-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Fix ModelParallel OOM issue during weight loading #21723
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Fix ModelParallel OOM issue during weight loading #21723
Conversation
amitsrivastava78
commented
Oct 7, 2025
- Modified load_own_variables() to use _direct_assign() for sharded variables
- Prevents loading full weight tensors on single device before distribution
- Resolves RESOURCE_EXHAUSTED errors when loading large models with ModelParallel
- Maintains backward compatibility for non-sharded variables
- Enables loading of models like Gemma2 2B/7B without OOM errors
- Added EinsumDense layer testing to ModelParallel sharded variable loading
- Modified load_own_variables() to use _direct_assign() for sharded variables - Prevents loading full weight tensors on single device before distribution - Resolves RESOURCE_EXHAUSTED errors when loading large models with ModelParallel - Maintains backward compatibility for non-sharded variables - Enables loading of models like Gemma2 2B/7B without OOM errors - Added EinsumDense layer testing to ModelParallel sharded variable loading
Summary of ChangesHello @amitsrivastava78, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request addresses a critical Out-Of-Memory (OOM) issue encountered when loading large models using ModelParallel in the JAX backend. The core problem stemmed from loading entire weight tensors onto a single device before distributing them, leading to Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request provides a crucial fix for an out-of-memory issue when loading sharded models with ModelParallel
in the JAX backend. The core change, which involves using a new _direct_assign
method to distribute weights before assigning them to device variables, is well-implemented and effectively prevents loading full tensors onto a single device. The addition of _ProtectedShardedArray
and strong referencing to prevent premature garbage collection of sharded arrays is a clever solution to a common problem in JAX. The refactoring of sharding logic into a shared _initialize_variable_with_sharding
helper improves code clarity and maintainability. The comprehensive end-to-end tests for sharded variable loading across various layer types provide strong confidence in the correctness of the fix. Overall, this is an excellent and thorough contribution that significantly improves the usability of model parallelism in Keras.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21723 +/- ##
=======================================
Coverage 82.59% 82.59%
=======================================
Files 572 572
Lines 58535 58543 +8
Branches 9158 9158
=======================================
+ Hits 48345 48353 +8
- Misses 7853 7854 +1
+ Partials 2337 2336 -1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
- Fix PyTorch backend CI failures by adding _direct_assign method for proper numpy-to-tensor conversion - Restore JAX export functionality using jax_export.symbolic_shape for dynamic shape handling - Refactor variable loading logic to eliminate duplication between Dense and EinsumDense layers - Create shared utility function get_quantized_variable_load_order in keras/src/utils/variable_loading.py - Update layer implementations to use the shared variable loading utility - All tests passing: PyTorch backend, JAX backend, and layer-specific legacy loading tests
b0a7824
to
5da9108
Compare
- Improve host memory allocation for sharded variables by preferring JAX arrays over NumPy conversion - Remove unnecessary jax.block_until_ready() calls as JAX automatically blocks when needed - Add comprehensive documentation for memory stability protection and host allocation - Enhance logging for variable initialization and assignment operations - Add support for both NumPy and JAX arrays in variable assignment methods
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One overall note: We basically never log in the successful case, the logs are just way too noisy.
Can you create two separate small PRs:
- One for the fix for the initializer. My reading of it is that this is the only change needed:
def _initialize_with_initializer(self, initializer):
"""Initialize variable with initializer, running on CPU if sharding
is needed."""
if self._layout is not None:
# For sharded variables, run initializer on CPU to avoid device
# placement issues
with jax.default_device(jax.devices("cpu")[0]):
value = self._convert_to_tensor(
initializer(self._shape, dtype=self._dtype)
)
else:
# For non-sharded variables, use the default behavior
value = self._convert_to_tensor(
initializer(self._shape, dtype=self._dtype)
)
self._initialize(value)
But we should check that it does what we think it does.
- One for the fix for the weight loading. It will have this change, and only this change, for the relevant layers:
for i, variable in enumerate(target_variables):
variable._direct_assign(store[str(i)])
And jax_memory_cleanup
should be addressed differently, which we can talk about.
keras/src/backend/jax/core.py
Outdated
class _ProtectedShardedArray: | ||
"""Wrapper that prevents deletion of sharded JAX arrays. | ||
This wrapper intercepts delete() calls from jax_memory_cleanup |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, so the problem stems from jax_memory_cleanup
.
And all this very complex logic (many hundreds of lines of code) does is undo what jax_memory_cleanup
is doing. So the obvious choice is to remove jax_memory_cleanup
(4 lines of code) instead. The way jax_memory_cleanup
works is intrinsically dangerous and we should find a better way to do it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually you are right, let me explain you what is happening
- During model build time, the Jax Array is created
- Since Now during model loading time, since Jax arrays are immutable they will be duplicated.
- No idea was that eventually gc will run and delete the old reference
Now this is the reason to introduce the jax_memory deletion as a way to aggressively delete the memory between build time of the model and weight loading time, so that duplication of memory should not happen.
Now in ModelParallel approach that we have done, we do not need this aggressive approach as this rolling, variable-by-variable approach does not put so much strain on memory.
So yes i have reverted all these change and in keras hub i will call jax_memory_cleanup conditionally i.e this will not be called if ModelParallel is active, let me know what you think
- Remove _ProtectedShardedArray class and _maybe_create_strong_reference method from core.py - Remove jax.block_until_ready calls that are no longer needed - Simplify variable initialization and assignment logic - Remove all test cases related to reference holding from core_test.py - Tests now pass and are consistent with the simplified implementation
ec69373
to
92bf1ed
Compare
37f7475
to
92bf1ed
Compare
…1713 - Remove variable_loading.py (quantization/saving related) - Fix duplicate import in core_test.py - Revert layer files to remove quantization changes - Keep only core JAX memory management changes for OOM fix
- Remove get_quantized_variable_load_order imports from dense.py and einsum_dense.py - Replace function calls with inline variable ordering logic - Maintain compatibility with quantization loading
…n usage - Remove quantization-specific variable ordering in _legacy_load_own_variables - Keep _direct_assign usage for OOM prevention during sharded variable loading - Maintain compatibility with quantization_variable_spec
- Change dense.py and einsum_dense.py _legacy_load_own_variables to use _direct_assign - Maintains OOM prevention for ModelParallel while ensuring consistency across all layers - All layers now use _direct_assign for variable loading
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me. But what about base_conv
, embedding
? Don't they need this change too?
https://github.com/search?q=repo%3Akeras-team%2Fkeras%20variable.assign(store%5Bstr(i)%5D)&type=code
Oh and what about base_optimizer
?