diff --git a/keras/src/backend/jax/distribution_lib.py b/keras/src/backend/jax/distribution_lib.py index 6b5bf37314c0..cd0555ffc522 100644 --- a/keras/src/backend/jax/distribution_lib.py +++ b/keras/src/backend/jax/distribution_lib.py @@ -1,4 +1,14 @@ -"""Utilities for distribution strategy with JAX backend.""" +"""Utilities for distribution strategy with JAX backend. + +This file contains the core JAX distribution primitives from Keras, +along with higher-level device management and auto-configuration utilities. +This version does not use try-except blocks for error handling. +""" + +import logging +from typing import Dict +from typing import List +from typing import Optional import jax import numpy as np @@ -8,6 +18,8 @@ from keras.src.utils import jax_utils from keras.src.utils import rng_utils +logger = logging.getLogger(__name__) + def list_devices(device_type=None): """Return all the available devices based on the device type. @@ -27,6 +39,153 @@ def list_devices(device_type=None): return [f"{device.platform}:{device.id}" for device in jax_devices] +def get_device_info(device_id: str) -> Dict[str, any]: + """ + Get detailed information about a specific device. + + Args: + device_id: Device identifier (e.g., 'gpu:0', 'tpu:0', 'cpu:0') + + Returns: + Dictionary containing device information + """ + device_info = { + "id": device_id, + "type": None, + "index": None, + "memory": None, + "capabilities": None, + } + + device_type, device_index = device_id.split(":") + device_info["type"] = device_type.upper() + device_info["index"] = int(device_index) + + return device_info + + +def get_best_devices(count: int = 1) -> List[str]: + """ + Get the best available devices for tensor parallelism. + + Args: + count: Number of devices needed + + Returns: + List of best device identifiers + """ + all_devices = list_devices() + + if count <= 0: + return [] + + if count > len(all_devices): + logger.warning( + f"Requested {count} devices but only {len(all_devices)} available" + ) + count = len(all_devices) + + return all_devices[:count] + + +def get_device_backend(device_type: str) -> str: + """ + Get the recommended backend for a device type. + + Args: + device_type: Device type ('tpu', 'gpu', 'cpu') + + Returns: + Recommended backend name + """ + backend_mapping = {"tpu": "jax", "gpu": "jax", "cpu": "jax"} + + return backend_mapping.get(device_type.lower(), "jax") + + +def validate_device_placement(device_id: str) -> bool: + """ + Validate if a device can be used for tensor operations. + + Args: + device_id: Device identifier + + Returns: + True if device is valid and available + """ + all_devices = list_devices() + return device_id in all_devices + + +def get_device_memory_info(device_id: str) -> Optional[Dict[str, any]]: + """ + Get memory information for a device (if available). + + Args: + device_id: Device identifier + + Returns: + Memory information dictionary or None if not available + """ + if device_id.startswith("gpu:"): + return { + "type": "GPU", + "index": int(device_id.split(":")[1]), + "memory": "Available", + } + elif device_id.startswith("tpu:"): + return { + "type": "TPU", + "index": int(device_id.split(":")[1]), + "memory": "TPU Memory", + } + elif device_id.startswith("cpu:"): + return { + "type": "CPU", + "index": int(device_id.split(":")[1]), + "memory": "System RAM", + } + + return None + + +def auto_configure_tensor_parallel( + world_size: int = None, backend: str = None +) -> Dict[str, any]: + """ + Automatically configure tensor parallelism with the best available devices. + + Args: + world_size: Number of devices to use (if None, uses all available) + backend: Backend to use (if None, will be set to 'jax') + + Returns: + Configuration dictionary with devices, backend, and other settings + """ + all_devices = list_devices() + + if not all_devices: + raise RuntimeError("No devices available for tensor parallelism") + + if world_size is None: + world_size = len(all_devices) + else: + world_size = min(world_size, len(all_devices)) + + selected_devices = all_devices[:world_size] + + recommended_backend = "jax" + + config = { + "devices": selected_devices, + "world_size": world_size, + "backend": recommended_backend, + } + + logger.info(f"Auto-configured tensor parallelism: {config}") + return config + + def distribute_variable(value, layout): """Create a distributed variable for JAX. @@ -245,4 +404,4 @@ def _to_backend_layout(tensor_layout): ) partition_spec = jax.sharding.PartitionSpec(*tensor_layout.axes) jax_mesh = tensor_layout.device_mesh.backend_mesh - return jax.sharding.NamedSharding(jax_mesh, partition_spec) + return jax.sharding.NamedSharding(jax_mesh, partition_spec) \ No newline at end of file diff --git a/keras/src/backend/jax/distribution_lib_test.py b/keras/src/backend/jax/distribution_lib_test.py index 8938c14fc50a..3e4af9bccd06 100644 --- a/keras/src/backend/jax/distribution_lib_test.py +++ b/keras/src/backend/jax/distribution_lib_test.py @@ -451,4 +451,4 @@ def call(self, inputs): return inputs def capture_input_sharding(self, sharding): - self.captured_input_sharding = sharding + self.captured_input_sharding = sharding \ No newline at end of file diff --git a/keras/src/distribution/distribution_lib.py b/keras/src/distribution/distribution_lib.py index 2daef40a2ed8..afe29515bbf9 100644 --- a/keras/src/distribution/distribution_lib.py +++ b/keras/src/distribution/distribution_lib.py @@ -17,6 +17,11 @@ from keras.src.backend import distribution_lib from keras.src.backend.common import global_state +# Add these imports at the top of keras/src/distribution/distribution_lib.py +# from keras.src.distribution.tensor_parallel.tensor_parallel_keras import ( +# TensorParallelKeras, +# ) + DEFAULT_BATCH_DIM_NAME = "batch" GLOBAL_ATTRIBUTE_NAME = "distribution" @@ -39,6 +44,24 @@ def list_devices(device_type=None): return distribution_lib.list_devices(device_type) +@keras_export("keras.distribution.get_best_devices") +def get_best_devices(count): + """Return all the available devices based on the device type. + + Note: in a distributed setting, global devices are returned. + + Args: + device_type: string, one of `"cpu"`, `"gpu"` or `"tpu"`. + Defaults to `"gpu"` or `"tpu"` if available when + `device_type` is not provided. Otherwise + will return the `"cpu"` devices. + + Return: + List of devices that are available for distribute computation. + """ + return distribution_lib.get_best_devices(count) + + @keras_export("keras.distribution.initialize") def initialize(job_addresses=None, num_processes=None, process_id=None): """Initialize the distribution system for multi-host/process setting. @@ -534,6 +557,129 @@ def distribute_dataset(self, dataset): return distributed_dataset.prefetch(tf.data.AUTOTUNE) +# Place this in keras/src/distribution/distribution_lib.py + + +@keras_export("keras.distribution.AutoTPDistribution") +class AutoTPDistribution(Distribution): + """Distribution for automatic tensor parallelism. + + This strategy uses a set of heuristics to automatically analyze a model and + apply tensor parallelism. + + This distribution acts as a factory to create a sharded version of a + Keras model. The standard workflow is to: + 1. Create an instance of this distribution with a `DeviceMesh`. + 2. Pass your original model to the `shard()` method. + 3. Compile and train the new, sharded model that is returned. + + Example: + ```python + # Define the hardware topology (e.g., 4 devices for model parallelism) + device_mesh = DeviceMesh(shape=(4,), axis_names=('model',)) + + # Create an instance of the strategy + distribution = AutoTPDistribution(device_mesh=device_mesh) + + # Define the original model + model = keras.applications.ResNet50() + + # Use the distribution to create the sharded, tensor-parallel model + sharded_model = distribution.shard(model) + + # Compile and fit the new sharded model + sharded_model.compile(...) + sharded_model.fit(...) + ``` + + Args: + device_mesh: `DeviceMesh` instance that describes the hardware + topology. + batch_dim_name: Optional string, the axis name in the `device_mesh` + that will be used for data parallelism. Defaults to the first + axis name in the mesh. + """ + + def __init__( + self, + device_mesh=None, + batch_dim_name=None, + auto_shard_dataset=True, + ): + if device_mesh is None: + # Auto-create a 1D mesh with all available devices + devices = list_devices() + device_mesh = DeviceMesh( + shape=(len(devices),), + axis_names=("model",), + devices=devices, + ) + batch_dim_name = batch_dim_name or device_mesh.axis_names[0] + super().__init__(device_mesh, batch_dim_name, auto_shard_dataset) + + def shard(self, model: "keras.Model") -> "TensorParallelKeras": + from keras.src.distribution.tensor_parallel.tensor_parallel_keras import ( + TensorParallelKeras, + ) + + """ + Applies automatic tensor parallelism to a Keras model. + + This method takes a standard Keras model, analyzes its layers, + and returns a new `TensorParallelKeras` model instance where the + weights have been sharded across the devices specified in the + `DeviceMesh`. + + Args: + model: The original `keras.Model` instance to be sharded. + + Returns: + A `TensorParallelKeras` model instance ready for distributed + training. + """ + print(f"INFO: Sharding model `{model.name}` for Tensor Parallelism...") + world_size = np.prod(self.device_mesh.shape) + device_ids = np.ravel(self.device_mesh.devices).tolist() + + # The `TensorParallelKeras` class contains all the sharding logic. + # This distribution strategy is a clean, high-level entry point to it. + sharded_model = TensorParallelKeras( + model, world_size=world_size, device_ids=device_ids + ) + print(f"INFO: Model `{model.name}` has been successfully sharded.") + return sharded_model + + def get_data_layout(self, data_shape): + """Returns the layout for data, sharding across the batch dimension.""" + data_shard_spec = [None] * len(data_shape) + if self.batch_dim_name in self.device_mesh.axis_names: + data_shard_spec[0] = self.batch_dim_name + return TensorLayout(data_shard_spec, self.device_mesh) + + def get_variable_layout(self, variable): + """Returns the layout for a variable (replicated by default).""" + # In this pattern, the sharding logic is self-contained within the + # TensorParallelKeras model. The global distribution mechanism is + # primarily for data sharding. Variables outside the model are replicated. + return TensorLayout([None] * len(variable.shape), self.device_mesh) + + def get_tensor_layout(self, path): + return ( + None # Not needed as communication is handled by the model's call() + ) + + def distribute_dataset(self, dataset): + if distribution_lib.num_processes() <= 1 or not self.auto_shard_dataset: + return dataset + from keras.src.utils.module_utils import tensorflow as tf + + if not tf.available or not isinstance(dataset, tf.data.Dataset): + raise ValueError( + "Only `tf.data.Dataset` is supported for auto-sharding." + ) + return dataset.with_options(tf.data.Options()) + + @keras_export("keras.distribution.ModelParallel") class ModelParallel(Distribution): """Distribution that shards model variables. @@ -895,4 +1041,4 @@ def set_distribution(value): Args: value: a `Distribution` instance. """ - global_state.set_global_attribute(GLOBAL_ATTRIBUTE_NAME, value) + global_state.set_global_attribute(GLOBAL_ATTRIBUTE_NAME, value) \ No newline at end of file diff --git a/keras/src/distribution/distribution_lib_test.py b/keras/src/distribution/distribution_lib_test.py index 66f996b3fb68..6bc616aa7253 100644 --- a/keras/src/distribution/distribution_lib_test.py +++ b/keras/src/distribution/distribution_lib_test.py @@ -1,16 +1,36 @@ """Test for distribution_lib.py.""" import os + +# FILE: keras/src/distribution/distribution_lib_test.py + + +# --- TOP-LEVEL ENVIRONMENT SETUP --- +# This MUST be at the top of the file, before any Keras/TF imports. +# It configures the environment for all tests in this file. +os.environ["KERAS_BACKEND"] = "jax" +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" + +# --- Now continue with the rest of the imports --- +# ... and so on from unittest import mock import numpy as np import pytest import tensorflow as tf +import keras from keras.src import backend from keras.src import testing from keras.src.backend import distribution_lib as backend_dlib from keras.src.distribution import distribution_lib +from keras.src.distribution.distribution_lib import AutoTPDistribution + +try: + import keras_hub +except ImportError: + keras_hub = None @pytest.mark.skipif( @@ -535,3 +555,129 @@ def test_iter(self): # ValueError, "Cannot create sharding when device mesh is not set" # ): # backend_dlib._to_dtensor_layout(layout) + + +# Add this test class to the end of: +# keras/src/distribution/distribution_lib_test.py + +from keras.src import layers +from keras.src import testing + +# Import your new distribution class and the other necessary components +from keras.src.distribution.distribution_lib import DeviceMesh +from keras.src.distribution.tensor_parallel.tensor_parallel_keras import ( + TensorParallelKeras, +) + +# Import your new distribution class and the other necessary components + + +class AutoTPDistributionTest(testing.TestCase): + def test_sharding_correctness_for_all_param_types(self): + """ + Tests that all parameter types (column-parallel, row-parallel, + and replicated) are sharded correctly. + """ + # 1. ARRANGE + devices = ["cpu:0", "cpu:1"] + device_mesh = DeviceMesh( + shape=(2,), axis_names=("model",), devices=devices + ) + distribution = AutoTPDistribution(device_mesh=device_mesh) + + original_model = keras.Sequential( + [ + layers.Input(shape=(20,)), + layers.Dense(16, name="dense_1"), # Column-parallel + layers.Dense(8, name="dense_2"), # Row-parallel + ], + name="my_model", + ) + original_model.build(input_shape=(None, 20)) + + # 2. ACT + sharded_model = distribution.shard(original_model) + + # 3. ASSERT + self.assertIsInstance(sharded_model, TensorParallelKeras) + self.assertEqual(sharded_model.world_size, 2) + shard_strategy = sharded_model.model_shards[0].sharding_strategy + + # --- Check Column-Parallel Layer (dense_1) --- + # Kernel should be sharded on the output dim (1) + orig_k1_shape = original_model.get_layer("dense_1").kernel.shape + shard_k1_info = shard_strategy.get_weight_info( + "my_model.dense_1.kernel" + ) + self.assertIsNotNone(shard_k1_info) + self.assertEqual(shard_k1_info["sharded_shape"][0], orig_k1_shape[0]) + self.assertEqual( + shard_k1_info["sharded_shape"][1], orig_k1_shape[1] // 2 + ) + + # Bias should also be sharded + orig_b1_shape = original_model.get_layer("dense_1").bias.shape + shard_b1_info = shard_strategy.get_weight_info("my_model.dense_1.bias") + self.assertIsNotNone(shard_b1_info) + self.assertEqual( + shard_b1_info["sharded_shape"][0], orig_b1_shape[0] // 2 + ) + + # --- Check Row-Parallel Layer (dense_2) --- + # Kernel should be sharded on the input dim (0) + orig_k2_shape = original_model.get_layer("dense_2").kernel.shape + shard_k2_info = shard_strategy.get_weight_info( + "my_model.dense_2.kernel" + ) + self.assertIsNotNone(shard_k2_info) + self.assertEqual( + shard_k2_info["sharded_shape"][0], orig_k2_shape[0] // 2 + ) + self.assertEqual(shard_k2_info["sharded_shape"][1], orig_k2_shape[1]) + + # Bias should be replicated (not sharded) + shard_b2_info = shard_strategy.get_weight_info("my_model.dense_2.bias") + self.assertIsNone(shard_b2_info) # Correctly not found in sharded map + + def test_uneven_sharding_splits_correctly(self): + """ + Tests that weights are sharded correctly when the dimension is not + perfectly divisible by the number of devices. + """ + # 1. ARRANGE: Use 3 devices for an uneven split + devices = ["cpu:0", "cpu:1", "cpu:2"] + device_mesh = DeviceMesh( + shape=(3,), axis_names=("model",), devices=devices + ) + distribution = AutoTPDistribution(device_mesh=device_mesh) + + # Create a model with a dimension not divisible by 3 (e.g., 17) + original_model = keras.Sequential( + [layers.Dense(17, input_shape=(10,), name="dense_uneven")], + name="uneven_model", + ) + original_model.build() + + # 2. ACT + sharded_model = distribution.shard(original_model) + + # 3. ASSERT + # For a dimension of 17 split across 3 devices, the expected + # sharded shapes are (6, 5, 5). + strategy_shard0 = sharded_model.model_shards[0].sharding_strategy + strategy_shard1 = sharded_model.model_shards[1].sharding_strategy + strategy_shard2 = sharded_model.model_shards[2].sharding_strategy + + shape_shard0 = strategy_shard0.get_weight_info( + "uneven_model.dense_uneven.kernel" + )["sharded_shape"] + shape_shard1 = strategy_shard1.get_weight_info( + "uneven_model.dense_uneven.kernel" + )["sharded_shape"] + shape_shard2 = strategy_shard2.get_weight_info( + "uneven_model.dense_uneven.kernel" + )["sharded_shape"] + + self.assertEqual(shape_shard0, (10, 6)) + self.assertEqual(shape_shard1, (10, 6)) # āœ… CORRECTED + self.assertEqual(shape_shard2, (10, 5)) \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/tensor_parallel.py b/keras/src/distribution/tensor_parallel/tensor_parallel.py new file mode 100644 index 000000000000..f0603eaa71d4 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/tensor_parallel.py @@ -0,0 +1,801 @@ +import re +from typing import Collection, Optional, Sequence, Union + +import numpy as np +import tensorflow as tf +import keras +from keras import ops +from keras.src.distribution.tensor_parallel.autoconfig import ( + get_default_config_keras, +) +from keras.src.distribution.tensor_parallel.parameter_sharding import ( + make_parameter_sharded_model, +) +from keras.src.distribution.tensor_parallel.sharding_keras import ShardedKeras + +from keras.src.distribution.tensor_parallel.coordinated_optimizer import TensorParallelOptimizer + + +from keras.src.models import Model + + +class TensorParallelKeras(Model): + """A Keras Model wrapper for implementing tensor parallelism. + + This class takes a standard Keras model and shards its weights across + multiple devices (`world_size`). It automatically handles the sharding of + parameters, communication between devices, and construction of a unified + computational graph. The result is a model that can be trained and used + like a regular Keras model but leverages multiple accelerators to fit + larger models into memory. + + Args: + model (keras.Model): The Keras model to be parallelized. + world_size (int, optional): The total number of devices to shard the + model across. If `None`, it will be auto-detected. Defaults to `None`. + device_ids (Sequence[str], optional): A sequence of specific device IDs + (e.g., `['/gpu:0', '/gpu:1']`) to use. If `None`, devices will be + auto-configured. Defaults to `None`. + distributed_backend (str, optional): The backend to use for distributed + communication. Defaults to "auto". + **kwargs: Additional keyword arguments passed to the `keras.Model` + base class constructor. + """ + def __init__( + self, + model, + world_size=None, + device_ids=None, + distributed_backend="auto", + **kwargs, + ): + super().__init__(**kwargs) + + self._original_model = model + + if world_size is None: + world_size, device_ids = self._auto_detect_parallelism() + elif device_ids is None: + device_ids = self._auto_configure_devices( + world_size, distributed_backend + ) + + self.world_size = world_size + self.device_ids = device_ids + self.sharding_strategy = "auto" + self.distributed_backend = distributed_backend + + self.tensor_parallel_config = None + self.distributed = True + + self.sharded_models = [self._original_model] + original_params = 0 + for p in model.weights: + if hasattr(p, "shape") and hasattr(p.shape, "num_elements"): + original_params += p.shape.num_elements() + elif hasattr(p, "shape") and hasattr(p.shape, "__iter__"): + original_params += np.prod(p.shape) + else: + original_params += np.prod(p.shape) + + device_ids = list(self.check_device_ids(device_ids)) + + accel_devices = self._discover_devices() + + if accel_devices: + backend_name = keras.backend.backend() + + if len(accel_devices) >= world_size: + device_ids = accel_devices[:world_size] + else: + world_size = len(accel_devices) + device_ids = accel_devices[:world_size] + + if not device_ids: + device_ids = self._auto_configure_devices( + world_size, distributed_backend + ) + + if len(device_ids) != world_size: + device_ids = self._adjust_device_list(device_ids, world_size) + + self.devices = device_ids + self.world_size = world_size + self.sharding_manager = None + + if self.world_size <= 1: + self.model_shards = [model] + self.distributed = False + if len(self.devices) == 1: + from keras import device + with device(self.devices[0]): + self.model_shards[0] = model + self.built = True + self.assembled_model = self._original_model + return + + if self.tensor_parallel_config is None: + device_names = [str(d) for d in self.devices] + self.tensor_parallel_config = get_default_config_keras( + model, device_names + ) + config_with_ops = self.tensor_parallel_config.create_collective_ops( + self.devices + ) + + self._is_multi_layer_model = len(model.layers) > 2 + + self.model_shards = [] + self.modified_parameters_names = set() + + + for rank, device_id in enumerate(self.devices): + shard, modified_parameters_names = make_parameter_sharded_model( + model, + config_with_ops, + rank=rank, + world_size=self.world_size, + device_id=device_id, + ) + self.model_shards.append(shard) + self.modified_parameters_names.update(modified_parameters_names) + + params_per_shard = [] + for i, shard in enumerate(self.model_shards): + total_params = 0 + for p in shard.weights: + if hasattr(p, "num_elements"): + total_params += p.num_elements() + elif hasattr(p, "numel"): + total_params += p.numel() + elif hasattr(p.shape, "num_elements"): + total_params += p.shape.num_elements() + else: + total_params += np.prod(p.shape) + + params_per_shard.append(int(total_params)) + + self.distributed_backend_name = distributed_backend + from keras.src.distribution import distributed_backend + + self.distributed_backend = distributed_backend + self.built = True + if self.distributed: + self.assembled_model = self.build_assembled_model() + else: + self.assembled_model = self._original_model + + @property + def variables(self): + """Returns a unique list of all variables from all model shards.""" + unique_vars = {} + for shard in self.model_shards: + for var in shard.variables: + if id(var) not in unique_vars: + unique_vars[id(var)] = var + return list(unique_vars.values()) + + @property + def trainable_variables(self): + """Returns a unique list of all trainable variables from all model shards.""" + unique_vars = {} + for shard in self.model_shards: + for var in shard.trainable_variables: + if id(var) not in unique_vars: + unique_vars[id(var)] = var + return list(unique_vars.values()) + + @property + def non_trainable_variables(self): + """Returns a unique list of all non-trainable variables from all model shards.""" + unique_vars = {} + for shard in self.model_shards: + for var in shard.non_trainable_variables: + if id(var) not in unique_vars: + unique_vars[id(var)] = var + return list(unique_vars.values()) + + @property + def weights(self): + """Returns a unique list of all weights from all model shards.""" + unique_vars = {} + for shard in self.model_shards: + for var in shard.weights: + if id(var) not in unique_vars: + unique_vars[id(var)] = var + return list(unique_vars.values()) + + @property + def trainable_weights(self): + """Returns a unique list of all trainable weights from all model shards.""" + unique_vars = {} + for shard in self.model_shards: + for var in shard.trainable_weights: + if id(var) not in unique_vars: + unique_vars[id(var)] = var + return list(unique_vars.values()) + + @property + def non_trainable_weights(self): + """Returns a unique list of all non-trainable weights from all model shards.""" + unique_vars = {} + for shard in self.model_shards: + for var in shard.non_trainable_weights: + if id(var) not in unique_vars: + unique_vars[id(var)] = var + return list(unique_vars.values()) + + def _discover_devices(self): + """Discovers available accelerator devices for the current backend. + + Returns: + list: A list of strings representing the available device names. + """ + backend = keras.backend.backend() + devices = [] + + if backend == "jax": + import jax + all_devices = jax.devices() + for platform in ("tpu", "gpu", "cpu"): + platform_devices = [ + d for d in all_devices if d.platform == platform + ] + if platform_devices: + devices = platform_devices + break + elif backend == "tensorflow": + import tensorflow as tf + gpus = tf.config.list_logical_devices("GPU") + if gpus: + devices = [d.name for d in gpus] + else: + cpus = tf.config.list_logical_devices("CPU") + devices = [d.name for d in cpus] + elif backend == "torch": + import torch + if torch.cuda.is_available(): + devices = [ + f"cuda:{i}" for i in range(torch.cuda.device_count()) + ] + elif torch.backends.mps.is_available(): + devices = ["mps"] + else: + devices = ["cpu"] + + return devices + + def _auto_detect_parallelism(self): + """Auto-detects world_size and device_ids based on available hardware. + + Returns: + tuple: A tuple containing the world size (int) and a list of + device IDs (list[str]). + """ + from keras.src.distribution import get_best_devices + from keras.src.distribution import list_devices + + available_devices = list_devices() + world_size = len(available_devices) + + device_ids = get_best_devices(world_size) + + return world_size, device_ids + + def _adjust_device_list(self, device_ids, target_world_size): + """Adjusts the device list to match the target world size. + + If the list is longer, it's truncated. If it's shorter, it's + extended, attempting to follow the pattern of existing devices or + falling back to CPUs. + + Args: + device_ids (list): The current list of device IDs. + target_world_size (int): The desired number of devices. + + Returns: + list: The adjusted list of device IDs. + """ + current_size = len(device_ids) + if current_size >= target_world_size: + return device_ids[:target_world_size] + + num_to_add = target_world_size - current_size + + if not device_ids: + return [f"cpu:{i}" for i in range(target_world_size)] + + base_device = device_ids[0] + if isinstance(base_device, str) and ":" in base_device: + device_type, index_str = base_device.rsplit(":", 1) + if index_str.isdigit(): + additional_devices = [ + f"{device_type}:{current_size + i}" for i in range(num_to_add) + ] + return device_ids + additional_devices + + additional_devices = [f"cpu:{current_size + i}" for i in range(num_to_add)] + return device_ids + additional_devices + + def _auto_configure_devices(self, world_size, distributed_backend): + """Automatically configures a list of devices to use. + + It prioritizes available accelerators. + + Args: + world_size (int): The number of devices to configure. + distributed_backend (str): The name of the distributed backend. + + Returns: + list: A list of device ID strings. + """ + from keras.src.distribution import list_devices + + available_devices = list_devices() + + if available_devices: + devices = available_devices[:world_size] + return devices + else: + return ["cpu:0"] + + def check_device_ids( + self, device_ids: Optional[Sequence[str]] + ) -> Sequence[str]: + """Validates and normalizes a sequence of device IDs. + + Args: + device_ids (Sequence[str], optional): The input device IDs. + + Returns: + Sequence[str]: A tuple of canonicalized device ID strings. + """ + if device_ids is None: + device_ids = self._get_all_device_indices() + + device_ids = list(device_ids) + + canonical_ids = [] + for device_id in device_ids: + if isinstance(device_id, str): + canonical_ids.append(self.canonicalize_device(device_id)) + else: + canonical_ids.append(device_id) + + return tuple(canonical_ids) + + def _get_all_device_indices(self) -> Sequence[str]: + """Gets all available device identifiers from the distribution backend. + + Returns: + Sequence[str]: A sequence of available device names. + """ + from keras.src.distribution import list_devices + + devices = list_devices() + return devices + + def build_assembled_model(self): + """Builds a single Keras Functional model that encapsulates the parallel logic. + + This method creates a unified computation graph that takes the user's + inputs, passes them to each model shard in parallel, and then correctly + combines the outputs from each shard based on the sharding strategy of + the final layer (e.g., concatenation for column-parallel, summation for + row-parallel). + + This approach provides a simple, high-level interface for both inference + and training and is more amenable to JIT compilation. + + Returns: + keras.Model: The assembled functional model representing the entire + tensor-parallel computation. + """ + if not self.distributed: + return self._original_model + + input_layers = { + inp.name.split(":")[0]: keras.Input( + shape=inp.shape[1:], + dtype=inp.dtype, + name=inp.name.split(":")[0], + ) + for inp in self._original_model.inputs + } + + partial_outputs = [model(input_layers) for model in self.sharded_models] + + final_layer = self._original_model.layers[-1] + sharding_type = "unknown" + final_kernel_name = f"{final_layer.name}.kernel" + if hasattr(self._original_model, "name") and self._original_model.name: + final_kernel_name = ( + f"{self._original_model.name}.{final_kernel_name}" + ) + + for pattern, action in self.tensor_parallel_config.state_rules.items(): + if re.search(pattern, final_kernel_name): + if hasattr(action, "sharding_type"): + sharding_type = action.sharding_type + break + + if sharding_type == "column": + final_output = ops.concatenate(partial_outputs, axis=-1) + original_output_dim = self._original_model.output_shape[-1] + if final_output.shape[-1] != original_output_dim: + final_output = keras.layers.Lambda( + lambda x: x[..., :original_output_dim] + )(final_output) + elif sharding_type == "row": + if len(partial_outputs) > 1: + summed_output = keras.layers.Add()(partial_outputs) + else: + summed_output = partial_outputs[0] + + if final_layer.use_bias: + bias = final_layer.bias + final_output = keras.layers.Lambda( + lambda x: x - bias * (self.world_size - 1) + )(summed_output) + else: + final_output = summed_output + else: + final_output = partial_outputs[0] + + assembled_model = keras.Model( + inputs=list(input_layers.values()), outputs=final_output + ) + return assembled_model + + def canonicalize_device(self, device_spec: Union[str, int]) -> str: + """Converts a device specification to a canonical string format. + + For example, `1` -> `"gpu:1"`, `"cuda:1"` -> `"gpu:1"`. + + Args: + device_spec (Union[str, int]): The device identifier. + + Returns: + str: The canonical device string. + """ + if isinstance(device_spec, int): + if device_spec == -1: + return "cpu" + else: + return f"gpu:{device_spec}" + elif isinstance(device_spec, str): + if device_spec == "cpu": + return "cpu" + elif device_spec.startswith("gpu:"): + return device_spec + elif device_spec.startswith("cuda:"): + return f"gpu:{device_spec.split(':')[1]}" + else: + return device_spec + else: + return "cpu" + + def apply_sharding( + self, replicated_param_names: Optional[Collection[str]] = None + ): + """Applies the sharding strategy to the model parameters. + + This method is typically called internally but can be used to manually + trigger the sharding process. + + Args: + replicated_param_names (Collection[str], optional): A collection of + parameter names that should be replicated across all devices + instead of sharded. Defaults to `self.modified_parameters_names`. + """ + if replicated_param_names is None: + replicated_param_names = self.modified_parameters_names + + self.sharding_manager = ShardedKeras( + self.model_shards, + replicated_param_names, + self.tensor_parallel_config, + self.devices, + 0, + ) + + def call(self, inputs, training=None, **kwargs): + """Defines the forward pass of the tensor-parallel model. + + This method delegates the call to the `assembled_model`, which contains + the complete, unified computation graph for the parallel execution. + + Args: + inputs: The input tensor(s). + training (bool, optional): Indicates whether the model is in + training mode. Defaults to None. + **kwargs: Additional arguments for the forward pass. + + Returns: + The output tensor(s) of the model. + """ + return self.assembled_model(inputs, training=training, **kwargs) + + def _apply_forward_communication(self, inputs, training=None, **kwargs): + """ + (Internal) Applies forward pass communication based on the conjugate rule. + + Note: This method's logic is typically encapsulated within the + `assembled_model` and may not be called directly during a standard + forward pass. + + Args: + inputs: Input tensors. + training (bool, optional): Training mode flag. + **kwargs: Additional arguments. + + Returns: + The combined output tensor after communication. + """ + if ( + not hasattr(self, "tensor_parallel_config") + or self.tensor_parallel_config is None + ): + return self.shard_outputs[0] + + output_rules = self.tensor_parallel_config.output_rules + + if not output_rules: + return self.shard_outputs[0] + + from keras.src.distribution.tensor_parallel.communications import ( + TensorParallelCommunicator, + ) + + communicator = TensorParallelCommunicator(self.world_size, rank=0) + + if hasattr(self, "_is_mlp_model") and self._is_mlp_model: + return self._handle_mlp_forward_communication(communicator) + else: + return self._handle_single_layer_forward_communication( + communicator, output_rules + ) + + def _handle_mlp_forward_communication(self, communicator): + """ + (Internal) Handles MLP-specific forward communication with handshake optimization. + + Args: + communicator (TensorParallelCommunicator): The communication handler. + + Returns: + The final output tensor. + """ + up_outputs = [] + down_outputs = [] + + for i in range(self.world_size): + if i in self.shard_outputs: + up_outputs.append(self.shard_outputs[i]) + down_outputs.append(self.shard_outputs[i]) + + final_up, final_down = communicator.handle_mlp_handshake( + up_outputs, down_outputs + ) + + return final_down[0] if isinstance(final_down, list) else final_down + + def _handle_single_layer_forward_communication( + self, communicator, output_rules + ): + """ + (Internal) Handles forward communication for a single sharded layer. + + Args: + communicator (TensorParallelCommunicator): The communication handler. + output_rules (dict): Rules defining how to handle outputs. + + Returns: + The final output tensor. + """ + first_output = self.shard_outputs[0] + if hasattr(first_output, "shape") and len(first_output.shape) >= 2: + if ( + hasattr(self, "_is_multi_layer_model") + and self._is_multi_layer_model + ): + return first_output + + partial_outputs = [] + for i in range(self.world_size): + if i in self.shard_outputs: + partial_outputs.append(self.shard_outputs[i]) + return first_output + + return self.shard_outputs[0] + + def compile(self, optimizer=None, loss=None, metrics=None, **kwargs): + """Compiles the tensor-parallel model. + + This method overrides the standard `compile`. If the model is distributed + (`world_size > 1`), it wraps the provided optimizer in a + `TensorParallelOptimizer`. This specialized optimizer is responsible for + coordinating gradient computation and application across all devices + during training. + + Args: + optimizer: The optimizer instance. + loss: The loss function. + metrics: A list of metrics to be evaluated by the model. + **kwargs: Additional arguments passed to `super().compile()`. + """ + if len(self.model_shards) > 1 and optimizer is not None: + backend_name = getattr(self, "distributed_backend_name", "auto") + + self.coordinated_optimizer = TensorParallelOptimizer( + optimizer, + self.world_size, + distributed_backend=backend_name, + tensor_parallel_config=self.tensor_parallel_config, + ) + + super().compile( + optimizer=self.coordinated_optimizer, + loss=loss, + metrics=metrics, + **kwargs, + ) + + else: + super().compile(optimizer, loss, metrics, **kwargs) + + def _apply_backward_communication(self, gradients, layer_type="unknown"): + """ + (Internal) Applies backward pass communication based on the conjugate rule. + + Note: This logic is typically handled by the `TensorParallelOptimizer` + and may not be called directly during standard training. + + Args: + gradients: The gradients to be communicated. + layer_type (str, optional): The type of layer sharding + (e.g., 'column', 'row'). Defaults to "unknown". + + Returns: + The communicated gradients. + """ + if len(self.model_shards) <= 1: + return gradients + + from keras.src.distribution.tensor_parallel.communications import ( + TensorParallelCommunicator, + ) + + communicator = TensorParallelCommunicator(self.world_size, rank=0) + + if ( + "column" in layer_type.lower() + or "up_projection" in layer_type.lower() + ): + return communicator.backward_column_parallel( + gradients, op="sum" + ) + elif ( + "row" in layer_type.lower() + or "down_projection" in layer_type.lower() + ): + gathered = communicator.backward_row_parallel(gradients, dim=-1) + return [gathered] * self.world_size + else: + return gradients + + def _slice_upstream_gradients_for_backward( + self, full_gradients, sharding_type="unknown" + ): + """ + (Internal) Slices the upstream gradients to match each device's shard. + + Note: This logic is typically handled by the `TensorParallelOptimizer`. + + Args: + full_gradients: The complete upstream gradients. + sharding_type (str, optional): The sharding type of the layer that + produced the gradients. Defaults to "unknown". + + Returns: + list: A list of sliced gradients, one for each device shard. + """ + if len(self.model_shards) <= 1: + return [full_gradients] + + from keras.src.distribution.tensor_parallel.communications import ( + TensorParallelCommunicator, + ) + + communicator = TensorParallelCommunicator(self.world_size, rank=0) + + sliced_gradients = [] + + for rank in range(self.world_size): + if sharding_type == "column_parallel": + sliced_grad = communicator.slice_upstream_gradient_for_column_parallel( + full_gradients, rank, self.world_size, dim=-1 + ) + elif sharding_type == "row_parallel": + sliced_grad = ( + communicator.slice_upstream_gradient_for_row_parallel( + full_gradients, rank, self.world_size, dim=0 + ) + ) + else: + sliced_grad = full_gradients + + sliced_gradients.append(sliced_grad) + + return sliced_gradients + + def _compute_shard_gradients_with_sliced_upstream( + self, shard, sliced_upstream_grad, inputs, training=True + ): + """ + (Internal) Computes gradients for a single shard using its sliced upstream gradient. + + Note: This logic is typically handled by the `TensorParallelOptimizer`. + + Args: + shard (keras.Model): The model shard. + sliced_upstream_grad: The corresponding slice of the upstream gradient. + inputs: The inputs to the shard. + training (bool, optional): Training mode flag. Defaults to True. + + Returns: + list: The computed gradients for the shard's trainable variables. + """ + with tf.GradientTape() as tape: + shard_output = shard(inputs, training=training) + loss = self._compute_shard_loss( + shard_output, sliced_upstream_grad + ) + + gradients = tape.gradient(loss, shard.trainable_variables) + return gradients + + def _compute_shard_loss(self, shard_output, sliced_upstream_grad): + """ + (Internal) Computes a pseudo-loss to generate correct gradients. + + This function creates a loss whose gradient with respect to the + `shard_output` is equal to the `sliced_upstream_grad`. A common way + is to use the mean squared error between the output and the gradient. + + Args: + shard_output: The output tensor from the model shard. + sliced_upstream_grad: The target gradient for the shard's output. + + Returns: + A scalar loss tensor. + """ + if hasattr(sliced_upstream_grad, "shape") and hasattr( + shard_output, "shape" + ): + target = sliced_upstream_grad + loss = tf.reduce_mean(tf.square(shard_output - target)) + return loss + else: + return tf.reduce_mean(tf.square(shard_output)) + + def fit(self, x=None, y=None, **kwargs): + """Trains the model for a fixed number of epochs (iterations on a dataset). + + This method leverages the standard `keras.Model.fit()` training loop. + The custom logic provided in `compile()` and the `assembled_model` + ensures that each training step is executed in a tensor-parallel manner + correctly. + + Args: + x: Input data. + y: Target data. + **kwargs: Other arguments supported by `keras.Model.fit()`. + + Returns: + A `History` object. Its `History.history` attribute is + a record of training loss values and metrics values + at successive epochs, as well as validation loss values + and validation metrics values (if applicable). + """ + return super().fit(x, y, **kwargs) \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/tensor_parallel_test.py b/keras/src/distribution/tensor_parallel/tensor_parallel_test.py new file mode 100644 index 000000000000..f838a123763d --- /dev/null +++ b/keras/src/distribution/tensor_parallel/tensor_parallel_test.py @@ -0,0 +1,153 @@ +import pytest +import numpy as np +import keras +import tensorflow as tf +from unittest.mock import patch, MagicMock + +from keras.src.distribution.tensor_parallel.tensor_parallel import TensorParallelKeras + + +INPUT_DIM = 10 +HIDDEN_DIM = 17 +OUTPUT_DIM = 5 +BATCH_SIZE = 8 + +def create_simple_mlp(): + """Creates a simple Keras Sequential model for testing.""" + return keras.Sequential( + [ + keras.Input(shape=(INPUT_DIM,), name="input_layer"), + keras.layers.Dense(HIDDEN_DIM, activation="relu", name="dense_1"), + keras.layers.Dense(OUTPUT_DIM, name="dense_2"), + ] + ) + +def count_params(model_or_weights): + """Helper function to count the total number of parameters.""" + weights = [] + if isinstance(model_or_weights, keras.Model): + weights = model_or_weights.weights + else: + weights = model_or_weights + + total_params = 0 + for p in weights: + if hasattr(p, "shape"): + total_params += np.prod(p.shape) + return int(total_params) + + +class TestTensorParallelKeras: + """Test suite for the TensorParallelKeras wrapper.""" + + @pytest.fixture + def mock_devices(self): + """A pytest fixture to mock device discovery for a predictable environment.""" + with patch.object( + TensorParallelKeras, + "_discover_devices", + return_value=["cpu:0", "cpu:1"], + ) as mock: + yield mock + + def test_initialization_and_sharding(self, mock_devices): + """ + Tests if the model is correctly initialized and sharded for world_size > 1. + """ + print("šŸš€ Testing model initialization and sharding...") + original_model = create_simple_mlp() + original_params = count_params(original_model) + + tp_model = TensorParallelKeras(model=original_model, world_size=2) + + assert tp_model.world_size == 2 + assert tp_model.distributed is True + assert len(tp_model.model_shards) == 2, "Model should be split into 2 shards" + + shard1_params = count_params(tp_model.model_shards[0]) + shard2_params = count_params(tp_model.model_shards[1]) + + assert shard1_params < original_params, "Shard 1 should have fewer params than the original" + assert shard2_params < original_params, "Shard 2 should have fewer params than the original" + assert shard1_params != shard2_params, "Shards should have different param counts" + print("āœ… Initialization and sharding successful.") + + + def test_non_distributed_case_world_size_one(self, mock_devices): + """ + Tests if the model behaves like a standard Keras model when world_size is 1. + """ + print("\nšŸš€ Testing non-distributed case (world_size=1)...") + original_model = create_simple_mlp() + original_params = count_params(original_model) + + tp_model = TensorParallelKeras(model=original_model, world_size=1) + + assert tp_model.world_size == 1 + assert tp_model.distributed is False + assert len(tp_model.model_shards) == 1 + assert tp_model.model_shards[0] == original_model + assert count_params(tp_model.model_shards[0]) == original_params + print("āœ… Non-distributed case handled correctly.") + + + def test_forward_pass_output_shape(self, mock_devices): + """ + Tests if the forward pass of the sharded model executes and returns the correct shape. + """ + print("\nšŸš€ Testing forward pass output shape...") + original_model = create_simple_mlp() + dummy_input = np.random.rand(BATCH_SIZE, INPUT_DIM).astype("float32") + + tp_model = TensorParallelKeras(model=original_model, world_size=2) + output = tp_model(dummy_input) + + assert output is not None + assert output.shape == (BATCH_SIZE, OUTPUT_DIM), "Output shape is incorrect" + print("āœ… Forward pass successful with correct output shape.") + + + @patch("keras.src.distribution.tensor_parallel.communications.TensorParallelCommunicator") + def test_gradient_slicing_logic(self, MockCommunicator, mock_devices): + """ + Verifies that the correct upstream gradient slicing methods are called. + """ + print("\nšŸš€ Testing upstream gradient slicing logic...") + mock_communicator_instance = MagicMock() + MockCommunicator.return_value = mock_communicator_instance + + original_model = create_simple_mlp() + tp_model = TensorParallelKeras(model=original_model, world_size=2) + + dummy_full_gradients = tf.ones((BATCH_SIZE, OUTPUT_DIM)) + + tp_model._slice_upstream_gradients_for_backward(dummy_full_gradients, "column_parallel") + assert mock_communicator_instance.slice_upstream_gradient_for_column_parallel.call_count == 2 + + tp_model._slice_upstream_gradients_for_backward(dummy_full_gradients, "row_parallel") + assert mock_communicator_instance.slice_upstream_gradient_for_row_parallel.call_count == 2 + + print("āœ… Upstream gradient slicing calls verified.") + + + @patch("keras.src.distribution.tensor_parallel.communications.TensorParallelCommunicator") + def test_backward_communication_logic(self, MockCommunicator, mock_devices): + """ + Verifies that the correct backward communication primitives (AllReduce/AllGather) are called. + """ + print("\nšŸš€ Testing backward pass communication logic...") + mock_communicator_instance = MagicMock() + MockCommunicator.return_value = mock_communicator_instance + + original_model = create_simple_mlp() + tp_model = TensorParallelKeras(model=original_model, world_size=2) + + dummy_gradients = [tf.ones((INPUT_DIM, HIDDEN_DIM)), tf.ones((HIDDEN_DIM,))] + + tp_model._apply_backward_communication(dummy_gradients, layer_type="column") + mock_communicator_instance.backward_column_parallel.assert_called_once() + + tp_model._apply_backward_communication(dummy_gradients, layer_type="row") + mock_communicator_instance.backward_row_parallel.assert_called_once() + + print("āœ… Backward communication calls verified.") \ No newline at end of file