-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Enable Automatic Tensor Parallelism #21726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Enable Automatic Tensor Parallelism #21726
Conversation
Summary of ChangesHello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This PR introduces automatic tensor parallelism to Keras, allowing models to be sharded across multiple devices for training larger models. It provides a high-level Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This is a great PR that introduces a powerful and much-needed feature for automatic tensor parallelism in Keras. The high-level API via AutoTPDistribution
and the core engine TensorParallelKeras
are well-designed, making large-scale model training more accessible. The use of a unified Functional model to encapsulate the parallel logic is particularly clever and should simplify execution and JIT compilation.
I've identified a few areas for improvement, mainly concerning API clarity, robustness, and consistency with Keras design principles. My comments focus on improving docstrings, handling edge cases more gracefully, and ensuring the code is as clear and maintainable as possible. I've also pointed out a few potential bugs and inconsistencies.
Overall, this is a fantastic contribution. Addressing these points will help ensure the new API is robust, intuitive, and easy for users to adopt.
"""Return all the available devices based on the device type. | ||
Note: in a distributed setting, global devices are returned. | ||
Args: | ||
device_type: string, one of `"cpu"`, `"gpu"` or `"tpu"`. | ||
Defaults to `"gpu"` or `"tpu"` if available when | ||
`device_type` is not provided. Otherwise | ||
will return the `"cpu"` devices. | ||
Return: | ||
List of devices that are available for distribute computation. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for get_best_devices
is incorrect and misleading. It describes a device_type
argument, but the function's signature is get_best_devices(count)
. This will cause significant confusion for users trying to use this function.
This is a critical documentation issue that violates the Keras API design guidelines on documentation clarity (lines 135-137, 144-145).
Please update the docstring to accurately reflect the count
parameter and the function's behavior.
"""Return the first `count` available devices.
Note: in a distributed setting, global devices are returned. This function
does not guarantee that the returned devices are the 'best' in terms of
performance, only that they are available.
Args:
count: The number of devices to return.
Return:
List of device identifier strings.
"""
|
||
selected_devices = all_devices[:world_size] | ||
|
||
recommended_backend = "jax" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The backend
parameter is accepted by the function but its value is never used. recommended_backend
is always hardcoded to "jax"
. This is a bug, as the function does not behave as its signature suggests.
To fix this, you should use the provided backend
parameter, falling back to "jax"
if it's not provided.
recommended_backend = "jax" | |
recommended_backend = backend or "jax" |
else: | ||
return "cpu" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The else
case here silently converts any non-string/non-int device_spec
to "cpu"
. This can hide bugs. For example, if a jax.Device
object is passed by mistake, it will be incorrectly treated as a CPU device.
It would be more robust to raise a TypeError
for unsupported types to make such errors explicit.
else:
raise TypeError(
f"Expected `device_spec` to be a str or int, but got "
f"{type(device_spec)}"
)
def distribute_dataset(self, dataset): | ||
if distribution_lib.num_processes() <= 1 or not self.auto_shard_dataset: | ||
return dataset | ||
from keras.src.utils.module_utils import tensorflow as tf | ||
|
||
if not tf.available or not isinstance(dataset, tf.data.Dataset): | ||
raise ValueError( | ||
"Only `tf.data.Dataset` is supported for auto-sharding." | ||
) | ||
return dataset.with_options(tf.data.Options()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation of distribute_dataset
for AutoTPDistribution
seems incorrect or incomplete. Calling dataset.with_options(tf.data.Options())
with empty options is a no-op and does not actually shard the dataset across processes.
This is problematic in a multi-process setting, as each process would get the full dataset, leading to redundant computation and incorrect results.
The implementation should handle dataset sharding, similar to how DataParallel.distribute_dataset
does. It needs to consider the number of processes and the data parallelism dimension of the device mesh to correctly shard the data.
A `TensorParallelKeras` model instance ready for distributed | ||
training. | ||
""" | ||
print(f"INFO: Sharding model `{model.name}` for Tensor Parallelism...") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method uses print()
for logging informational messages (here and on line 649). It's better to use the logging
module for consistency with other parts of the codebase (like keras/src/backend/jax/distribution_lib.py
) and to allow users more control over log verbosity.
Please replace print()
with logging.info()
. You will need to ensure logging
is imported and a logger is configured at the module level.
import logging
logging.info(f"Sharding model `{model.name}` for Tensor Parallelism...")
params_per_shard = [] | ||
for i, shard in enumerate(self.model_shards): | ||
total_params = 0 | ||
for p in shard.weights: | ||
if hasattr(p, "num_elements"): | ||
total_params += p.num_elements() | ||
elif hasattr(p, "numel"): | ||
total_params += p.numel() | ||
elif hasattr(p.shape, "num_elements"): | ||
total_params += p.shape.num_elements() | ||
else: | ||
total_params += np.prod(p.shape) | ||
|
||
params_per_shard.append(int(total_params)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from keras.src.distribution import distributed_backend | ||
|
||
self.distributed_backend = distributed_backend |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, self.distributed_backend
(which was a string from the constructor argument) is being reassigned to the imported distributed_backend
module. The original string value is stored in self.distributed_backend_name
. This is confusing because the type of self.distributed_backend
changes, and its name no longer reflects that it's a module.
To improve clarity and avoid type confusion, consider renaming the attribute that holds the module, for example:
from keras.src.distribution import distributed_backend as dist_backend_module
self.distributed_backend_module = dist_backend_module
def variables(self): | ||
"""Returns a unique list of all variables from all model shards.""" | ||
unique_vars = {} | ||
for shard in self.model_shards: | ||
for var in shard.variables: | ||
if id(var) not in unique_vars: | ||
unique_vars[id(var)] = var | ||
return list(unique_vars.values()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This property re-calculates the list of unique variables every time it's accessed. This is inefficient, especially if accessed frequently in a loop. The same applies to the other properties like trainable_variables
, weights
, etc.
Since the model shards are created once in __init__
, you can cache the result of this computation to improve performance.
def variables(self): | |
"""Returns a unique list of all variables from all model shards.""" | |
unique_vars = {} | |
for shard in self.model_shards: | |
for var in shard.variables: | |
if id(var) not in unique_vars: | |
unique_vars[id(var)] = var | |
return list(unique_vars.values()) | |
@property | |
def variables(self): | |
"""Returns a unique list of all variables from all model shards.""" | |
if hasattr(self, "_variables"): | |
return self._variables | |
unique_vars = {} | |
for shard in self.model_shards: | |
for var in shard.variables: | |
if id(var) not in unique_vars: | |
unique_vars[id(var)] = var | |
self._variables = list(unique_vars.values()) | |
return self._variables |
device_type, device_index = device_id.split(":") | ||
device_info["type"] = device_type.upper() | ||
device_info["index"] = int(device_index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of device_id.split(":")
is not robust and will raise a ValueError
if device_id
does not contain a colon (e.g., for a device like "cpu"
). This could lead to unexpected crashes if such a device ID is passed.
To make this function more robust, you should handle cases where the device ID might not have an index. A simple way is to check for the presence of ":"
before splitting.
device_type, device_index = device_id.split(":") | |
device_info["type"] = device_type.upper() | |
device_info["index"] = int(device_index) | |
parts = device_id.split(":") | |
device_info["type"] = parts[0].upper() | |
device_info["index"] = int(parts[1]) if len(parts) > 1 else 0 |
|
||
# 3. ASSERT | ||
# For a dimension of 17 split across 3 devices, the expected | ||
# sharded shapes are (6, 5, 5). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment here states that the expected sharded shapes are (6, 5, 5)
. However, for a dimension of 17 split across 3 devices, the split is (6, 6, 5)
. The test assertions correctly check for this (6, 6, 5)
split. The comment is incorrect and misleading.
Please update the comment to match the actual implementation and assertion.
# sharded shapes are (6, 5, 5). | |
# sharded shapes are (6, 6, 5). |
This pull request introduces a comprehensive framework for automatic tensor parallelism in Keras, enabling users to train models that are too large to fit on a single accelerator. The core of this feature is a new distribution strategy, AutoTPDistribution, which provides a simple, high-level API to shard an existing Keras model across multiple devices.
Description
This framework is designed to automate the complex process of model sharding and inter-device communication, making large-scale model training more accessible. The implementation is broken down into several key components:
A new distribution strategy, AutoTPDistribution, is introduced in keras/src/distribution/distribution_lib.py. This class serves as the primary user entry point. The workflow is straightforward:
A new TensorParallelKeras model class (keras.src.distribution.tensor_parallel.tensor_parallel.py) acts as the core engine. When distribution.shard() is called, this class wraps the original model and performs the following actions:
Auto-Configures Hardware: Discovers and assigns available devices (TPU, GPU, or CPU).
Shards Parameters: It analyzes the model's layers and applies column-parallel or row-parallel sharding strategies to the weights and biases of relevant layers (e.g., Dense).
Builds a Unified Graph: It creates a single, assembled Keras Functional model that internally manages the parallel computation. This clever design encapsulates the communication logic (e.g., AllGather, ReduceScatter) within the model's call graph, simplifying the execution and enabling JIT compilation. Partial outputs from each device shard are correctly combined (e.g., concatenation for column-parallel, summation for row-parallel).
Coordinates Gradients: It overrides the compile method to wrap the user's optimizer in a TensorParallelOptimizer, which handles the synchronized computation and application of gradients across all shards.