-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Add Autoconfig, Coordinated_Optimizer and Sharding keras implementations for Tensor Parallel Autosharding #21707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add Autoconfig, Coordinated_Optimizer and Sharding keras implementations for Tensor Parallel Autosharding #21707
Conversation
…hsuhana/keras into Tensor_parallel_keras_2
Summary of ChangesHello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances Keras's capabilities for large-scale model training by introducing foundational support for tensor parallelism autosharding. It provides mechanisms to automatically determine how model layers should be split across multiple devices and a specialized optimizer to manage the distributed training process, including sharding optimizer states and synchronizing gradients. This enables users to train models that exceed the memory capacity of a single accelerator, making distributed training more accessible and efficient within the Keras ecosystem. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant new functionality for tensor parallelism autosharding in Keras, including modules for automatic configuration and a coordinated optimizer. The implementation is well-structured, with new logic for analyzing models, generating sharding plans, and synchronizing gradients. However, I've identified a few issues that need attention. There is a critical bug in the CoordinatedOptimizer
where a method for applying gradients with sharded states is called but not defined. I also found a couple of high-severity issues related to incorrect logic for matching optimizer states and gathering sharded parameters, which could lead to runtime errors or incorrect behavior. Additionally, there are some medium-severity issues regarding code clarity, such as unused parameters. The accompanying tests are a good start but do not cover the code path with the critical bug.
This PR introduces support for tensor parallelism autosharding in Keras, enabling users to shard large model layers across multiple devices. This is a crucial feature for training models that are too large to fit into the memory of a single accelerator.
The implementation is centered around two new components:
autoconfig.py: This module contains the logic to analyze a Keras model, identify sharding candidates (e.g., Dense, EinsumDense layers), and generate a sharding plan.
coordinated_optimizer.py: This is an optimizer wrapper that consumes the sharding plan. During training, it intercepts gradients for sharded variables and performs a collective AllReduce to ensure weight updates are correctly synchronized across all devices.