Skip to content

feat: TrainableState wrapper for efficient JIT compilation#322

Closed
nicholasjpaterno wants to merge 2 commits intooxiglade:mainfrom
nicholasjpaterno:feat/trainable-state
Closed

feat: TrainableState wrapper for efficient JIT compilation#322
nicholasjpaterno wants to merge 2 commits intooxiglade:mainfrom
nicholasjpaterno:feat/trainable-state

Conversation

@nicholasjpaterno
Copy link
Contributor

Summary

Adds TrainableState<M, O> — a wrapper that exposes only trainable (non-frozen) parameters to compile_with_state, enabling efficient JIT compilation for LoRA and other parameter-efficient fine-tuning methods.

The problem: The standard Updatable implementation returns ALL model parameters. For a model with 10M params but only 500K trainable (LoRA adapters), the MLX compiler captures all 10M arrays, prunes 9.5M unchanged ones from output, and breaks state tracking due to count mismatch.

The solution: TrainableState wraps model + optimizer, implementing Updatable to return only trainable parameters plus optimizer state — reducing state count by orders of magnitude.

  • TrainableState<M, O> wrapper type with cached key ordering for compile correctness
  • CompileTrainingExt trait for ergonomic model.with_optimizer(opt) construction
  • Unit tests for creation and frozen parameter handling

Depends on #321 (public RandomState API).

- Make RandomState methods public (new, with_seed, next_key, seed)
- Implement Updatable trait for compile_with_state compatibility
- Add Default impl for RandomState
- Add helper methods (from_key, as_array, as_array_mut)

This enables using RandomState with compile_with_state for JIT-compiled
random number generation, equivalent to Python's:
  @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
When using compile_with_state with LoRA or other parameter-efficient
fine-tuning methods, the standard Updatable implementation returns ALL
model parameters. For a model with 10M parameters but only 500K trainable,
this causes MLX's compiler to prune 9.5M unchanged arrays from output,
breaking state tracking.

TrainableState wraps a model and optimizer, implementing Updatable to
return only trainable (non-frozen) parameters plus optimizer state.
This dramatically reduces state count, enabling successful JIT compilation.

Key additions:
- TrainableState<M, O> wrapper type
- CompileTrainingExt trait for ergonomic wrapping
- Consistent state ordering via cached trainable keys
@dcvz
Copy link
Collaborator

dcvz commented Feb 20, 2026

@nicholasjpaterno I thought we ended up not needing this? Did the last solution not end up working in the end?

@nicholasjpaterno
Copy link
Contributor Author

@dcvz You're right — #314 fixed the root cause. I forgot to rebase my local fork against those upstream. The proper fix is in place, so this is definitely no longer needed. Closing this one. Thanks! I have opened PR's for all other changes in my fork - I'm hoping to deprecate it entirely as I'm gearing up to launch https://github.com/Epistates/pmetal

@nicholasjpaterno
Copy link
Contributor Author

Closing now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants