feat: TrainableState wrapper for efficient JIT compilation#322
Closed
nicholasjpaterno wants to merge 2 commits intooxiglade:mainfrom
Closed
feat: TrainableState wrapper for efficient JIT compilation#322nicholasjpaterno wants to merge 2 commits intooxiglade:mainfrom
nicholasjpaterno wants to merge 2 commits intooxiglade:mainfrom
Conversation
- Make RandomState methods public (new, with_seed, next_key, seed) - Implement Updatable trait for compile_with_state compatibility - Add Default impl for RandomState - Add helper methods (from_key, as_array, as_array_mut) This enables using RandomState with compile_with_state for JIT-compiled random number generation, equivalent to Python's: @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
When using compile_with_state with LoRA or other parameter-efficient fine-tuning methods, the standard Updatable implementation returns ALL model parameters. For a model with 10M parameters but only 500K trainable, this causes MLX's compiler to prune 9.5M unchanged arrays from output, breaking state tracking. TrainableState wraps a model and optimizer, implementing Updatable to return only trainable (non-frozen) parameters plus optimizer state. This dramatically reduces state count, enabling successful JIT compilation. Key additions: - TrainableState<M, O> wrapper type - CompileTrainingExt trait for ergonomic wrapping - Consistent state ordering via cached trainable keys
Collaborator
|
@nicholasjpaterno I thought we ended up not needing this? Did the last solution not end up working in the end? |
Contributor
Author
|
@dcvz You're right — #314 fixed the root cause. I forgot to rebase my local fork against those upstream. The proper fix is in place, so this is definitely no longer needed. Closing this one. Thanks! I have opened PR's for all other changes in my fork - I'm hoping to deprecate it entirely as I'm gearing up to launch https://github.com/Epistates/pmetal |
Contributor
Author
|
Closing now |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds
TrainableState<M, O>— a wrapper that exposes only trainable (non-frozen) parameters tocompile_with_state, enabling efficient JIT compilation for LoRA and other parameter-efficient fine-tuning methods.The problem: The standard
Updatableimplementation returns ALL model parameters. For a model with 10M params but only 500K trainable (LoRA adapters), the MLX compiler captures all 10M arrays, prunes 9.5M unchanged ones from output, and breaks state tracking due to count mismatch.The solution:
TrainableStatewraps model + optimizer, implementingUpdatableto return only trainable parameters plus optimizer state — reducing state count by orders of magnitude.TrainableState<M, O>wrapper type with cached key ordering for compile correctnessCompileTrainingExttrait for ergonomicmodel.with_optimizer(opt)constructionDepends on #321 (public RandomState API).