Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 161 additions & 2 deletions keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
"""Utilities for distribution strategy with JAX backend."""
"""Utilities for distribution strategy with JAX backend.

This file contains the core JAX distribution primitives from Keras,
along with higher-level device management and auto-configuration utilities.
This version does not use try-except blocks for error handling.
"""

import logging
from typing import Dict
from typing import List
from typing import Optional

import jax
import numpy as np
Expand All @@ -8,6 +18,8 @@
from keras.src.utils import jax_utils
from keras.src.utils import rng_utils

logger = logging.getLogger(__name__)


def list_devices(device_type=None):
"""Return all the available devices based on the device type.
Expand All @@ -27,6 +39,153 @@ def list_devices(device_type=None):
return [f"{device.platform}:{device.id}" for device in jax_devices]


def get_device_info(device_id: str) -> Dict[str, any]:
"""
Get detailed information about a specific device.

Args:
device_id: Device identifier (e.g., 'gpu:0', 'tpu:0', 'cpu:0')

Returns:
Dictionary containing device information
"""
device_info = {
"id": device_id,
"type": None,
"index": None,
"memory": None,
"capabilities": None,
Comment on lines +56 to +57
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The device_info dictionary is initialized with memory and capabilities keys, but they are never populated and always remain None. This can be misleading for consumers of this function who might expect these fields to contain data.

According to the Keras API design guidelines, we should seek to minimize cognitive load (lines 46-48) and avoid exposing options or fields that are not used (lines 52-54).

It would be clearer to remove these keys from the dictionary. If they are intended for future use, they can be added back when the logic to populate them is implemented.

}

device_type, device_index = device_id.split(":")
device_info["type"] = device_type.upper()
device_info["index"] = int(device_index)
Comment on lines +60 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of device_id.split(":") is not robust and will raise a ValueError if device_id does not contain a colon (e.g., for a device like "cpu"). This could lead to unexpected crashes if such a device ID is passed.

To make this function more robust, you should handle cases where the device ID might not have an index. A simple way is to check for the presence of ":" before splitting.

Suggested change
device_type, device_index = device_id.split(":")
device_info["type"] = device_type.upper()
device_info["index"] = int(device_index)
parts = device_id.split(":")
device_info["type"] = parts[0].upper()
device_info["index"] = int(parts[1]) if len(parts) > 1 else 0


return device_info


def get_best_devices(count: int = 1) -> List[str]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function name get_best_devices is misleading. The implementation simply returns the first count devices from list_devices(). However, the underlying jax.devices() call does not guarantee any specific order, let alone an ordering from 'best' to 'worst' in terms of performance.

This violates the Keras API design guideline on naming (lines 68-69), which states that the meaning of an API element should be clear from its name. A developer might incorrectly assume this function performs some sort of device capability analysis.

Consider renaming this function to something more descriptive of its actual behavior, such as get_first_n_devices, and updating the docstring to clarify that no performance-based ordering is guaranteed.

"""
Get the best available devices for tensor parallelism.

Args:
count: Number of devices needed

Returns:
List of best device identifiers
"""
all_devices = list_devices()

if count <= 0:
return []

if count > len(all_devices):
logger.warning(
f"Requested {count} devices but only {len(all_devices)} available"
)
count = len(all_devices)

return all_devices[:count]


def get_device_backend(device_type: str) -> str:
"""
Get the recommended backend for a device type.

Args:
device_type: Device type ('tpu', 'gpu', 'cpu')

Returns:
Recommended backend name
"""
backend_mapping = {"tpu": "jax", "gpu": "jax", "cpu": "jax"}

return backend_mapping.get(device_type.lower(), "jax")


def validate_device_placement(device_id: str) -> bool:
"""
Validate if a device can be used for tensor operations.

Args:
device_id: Device identifier

Returns:
True if device is valid and available
"""
all_devices = list_devices()
return device_id in all_devices


def get_device_memory_info(device_id: str) -> Optional[Dict[str, any]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function returns hardcoded, non-factual strings for memory information (e.g., "Available", "TPU Memory"). This is misleading as it doesn't provide actual memory stats. It also suffers from the same non-robust split(":") issue as get_device_info.

Per Keras API guidelines (lines 68-69), names should be clear. A function named get_device_memory_info implies it retrieves real data.

I recommend one of the following:

  • Implement actual memory querying for each device type.
  • If that's not feasible, remove the "memory" key or set its value to None.
  • At a minimum, update the docstring to explicitly state that the memory information is a placeholder.

"""
Get memory information for a device (if available).

Args:
device_id: Device identifier

Returns:
Memory information dictionary or None if not available
"""
if device_id.startswith("gpu:"):
return {
"type": "GPU",
"index": int(device_id.split(":")[1]),
"memory": "Available",
}
elif device_id.startswith("tpu:"):
return {
"type": "TPU",
"index": int(device_id.split(":")[1]),
"memory": "TPU Memory",
}
elif device_id.startswith("cpu:"):
return {
"type": "CPU",
"index": int(device_id.split(":")[1]),
"memory": "System RAM",
}

return None


def auto_configure_tensor_parallel(
world_size: int = None, backend: str = None
) -> Dict[str, any]:
"""
Automatically configure tensor parallelism with the best available devices.

Args:
world_size: Number of devices to use (if None, uses all available)
backend: Backend to use (if None, will be set to 'jax')

Returns:
Configuration dictionary with devices, backend, and other settings
"""
all_devices = list_devices()

if not all_devices:
raise RuntimeError("No devices available for tensor parallelism")

if world_size is None:
world_size = len(all_devices)
else:
world_size = min(world_size, len(all_devices))

selected_devices = all_devices[:world_size]

recommended_backend = "jax"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The backend parameter is accepted by the function but its value is never used. recommended_backend is always hardcoded to "jax". This is a bug, as the function does not behave as its signature suggests.

To fix this, you should use the provided backend parameter, falling back to "jax" if it's not provided.

Suggested change
recommended_backend = "jax"
recommended_backend = backend or "jax"


config = {
"devices": selected_devices,
"world_size": world_size,
"backend": recommended_backend,
}

logger.info(f"Auto-configured tensor parallelism: {config}")
return config


def distribute_variable(value, layout):
"""Create a distributed variable for JAX.

Expand Down Expand Up @@ -245,4 +404,4 @@ def _to_backend_layout(tensor_layout):
)
partition_spec = jax.sharding.PartitionSpec(*tensor_layout.axes)
jax_mesh = tensor_layout.device_mesh.backend_mesh
return jax.sharding.NamedSharding(jax_mesh, partition_spec)
return jax.sharding.NamedSharding(jax_mesh, partition_spec)
2 changes: 1 addition & 1 deletion keras/src/backend/jax/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,4 +451,4 @@ def call(self, inputs):
return inputs

def capture_input_sharding(self, sharding):
self.captured_input_sharding = sharding
self.captured_input_sharding = sharding
148 changes: 147 additions & 1 deletion keras/src/distribution/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
from keras.src.backend import distribution_lib
from keras.src.backend.common import global_state

# Add these imports at the top of keras/src/distribution/distribution_lib.py
# from keras.src.distribution.tensor_parallel.tensor_parallel_keras import (
# TensorParallelKeras,
# )

DEFAULT_BATCH_DIM_NAME = "batch"
GLOBAL_ATTRIBUTE_NAME = "distribution"

Expand All @@ -39,6 +44,24 @@ def list_devices(device_type=None):
return distribution_lib.list_devices(device_type)


@keras_export("keras.distribution.get_best_devices")
def get_best_devices(count):
"""Return all the available devices based on the device type.

Note: in a distributed setting, global devices are returned.

Args:
device_type: string, one of `"cpu"`, `"gpu"` or `"tpu"`.
Defaults to `"gpu"` or `"tpu"` if available when
`device_type` is not provided. Otherwise
will return the `"cpu"` devices.

Return:
List of devices that are available for distribute computation.
"""
Comment on lines +49 to +61
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The docstring for get_best_devices is incorrect and misleading. It describes a device_type argument, but the function's signature is get_best_devices(count). This will cause significant confusion for users trying to use this function.

This is a critical documentation issue that violates the Keras API design guidelines on documentation clarity (lines 135-137, 144-145).

Please update the docstring to accurately reflect the count parameter and the function's behavior.

    """Return the first `count` available devices.

    Note: in a distributed setting, global devices are returned. This function
    does not guarantee that the returned devices are the 'best' in terms of
    performance, only that they are available.

    Args:
        count: The number of devices to return.

    Return:
        List of device identifier strings.
    """

return distribution_lib.get_best_devices(count)


@keras_export("keras.distribution.initialize")
def initialize(job_addresses=None, num_processes=None, process_id=None):
"""Initialize the distribution system for multi-host/process setting.
Expand Down Expand Up @@ -534,6 +557,129 @@ def distribute_dataset(self, dataset):
return distributed_dataset.prefetch(tf.data.AUTOTUNE)


# Place this in keras/src/distribution/distribution_lib.py


@keras_export("keras.distribution.AutoTPDistribution")
class AutoTPDistribution(Distribution):
"""Distribution for automatic tensor parallelism.

This strategy uses a set of heuristics to automatically analyze a model and
apply tensor parallelism.

This distribution acts as a factory to create a sharded version of a
Keras model. The standard workflow is to:
1. Create an instance of this distribution with a `DeviceMesh`.
2. Pass your original model to the `shard()` method.
3. Compile and train the new, sharded model that is returned.

Example:
```python
# Define the hardware topology (e.g., 4 devices for model parallelism)
device_mesh = DeviceMesh(shape=(4,), axis_names=('model',))

# Create an instance of the strategy
distribution = AutoTPDistribution(device_mesh=device_mesh)

# Define the original model
model = keras.applications.ResNet50()

# Use the distribution to create the sharded, tensor-parallel model
sharded_model = distribution.shard(model)

# Compile and fit the new sharded model
sharded_model.compile(...)
sharded_model.fit(...)
```

Args:
device_mesh: `DeviceMesh` instance that describes the hardware
topology.
batch_dim_name: Optional string, the axis name in the `device_mesh`
that will be used for data parallelism. Defaults to the first
axis name in the mesh.
"""

def __init__(
self,
device_mesh=None,
batch_dim_name=None,
auto_shard_dataset=True,
):
if device_mesh is None:
# Auto-create a 1D mesh with all available devices
devices = list_devices()
device_mesh = DeviceMesh(
shape=(len(devices),),
axis_names=("model",),
devices=devices,
)
batch_dim_name = batch_dim_name or device_mesh.axis_names[0]
super().__init__(device_mesh, batch_dim_name, auto_shard_dataset)

def shard(self, model: "keras.Model") -> "TensorParallelKeras":
from keras.src.distribution.tensor_parallel.tensor_parallel_keras import (
TensorParallelKeras,
)

"""
Applies automatic tensor parallelism to a Keras model.

This method takes a standard Keras model, analyzes its layers,
and returns a new `TensorParallelKeras` model instance where the
weights have been sharded across the devices specified in the
`DeviceMesh`.

Args:
model: The original `keras.Model` instance to be sharded.

Returns:
A `TensorParallelKeras` model instance ready for distributed
training.
"""
print(f"INFO: Sharding model `{model.name}` for Tensor Parallelism...")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This method uses print() for logging informational messages (here and on line 649). It's better to use the logging module for consistency with other parts of the codebase (like keras/src/backend/jax/distribution_lib.py) and to allow users more control over log verbosity.

Please replace print() with logging.info(). You will need to ensure logging is imported and a logger is configured at the module level.

        import logging
        logging.info(f"Sharding model `{model.name}` for Tensor Parallelism...")

world_size = np.prod(self.device_mesh.shape)
device_ids = np.ravel(self.device_mesh.devices).tolist()

# The `TensorParallelKeras` class contains all the sharding logic.
# This distribution strategy is a clean, high-level entry point to it.
sharded_model = TensorParallelKeras(
model, world_size=world_size, device_ids=device_ids
)
print(f"INFO: Model `{model.name}` has been successfully sharded.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the previous comment, please use logging.info() instead of print() for consistency and better log control.

        logging.info(f"Model `{model.name}` has been successfully sharded.")

return sharded_model

def get_data_layout(self, data_shape):
"""Returns the layout for data, sharding across the batch dimension."""
data_shard_spec = [None] * len(data_shape)
if self.batch_dim_name in self.device_mesh.axis_names:
data_shard_spec[0] = self.batch_dim_name
return TensorLayout(data_shard_spec, self.device_mesh)

def get_variable_layout(self, variable):
"""Returns the layout for a variable (replicated by default)."""
# In this pattern, the sharding logic is self-contained within the
# TensorParallelKeras model. The global distribution mechanism is
# primarily for data sharding. Variables outside the model are replicated.
return TensorLayout([None] * len(variable.shape), self.device_mesh)

def get_tensor_layout(self, path):
return (
None # Not needed as communication is handled by the model's call()
)

def distribute_dataset(self, dataset):
if distribution_lib.num_processes() <= 1 or not self.auto_shard_dataset:
return dataset
from keras.src.utils.module_utils import tensorflow as tf

if not tf.available or not isinstance(dataset, tf.data.Dataset):
raise ValueError(
"Only `tf.data.Dataset` is supported for auto-sharding."
)
return dataset.with_options(tf.data.Options())
Comment on lines +671 to +680
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of distribute_dataset for AutoTPDistribution seems incorrect or incomplete. Calling dataset.with_options(tf.data.Options()) with empty options is a no-op and does not actually shard the dataset across processes.

This is problematic in a multi-process setting, as each process would get the full dataset, leading to redundant computation and incorrect results.

The implementation should handle dataset sharding, similar to how DataParallel.distribute_dataset does. It needs to consider the number of processes and the data parallelism dimension of the device mesh to correctly shard the data.



@keras_export("keras.distribution.ModelParallel")
class ModelParallel(Distribution):
"""Distribution that shards model variables.
Expand Down Expand Up @@ -895,4 +1041,4 @@ def set_distribution(value):
Args:
value: a `Distribution` instance.
"""
global_state.set_global_attribute(GLOBAL_ATTRIBUTE_NAME, value)
global_state.set_global_attribute(GLOBAL_ATTRIBUTE_NAME, value)
Loading
Loading