-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Enable Automatic Tensor Parallelism #21726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,4 +1,14 @@ | ||||||||||||||
"""Utilities for distribution strategy with JAX backend.""" | ||||||||||||||
"""Utilities for distribution strategy with JAX backend. | ||||||||||||||
|
||||||||||||||
This file contains the core JAX distribution primitives from Keras, | ||||||||||||||
along with higher-level device management and auto-configuration utilities. | ||||||||||||||
This version does not use try-except blocks for error handling. | ||||||||||||||
""" | ||||||||||||||
|
||||||||||||||
import logging | ||||||||||||||
from typing import Dict | ||||||||||||||
from typing import List | ||||||||||||||
from typing import Optional | ||||||||||||||
|
||||||||||||||
import jax | ||||||||||||||
import numpy as np | ||||||||||||||
|
@@ -8,6 +18,8 @@ | |||||||||||||
from keras.src.utils import jax_utils | ||||||||||||||
from keras.src.utils import rng_utils | ||||||||||||||
|
||||||||||||||
logger = logging.getLogger(__name__) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def list_devices(device_type=None): | ||||||||||||||
"""Return all the available devices based on the device type. | ||||||||||||||
|
@@ -27,6 +39,153 @@ def list_devices(device_type=None): | |||||||||||||
return [f"{device.platform}:{device.id}" for device in jax_devices] | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def get_device_info(device_id: str) -> Dict[str, any]: | ||||||||||||||
""" | ||||||||||||||
Get detailed information about a specific device. | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
device_id: Device identifier (e.g., 'gpu:0', 'tpu:0', 'cpu:0') | ||||||||||||||
|
||||||||||||||
Returns: | ||||||||||||||
Dictionary containing device information | ||||||||||||||
""" | ||||||||||||||
device_info = { | ||||||||||||||
"id": device_id, | ||||||||||||||
"type": None, | ||||||||||||||
"index": None, | ||||||||||||||
"memory": None, | ||||||||||||||
"capabilities": None, | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
device_type, device_index = device_id.split(":") | ||||||||||||||
device_info["type"] = device_type.upper() | ||||||||||||||
device_info["index"] = int(device_index) | ||||||||||||||
Comment on lines
+60
to
+62
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The use of To make this function more robust, you should handle cases where the device ID might not have an index. A simple way is to check for the presence of
Suggested change
|
||||||||||||||
|
||||||||||||||
return device_info | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def get_best_devices(count: int = 1) -> List[str]: | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function name This violates the Keras API design guideline on naming (lines 68-69), which states that the meaning of an API element should be clear from its name. A developer might incorrectly assume this function performs some sort of device capability analysis. Consider renaming this function to something more descriptive of its actual behavior, such as |
||||||||||||||
""" | ||||||||||||||
Get the best available devices for tensor parallelism. | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
count: Number of devices needed | ||||||||||||||
|
||||||||||||||
Returns: | ||||||||||||||
List of best device identifiers | ||||||||||||||
""" | ||||||||||||||
all_devices = list_devices() | ||||||||||||||
|
||||||||||||||
if count <= 0: | ||||||||||||||
return [] | ||||||||||||||
|
||||||||||||||
if count > len(all_devices): | ||||||||||||||
logger.warning( | ||||||||||||||
f"Requested {count} devices but only {len(all_devices)} available" | ||||||||||||||
) | ||||||||||||||
count = len(all_devices) | ||||||||||||||
|
||||||||||||||
return all_devices[:count] | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def get_device_backend(device_type: str) -> str: | ||||||||||||||
""" | ||||||||||||||
Get the recommended backend for a device type. | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
device_type: Device type ('tpu', 'gpu', 'cpu') | ||||||||||||||
|
||||||||||||||
Returns: | ||||||||||||||
Recommended backend name | ||||||||||||||
""" | ||||||||||||||
backend_mapping = {"tpu": "jax", "gpu": "jax", "cpu": "jax"} | ||||||||||||||
|
||||||||||||||
return backend_mapping.get(device_type.lower(), "jax") | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def validate_device_placement(device_id: str) -> bool: | ||||||||||||||
""" | ||||||||||||||
Validate if a device can be used for tensor operations. | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
device_id: Device identifier | ||||||||||||||
|
||||||||||||||
Returns: | ||||||||||||||
True if device is valid and available | ||||||||||||||
""" | ||||||||||||||
all_devices = list_devices() | ||||||||||||||
return device_id in all_devices | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def get_device_memory_info(device_id: str) -> Optional[Dict[str, any]]: | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function returns hardcoded, non-factual strings for memory information (e.g., Per Keras API guidelines (lines 68-69), names should be clear. A function named I recommend one of the following:
|
||||||||||||||
""" | ||||||||||||||
Get memory information for a device (if available). | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
device_id: Device identifier | ||||||||||||||
|
||||||||||||||
Returns: | ||||||||||||||
Memory information dictionary or None if not available | ||||||||||||||
""" | ||||||||||||||
if device_id.startswith("gpu:"): | ||||||||||||||
return { | ||||||||||||||
"type": "GPU", | ||||||||||||||
"index": int(device_id.split(":")[1]), | ||||||||||||||
"memory": "Available", | ||||||||||||||
} | ||||||||||||||
elif device_id.startswith("tpu:"): | ||||||||||||||
return { | ||||||||||||||
"type": "TPU", | ||||||||||||||
"index": int(device_id.split(":")[1]), | ||||||||||||||
"memory": "TPU Memory", | ||||||||||||||
} | ||||||||||||||
elif device_id.startswith("cpu:"): | ||||||||||||||
return { | ||||||||||||||
"type": "CPU", | ||||||||||||||
"index": int(device_id.split(":")[1]), | ||||||||||||||
"memory": "System RAM", | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
return None | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def auto_configure_tensor_parallel( | ||||||||||||||
world_size: int = None, backend: str = None | ||||||||||||||
) -> Dict[str, any]: | ||||||||||||||
""" | ||||||||||||||
Automatically configure tensor parallelism with the best available devices. | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
world_size: Number of devices to use (if None, uses all available) | ||||||||||||||
backend: Backend to use (if None, will be set to 'jax') | ||||||||||||||
|
||||||||||||||
Returns: | ||||||||||||||
Configuration dictionary with devices, backend, and other settings | ||||||||||||||
""" | ||||||||||||||
all_devices = list_devices() | ||||||||||||||
|
||||||||||||||
if not all_devices: | ||||||||||||||
raise RuntimeError("No devices available for tensor parallelism") | ||||||||||||||
|
||||||||||||||
if world_size is None: | ||||||||||||||
world_size = len(all_devices) | ||||||||||||||
else: | ||||||||||||||
world_size = min(world_size, len(all_devices)) | ||||||||||||||
|
||||||||||||||
selected_devices = all_devices[:world_size] | ||||||||||||||
|
||||||||||||||
recommended_backend = "jax" | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The To fix this, you should use the provided
Suggested change
|
||||||||||||||
|
||||||||||||||
config = { | ||||||||||||||
"devices": selected_devices, | ||||||||||||||
"world_size": world_size, | ||||||||||||||
"backend": recommended_backend, | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
logger.info(f"Auto-configured tensor parallelism: {config}") | ||||||||||||||
return config | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def distribute_variable(value, layout): | ||||||||||||||
"""Create a distributed variable for JAX. | ||||||||||||||
|
||||||||||||||
|
@@ -245,4 +404,4 @@ def _to_backend_layout(tensor_layout): | |||||||||||||
) | ||||||||||||||
partition_spec = jax.sharding.PartitionSpec(*tensor_layout.axes) | ||||||||||||||
jax_mesh = tensor_layout.device_mesh.backend_mesh | ||||||||||||||
return jax.sharding.NamedSharding(jax_mesh, partition_spec) | ||||||||||||||
return jax.sharding.NamedSharding(jax_mesh, partition_spec) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,11 @@ | |
from keras.src.backend import distribution_lib | ||
from keras.src.backend.common import global_state | ||
|
||
# Add these imports at the top of keras/src/distribution/distribution_lib.py | ||
# from keras.src.distribution.tensor_parallel.tensor_parallel_keras import ( | ||
# TensorParallelKeras, | ||
# ) | ||
|
||
DEFAULT_BATCH_DIM_NAME = "batch" | ||
GLOBAL_ATTRIBUTE_NAME = "distribution" | ||
|
||
|
@@ -39,6 +44,24 @@ def list_devices(device_type=None): | |
return distribution_lib.list_devices(device_type) | ||
|
||
|
||
@keras_export("keras.distribution.get_best_devices") | ||
def get_best_devices(count): | ||
"""Return all the available devices based on the device type. | ||
|
||
Note: in a distributed setting, global devices are returned. | ||
|
||
Args: | ||
device_type: string, one of `"cpu"`, `"gpu"` or `"tpu"`. | ||
Defaults to `"gpu"` or `"tpu"` if available when | ||
`device_type` is not provided. Otherwise | ||
will return the `"cpu"` devices. | ||
|
||
Return: | ||
List of devices that are available for distribute computation. | ||
""" | ||
Comment on lines
+49
to
+61
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docstring for This is a critical documentation issue that violates the Keras API design guidelines on documentation clarity (lines 135-137, 144-145). Please update the docstring to accurately reflect the """Return the first `count` available devices.
Note: in a distributed setting, global devices are returned. This function
does not guarantee that the returned devices are the 'best' in terms of
performance, only that they are available.
Args:
count: The number of devices to return.
Return:
List of device identifier strings.
""" |
||
return distribution_lib.get_best_devices(count) | ||
|
||
|
||
@keras_export("keras.distribution.initialize") | ||
def initialize(job_addresses=None, num_processes=None, process_id=None): | ||
"""Initialize the distribution system for multi-host/process setting. | ||
|
@@ -534,6 +557,129 @@ def distribute_dataset(self, dataset): | |
return distributed_dataset.prefetch(tf.data.AUTOTUNE) | ||
|
||
|
||
# Place this in keras/src/distribution/distribution_lib.py | ||
|
||
|
||
@keras_export("keras.distribution.AutoTPDistribution") | ||
class AutoTPDistribution(Distribution): | ||
"""Distribution for automatic tensor parallelism. | ||
|
||
This strategy uses a set of heuristics to automatically analyze a model and | ||
apply tensor parallelism. | ||
|
||
This distribution acts as a factory to create a sharded version of a | ||
Keras model. The standard workflow is to: | ||
1. Create an instance of this distribution with a `DeviceMesh`. | ||
2. Pass your original model to the `shard()` method. | ||
3. Compile and train the new, sharded model that is returned. | ||
|
||
Example: | ||
```python | ||
# Define the hardware topology (e.g., 4 devices for model parallelism) | ||
device_mesh = DeviceMesh(shape=(4,), axis_names=('model',)) | ||
|
||
# Create an instance of the strategy | ||
distribution = AutoTPDistribution(device_mesh=device_mesh) | ||
|
||
# Define the original model | ||
model = keras.applications.ResNet50() | ||
|
||
# Use the distribution to create the sharded, tensor-parallel model | ||
sharded_model = distribution.shard(model) | ||
|
||
# Compile and fit the new sharded model | ||
sharded_model.compile(...) | ||
sharded_model.fit(...) | ||
``` | ||
|
||
Args: | ||
device_mesh: `DeviceMesh` instance that describes the hardware | ||
topology. | ||
batch_dim_name: Optional string, the axis name in the `device_mesh` | ||
that will be used for data parallelism. Defaults to the first | ||
axis name in the mesh. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
device_mesh=None, | ||
batch_dim_name=None, | ||
auto_shard_dataset=True, | ||
): | ||
if device_mesh is None: | ||
# Auto-create a 1D mesh with all available devices | ||
devices = list_devices() | ||
device_mesh = DeviceMesh( | ||
shape=(len(devices),), | ||
axis_names=("model",), | ||
devices=devices, | ||
) | ||
batch_dim_name = batch_dim_name or device_mesh.axis_names[0] | ||
super().__init__(device_mesh, batch_dim_name, auto_shard_dataset) | ||
|
||
def shard(self, model: "keras.Model") -> "TensorParallelKeras": | ||
from keras.src.distribution.tensor_parallel.tensor_parallel_keras import ( | ||
TensorParallelKeras, | ||
) | ||
|
||
""" | ||
Applies automatic tensor parallelism to a Keras model. | ||
|
||
This method takes a standard Keras model, analyzes its layers, | ||
and returns a new `TensorParallelKeras` model instance where the | ||
weights have been sharded across the devices specified in the | ||
`DeviceMesh`. | ||
|
||
Args: | ||
model: The original `keras.Model` instance to be sharded. | ||
|
||
Returns: | ||
A `TensorParallelKeras` model instance ready for distributed | ||
training. | ||
""" | ||
print(f"INFO: Sharding model `{model.name}` for Tensor Parallelism...") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method uses Please replace import logging
logging.info(f"Sharding model `{model.name}` for Tensor Parallelism...") |
||
world_size = np.prod(self.device_mesh.shape) | ||
device_ids = np.ravel(self.device_mesh.devices).tolist() | ||
|
||
# The `TensorParallelKeras` class contains all the sharding logic. | ||
# This distribution strategy is a clean, high-level entry point to it. | ||
sharded_model = TensorParallelKeras( | ||
model, world_size=world_size, device_ids=device_ids | ||
) | ||
print(f"INFO: Model `{model.name}` has been successfully sharded.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
return sharded_model | ||
|
||
def get_data_layout(self, data_shape): | ||
"""Returns the layout for data, sharding across the batch dimension.""" | ||
data_shard_spec = [None] * len(data_shape) | ||
if self.batch_dim_name in self.device_mesh.axis_names: | ||
data_shard_spec[0] = self.batch_dim_name | ||
return TensorLayout(data_shard_spec, self.device_mesh) | ||
|
||
def get_variable_layout(self, variable): | ||
"""Returns the layout for a variable (replicated by default).""" | ||
# In this pattern, the sharding logic is self-contained within the | ||
# TensorParallelKeras model. The global distribution mechanism is | ||
# primarily for data sharding. Variables outside the model are replicated. | ||
return TensorLayout([None] * len(variable.shape), self.device_mesh) | ||
|
||
def get_tensor_layout(self, path): | ||
return ( | ||
None # Not needed as communication is handled by the model's call() | ||
) | ||
|
||
def distribute_dataset(self, dataset): | ||
if distribution_lib.num_processes() <= 1 or not self.auto_shard_dataset: | ||
return dataset | ||
from keras.src.utils.module_utils import tensorflow as tf | ||
|
||
if not tf.available or not isinstance(dataset, tf.data.Dataset): | ||
raise ValueError( | ||
"Only `tf.data.Dataset` is supported for auto-sharding." | ||
) | ||
return dataset.with_options(tf.data.Options()) | ||
Comment on lines
+671
to
+680
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The implementation of This is problematic in a multi-process setting, as each process would get the full dataset, leading to redundant computation and incorrect results. The implementation should handle dataset sharding, similar to how |
||
|
||
|
||
@keras_export("keras.distribution.ModelParallel") | ||
class ModelParallel(Distribution): | ||
"""Distribution that shards model variables. | ||
|
@@ -895,4 +1041,4 @@ def set_distribution(value): | |
Args: | ||
value: a `Distribution` instance. | ||
""" | ||
global_state.set_global_attribute(GLOBAL_ATTRIBUTE_NAME, value) | ||
global_state.set_global_attribute(GLOBAL_ATTRIBUTE_NAME, value) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
device_info
dictionary is initialized withmemory
andcapabilities
keys, but they are never populated and always remainNone
. This can be misleading for consumers of this function who might expect these fields to contain data.According to the Keras API design guidelines, we should seek to minimize cognitive load (lines 46-48) and avoid exposing options or fields that are not used (lines 52-54).
It would be clearer to remove these keys from the dictionary. If they are intended for future use, they can be added back when the logic to populate them is implemented.