diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index 15f1af2145d5..dc93944dfd47 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -37,6 +37,8 @@ if backend() == "tensorflow": from keras.src.backend.tensorflow import * # noqa: F403 from keras.src.backend.tensorflow.core import Variable as BackendVariable + + distributed_backend = None elif backend() == "jax": from keras.src.backend.jax import * # noqa: F403 from keras.src.backend.jax.core import Variable as BackendVariable @@ -44,17 +46,18 @@ from keras.src.backend.torch import * # noqa: F403 from keras.src.backend.torch.core import Variable as BackendVariable - distribution_lib = None elif backend() == "numpy": from keras.src.backend.numpy import * # noqa: F403 from keras.src.backend.numpy.core import Variable as BackendVariable distribution_lib = None + distributed_backend = None elif backend() == "openvino": from keras.src.backend.openvino import * # noqa: F403 from keras.src.backend.openvino.core import Variable as BackendVariable distribution_lib = None + distributed_backend = None else: raise ValueError(f"Unable to import backend : {backend()}") diff --git a/keras/src/backend/torch/__init__.py b/keras/src/backend/torch/__init__.py index 371a62cd0f52..a7d0405a5567 100644 --- a/keras/src/backend/torch/__init__.py +++ b/keras/src/backend/torch/__init__.py @@ -16,6 +16,8 @@ from keras.src.backend.common.name_scope import name_scope from keras.src.backend.torch import core +from keras.src.backend.torch import distributed_backend +from keras.src.backend.torch import distribution_lib from keras.src.backend.torch import image from keras.src.backend.torch import linalg from keras.src.backend.torch import math diff --git a/keras/src/backend/torch/distributed_backend.py b/keras/src/backend/torch/distributed_backend.py new file mode 100644 index 000000000000..432760339102 --- /dev/null +++ b/keras/src/backend/torch/distributed_backend.py @@ -0,0 +1,220 @@ +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Literal + +import torch +import torch.distributed as dist + + +def compute_gradients( + loss: torch.Tensor, trainable_vars: List[torch.Tensor] +) -> List[torch.Tensor]: + """Computes gradients of the loss with respect to trainable variables. + + This function leverages PyTorch's `autograd.grad` for a stateless, + functional approach similar to `jax.grad`. + + Args: + loss (torch.Tensor): The loss value for which to compute gradients. + trainable_vars (List[torch.Tensor]): A list of variables (tensors with + `requires_grad=True`) to compute gradients with respect to. + + Returns: + List[torch.Tensor]: A list of gradients corresponding to the + trainable variables. + """ + return list(torch.autograd.grad(loss, trainable_vars)) + + +def apply_gradients( + gradients: List[torch.Tensor], + trainable_vars: List[torch.Tensor], + learning_rate: float = 0.001, +) -> List[torch.Tensor]: + """Applies gradients and returns the updated variables. + + Updates are performed in-place within a `torch.no_grad()` context + to prevent the update operation from being part of the computation graph. + """ + with torch.no_grad(): + updated_vars = [] + for grad, var in zip(gradients, trainable_vars): + if grad is not None: + var.sub_(learning_rate * grad) + updated_vars.append(var) + return updated_vars + + +def create_optimizer(optimizer_class: str, **kwargs) -> Dict[str, Any]: + """Creates a configuration dictionary for a PyTorch optimizer. + + This function returns a dictionary containing the optimizer's configuration, + maintaining a consistent interface with the JAX backend. The user is + expected to instantiate the optimizer from this config. + + Args: + optimizer_class (str): The name of the optimizer to create (e.g., + `"adam"`, `"sgd"`). + **kwargs: Keyword arguments for the optimizer (e.g., `learning_rate`). + + Returns: + Dict[str, Any]: A dictionary representing the optimizer configuration. + """ + config = kwargs.copy() + config["name"] = optimizer_class.lower() + config.setdefault("learning_rate", 0.001) + return config + + +def get_device_info() -> Dict[str, Any]: + """Retrieves information about the available PyTorch devices. + + Returns: + Dict[str, Any]: A dictionary containing the backend name, a list of + available device strings, and the total device count. + """ + if torch.cuda.is_available(): + device_count = torch.cuda.device_count() + devices = [torch.cuda.get_device_name(i) for i in range(device_count)] + else: + device_count = 1 + devices = ["cpu"] + return { + "backend": "pytorch", + "devices": devices, + "device_count": device_count, + } + + +def is_multi_device_capable() -> bool: + """Checks if more than one CUDA device is available. + + Returns: + bool: `True` if PyTorch reports more than one CUDA device, `False` + otherwise. + """ + return torch.cuda.device_count() > 1 + + +def get_communication_ops() -> Dict[str, Callable]: + """Provides a dictionary of PyTorch collective communication operations. + + These operations rely on the `torch.distributed` package. They are + designed to work in a multi-process, multi-device environment. If the + distributed package is not initialized, they provide a sensible fallback + for single-device execution. + + Returns: + Dict[str, Callable]: A dictionary mapping operation names to their + PyTorch implementations. + """ + + def _is_distributed() -> bool: + """Checks if the default process group is initialized.""" + return dist.is_available() and dist.is_initialized() + + def all_reduce( + x: torch.Tensor, + op: Literal["sum", "mean"] = "sum", + axis_name: str = None, + ) -> torch.Tensor: + """Reduces a tensor across all devices.""" + if not _is_distributed(): + world_size = ( + torch.cuda.device_count() if torch.cuda.is_available() else 1 + ) + if world_size <= 1: + return x + if op == "sum": + return x * float(world_size) + elif op == "mean": + return x + else: + raise ValueError(f"Unsupported all_reduce op: {op}") + + reduce_op = {"sum": dist.ReduceOp.SUM, "mean": dist.ReduceOp.AVG}.get( + op + ) + if reduce_op is None: + raise ValueError(f"Unsupported all_reduce op: {op}") + + result = x.clone() + dist.all_reduce(result, op=reduce_op) + return result + + def all_gather( + x: torch.Tensor, axis: int = 0, axis_name: str = None + ) -> torch.Tensor: + """Gathers tensors from all devices and concatenates them.""" + if not _is_distributed(): + world_size = ( + torch.cuda.device_count() if torch.cuda.is_available() else 1 + ) + if world_size <= 1: + return x + return torch.cat([x] * world_size, dim=axis) + + world_size = dist.get_world_size() + tensor_list = [torch.empty_like(x) for _ in range(world_size)] + dist.all_gather(tensor_list, x) + return torch.cat(tensor_list, dim=axis) + + def broadcast( + x: torch.Tensor, root: int = 0, axis_name: str = None + ) -> torch.Tensor: + """Broadcasts a tensor from a root device to all other devices.""" + if not _is_distributed(): + return x + + dist.broadcast(x, src=root) + return x + + def scatter( + x: torch.Tensor, + root: int = 0, + axis: int = 0, + axis_name: str = None, + ) -> torch.Tensor: + """Scatters a tensor from a root device to all devices.""" + if not _is_distributed(): + world_size = ( + torch.cuda.device_count() if torch.cuda.is_available() else 1 + ) + if world_size <= 1: + return x + if x.shape[axis] % world_size != 0: + raise ValueError( + f"Tensor with shape {x.shape} cannot be scattered along " + f"axis {axis} across {world_size} devices." + ) + return torch.chunk(x, world_size, dim=axis)[0] + + world_size = dist.get_world_size() + rank = dist.get_rank() + + if x.shape[axis] % world_size != 0: + raise ValueError( + f"Tensor with shape {x.shape} cannot be scattered along " + f"axis {axis} across {world_size} devices." + ) + + if rank == root: + scatter_list = list(torch.chunk(x, world_size, dim=axis)) + else: + scatter_list = None + + chunk_shape = list(x.shape) + chunk_shape[axis] //= world_size + local_chunk = torch.empty(chunk_shape, dtype=x.dtype, device=x.device) + + dist.scatter(local_chunk, scatter_list, src=root) + return local_chunk + + return { + "all_reduce": all_reduce, + "all_gather": all_gather, + "broadcast": broadcast, + "scatter": scatter, + } diff --git a/keras/src/backend/torch/distributed_backend_test.py b/keras/src/backend/torch/distributed_backend_test.py new file mode 100644 index 000000000000..d6dce5977b81 --- /dev/null +++ b/keras/src/backend/torch/distributed_backend_test.py @@ -0,0 +1,129 @@ +import pytest +import torch + +from keras.src import backend +from keras.src.backend import distributed_backend + + +@pytest.mark.skipif( + backend.backend() != "torch", + reason="Jax Backend specific test", +) +class TestPytorchDistributedFunctions: + """Unit tests for the PyTorch distributed backend standalone functions.""" + + def test_compute_gradients_computes_correctly(self): + """Test that compute_gradients returns correct gradients.""" + w = torch.tensor([2.0, 3.0], requires_grad=True) + b = torch.tensor(1.0, requires_grad=True) + x = torch.tensor([4.0, 5.0]) + y_true = torch.tensor(25.0) + + # loss = (w.x + b - y_true)^2 = ((2*4 + 3*5 + 1) - 25)^2 = (24-25)^2 = 1 + y_pred = torch.dot(w, x) + b + loss = (y_pred - y_true) ** 2 + + trainable_vars = [w, b] + gradients = distributed_backend.compute_gradients(loss, trainable_vars) + + # d_loss/d_w = 2*(y_pred - y_true)*x = 2*(-1)*[4, 5] = [-8, -10] + # d_loss/d_b = 2*(y_pred - y_true)*1 = 2*(-1)*1 = -2 + expected_grad_w = torch.tensor([-8.0, -10.0]) + expected_grad_b = torch.tensor(-2.0) + + assert len(gradients) == 2 + torch.testing.assert_close(gradients[0], expected_grad_w) + torch.testing.assert_close(gradients[1], expected_grad_b) + + def test_apply_gradients(self): + """Test the application of gradients to PyTorch tensors.""" + var1 = torch.tensor([1.0, 2.0], requires_grad=True) + var2 = torch.tensor(5.0, requires_grad=True) + trainable_vars = [var1, var2] + grad1 = torch.tensor([0.1, 0.2]) + grad2 = torch.tensor(0.5) + gradients = [grad1, grad2] + learning_rate = 0.1 + + original_var1 = var1.clone() + original_var2 = var2.clone() + + updated_vars = distributed_backend.apply_gradients( + gradients, trainable_vars, learning_rate + ) + + assert updated_vars[0] is var1 + assert updated_vars[1] is var2 + + expected_var1 = original_var1 - (grad1 * learning_rate) + expected_var2 = original_var2 - (grad2 * learning_rate) + torch.testing.assert_close(updated_vars[0], expected_var1) + torch.testing.assert_close(updated_vars[1], expected_var2) + + def test_create_optimizer(self): + """Test optimizer configuration creation.""" + adam_config = distributed_backend.create_optimizer( + "adam", learning_rate=0.01 + ) + assert isinstance(adam_config, dict) + assert adam_config["name"] == "adam" + assert adam_config["learning_rate"] == 0.01 + + sgd_config = distributed_backend.create_optimizer( + "sgd", learning_rate=0.1, momentum=0.9 + ) + assert isinstance(sgd_config, dict) + assert sgd_config["name"] == "sgd" + assert sgd_config["learning_rate"] == 0.1 + assert sgd_config["momentum"] == 0.9 + + def test_get_device_info(self): + """Test retrieving device information from the PyTorch backend.""" + info = distributed_backend.get_device_info() + assert info["backend"] == "pytorch" + assert isinstance(info["devices"], list) + assert isinstance(info["device_count"], int) + assert info["device_count"] > 0 + assert len(info["devices"]) == info["device_count"] + if torch.cuda.is_available(): + assert info["device_count"] == torch.cuda.device_count() + else: + assert info["device_count"] == 1 + assert info["devices"] == ["cpu"] + + def test_is_multi_device_capable(self): + """Test the boolean check for multi-device capability.""" + assert isinstance(distributed_backend.is_multi_device_capable(), bool) + + def test_communication_ops_simulation_logic(self): + """Test the simulated communication ops in a single-device context.""" + comm_ops = distributed_backend.get_communication_ops() + device_info = distributed_backend.get_device_info() + world_size = device_info.get("device_count", 1) + + # Test all_reduce + x_reduce = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + reduced = comm_ops["all_reduce"](x_reduce, op="sum") + expected_reduce = ( + x_reduce * float(world_size) if world_size > 1 else x_reduce + ) + torch.testing.assert_close(reduced, expected_reduce) + + # Test all_gather + x_gather = torch.tensor([[1.0, 2.0]]) + gathered = comm_ops["all_gather"](x_gather, axis=0) + expected_gather = torch.cat([x_gather] * world_size, dim=0) + torch.testing.assert_close(gathered, expected_gather) + + # Test broadcast + x_broadcast = torch.tensor([5.0, 6.0]) + broadcasted = comm_ops["broadcast"](x_broadcast) + torch.testing.assert_close(broadcasted, x_broadcast) + + # Test scatter + if world_size > 0: + scatter_data = torch.arange(world_size * 4, dtype=torch.float32) + x_scatter = scatter_data.reshape(world_size * 2, 2) + scattered = comm_ops["scatter"](x_scatter, axis=0) + expected_scatter = torch.chunk(x_scatter, world_size, dim=0)[0] + torch.testing.assert_close(scattered, expected_scatter) diff --git a/keras/src/backend/torch/distribution_lib.py b/keras/src/backend/torch/distribution_lib.py new file mode 100644 index 000000000000..0d8c18de4bf7 --- /dev/null +++ b/keras/src/backend/torch/distribution_lib.py @@ -0,0 +1,413 @@ +"""Utilities for distribution strategy with Torch backend. + +This file contains the core Torch distribution primitives from Keras, +along with higher-level device management and auto-configuration utilities. +This version does not use try-except blocks for error handling. +""" + +import logging +import os +from typing import Dict +from typing import List +from typing import Optional + +import numpy as np +import torch +import torch.distributed as dist + +from keras.src.backend.common import global_state +from keras.src.random import seed_generator +from keras.src.utils import rng_utils + +logger = logging.getLogger(__name__) + + +def list_devices(device_type=None): + """Return all the available devices based on the device type. + + Note that this should return the global devices in a distributed setting. + + Args: + device_type: string of `"cpu"`, `"gpu"`. Defaults to `"gpu"` if + available when device_type is not provided. Otherwise will return + the `"cpu"` devices. `"tpu"` is not supported by the default + torch backend. + + Return: + List of devices that are available for distribute computation. + """ + if device_type: + device_type = device_type.lower() + else: + device_type = "cuda" if torch.cuda.is_available() else "cpu" + + if device_type in ("gpu", "cuda"): + if not torch.cuda.is_available(): + return [] + return [f"cuda:{i}" for i in range(torch.cuda.device_count())] + elif device_type == "cpu": + return ["cpu:0"] + elif device_type == "tpu": + logger.warning( + "TPU device type is not supported by the default " + "PyTorch backend. Use the `torch_xla` package." + ) + return [] + raise ValueError(f"Unknown device type: {device_type}") + + +def get_device_info(device_id: str) -> Dict[str, any]: + """ + Get detailed information about a specific device. + + Args: + device_id: Device identifier (e.g., 'cuda:0', 'cpu:0') + + Returns: + Dictionary containing device information + """ + device_info = { + "id": device_id, + "type": None, + "index": None, + "memory": None, + "capabilities": None, + } + + device_type, device_index = device_id.split(":") + device_type_map = {"cuda": "GPU", "cpu": "CPU"} + device_info["type"] = device_type_map.get(device_type, device_type.upper()) + device_info["index"] = int(device_index) + + return device_info + + +def get_best_devices(count: int = 1) -> List[str]: + """ + Get the best available devices for tensor parallelism. + + Args: + count: Number of devices needed + + Returns: + List of best device identifiers + """ + all_devices = list_devices("cuda") + if not all_devices: + all_devices = list_devices("cpu") + + if count <= 0: + return [] + + if count > len(all_devices): + logger.warning( + f"Requested {count} devices but only {len(all_devices)} available" + ) + count = len(all_devices) + + return all_devices[:count] + + +def get_device_backend(device_type: str) -> str: + """ + Get the recommended backend for a device type. + + Args: + device_type: Device type ('tpu', 'gpu', 'cpu') + + Returns: + Recommended backend name + """ + backend_mapping = {"gpu": "torch", "cuda": "torch", "cpu": "torch"} + + return backend_mapping.get(device_type.lower(), "torch") + + +def validate_device_placement(device_id: str) -> bool: + """ + Validate if a device can be used for tensor operations. + + Args: + device_id: Device identifier + + Returns: + True if device is valid and available + """ + if ":" not in device_id: + return False + + device_type = device_id.split(":")[0] + known_device_types = ("cpu", "gpu", "cuda", "tpu") + if device_type not in known_device_types: + return False + + all_devices = list_devices(device_type) + return device_id in all_devices + + +def get_device_memory_info(device_id: str) -> Optional[Dict[str, any]]: + """ + Get memory information for a device (if available). + + Args: + device_id: Device identifier + + Returns: + Memory information dictionary or None if not available + """ + if device_id.startswith("cuda:"): + return { + "type": "GPU", + "index": int(device_id.split(":")[1]), + "memory": "Available", + } + elif device_id.startswith("cpu:"): + return { + "type": "CPU", + "index": int(device_id.split(":")[1]), + "memory": "System RAM", + } + + return None + + +def auto_configure_tensor_parallel( + world_size: int = None, backend: str = None +) -> Dict[str, any]: + """ + Automatically configure tensor parallelism with the best available devices. + + Args: + world_size: Number of devices to use (if None, uses all available GPUs) + backend: Backend to use (if None, will be set to 'torch') + + Returns: + Configuration dictionary with devices, backend, and other settings + """ + all_devices = list_devices() + + if not all_devices: + raise RuntimeError("No devices available for tensor parallelism") + + if world_size is None: + world_size = len(all_devices) + else: + world_size = min(world_size, len(all_devices)) + + selected_devices = all_devices[:world_size] + + recommended_backend = "torch" + + config = { + "devices": selected_devices, + "world_size": world_size, + "backend": recommended_backend, + } + + logger.info(f"Auto-configured tensor parallelism: {config}") + return config + + +def distribute_variable(value, layout): + """Create a distributed variable for PyTorch. + + This function creates a `torch.Tensor` distributed according to the given + layout. In PyTorch, variables and tensors are unified in the `Tensor` class. + + Args: + value: The initial value of the variable as a `torch.Tensor`. + layout: `TensorLayout` for the created variable, or a PyTorch-supported + layout instance (e.g., a list of `Placement` types). + + Returns: + `torch.Tensor` which is the distributed variable. + """ + return distribute_tensor(value, layout) + + +def distribute_tensor(tensor, layout): + """Distribute the tensor based on the layout. + + Args: + tensor: `torch.Tensor` that needs to be distributed. + layout: `TensorLayout` for the created variable, or a PyTorch-supported + layout instance (e.g., a list of `Placement` types). + + Returns: + Distributed `torch.Tensor`. + """ + # Avoid circular imports. + from keras.src.distribution import TensorLayout + + if isinstance(layout, TensorLayout): + placements = layout.backend_layout + device_mesh = layout.device_mesh.backend_mesh + else: + raise ValueError( + "Directly passing backend layout is not yet supported for torch. " + "Please provide a `keras.distribution.TensorLayout` instance." + ) + + return dist.dtensor.distribute_tensor( + tensor.to("cpu"), device_mesh, placements + ) + + +def distribute_data_input(per_process_batch, layout, batch_dim_name): + """Distribute the input data with the corresponding layout. + + Note that the input here is a local worker batch. PyTorch's `from_local` + is used to construct a global DTensor from these local shards. + + Args: + per_process_batch: `torch.Tensor` that is local shard for this process. + layout: `TensorLayout` for the distribution information. + + Returns: + A global batch distributed according to `layout`. + """ + from keras.src.distribution import TensorLayout + + if not isinstance(layout, TensorLayout): + raise ValueError( + "A `keras.distribution.TensorLayout` instance is required." + ) + + placements = layout.backend_layout + device_mesh = layout.device_mesh.backend_mesh + return dist.dtensor.from_local( + per_process_batch, device_mesh, placements, run_check=True + ) + + +def initialize_rng(): + """Initializes the global random number generator across processes. + + This is required for consistent initialization in multi-host settings. + It works by generating a seed on rank 0 and broadcasting it to all other + processes. + """ + global_seed = rng_utils.get_random_seed() + if global_seed is None: + if not dist.is_initialized(): + seed = seed_generator.make_default_seed() + else: + if process_id() == 0: + seed = seed_generator.make_default_seed() + seed_tensor = torch.tensor( + seed, dtype=torch.int64, device="cpu" + ) + else: + seed_tensor = torch.empty(1, dtype=torch.int64, device="cpu") + dist.broadcast(seed_tensor, src=0) + seed = seed_tensor.item() + global_seed = seed + rng_utils.set_random_seed(global_seed) + + global_seed_generator = global_state.get_global_attribute( + "global_seed_generator" + ) + if global_seed_generator is not None and global_seed_generator.seed is None: + global_state.set_global_attribute( + "global_seed_generator", + seed_generator.SeedGenerator( + seed=global_seed, + name=global_seed_generator.name, + backend=global_seed_generator.backend, + ), + ) + + +def initialize(job_addresses, num_processes, process_id): + """Initializes the distributed process group in PyTorch.""" + os.environ["RANK"] = str(process_id) + os.environ["WORLD_SIZE"] = str(num_processes) + + if "," in job_addresses: + master_addr = job_addresses.split(",")[0] + else: + master_addr = job_addresses + + if ":" not in master_addr: + raise ValueError( + "Invalid `job_addresses`. Expected format `hostname:port`, " + f"but got {master_addr}" + ) + + master_host, master_port = master_addr.split(":") + os.environ["MASTER_ADDR"] = master_host + os.environ["MASTER_PORT"] = master_port + + backend = "nccl" if torch.cuda.is_available() else "gloo" + dist.init_process_group(backend=backend) + + initialize_rng() + + +def num_processes(): + """Return the number of processes for the current distribution setting.""" + if dist.is_initialized(): + return dist.get_world_size() + return 1 + + +def process_id(): + """Return the current process ID for the distribution setting.""" + if dist.is_initialized(): + return dist.get_rank() + return 0 + + +def _to_backend_device(device_name): + if isinstance(device_name, torch.device): + return device_name + return torch.device(device_name) + + +def _to_backend_mesh(device_mesh): + """Convert the DeviceMesh to Torch backend specific Mesh. + + Args: + device_mesh: DeviceMesh instance to convert. + + Returns: + A `torch.distributed.DeviceMesh` instance. + """ + mesh_shape = device_mesh.devices.shape + mesh_devices = np.array(device_mesh.devices.flatten()).reshape(mesh_shape) + return dist.DeviceMesh( + device_type="cuda" if torch.cuda.is_available() else "cpu", + mesh=mesh_devices, + ) + + +def _to_backend_layout(tensor_layout): + """Convert the TensorLayout to Torch backend specific placement. + + Args: + tensor_layout: TensorLayout instance to convert. + + Returns: + A list of `torch.distributed.placement_types.Placement` instances. + """ + if tensor_layout.device_mesh is None: + raise ValueError( + "Cannot create sharding when device mesh is not set " + "for TensorLayout." + ) + + mesh_axes = tensor_layout.device_mesh.axis_names + placements = [] + for axis in tensor_layout.axes: + if axis is None: + placements.append(dist.Replicate()) + else: + try: + mesh_dim = mesh_axes.index(axis) + placements.append(dist.Shard(mesh_dim)) + except ValueError: + raise ValueError( + f"Tensor axis `{axis}` is not found in the " + f"device mesh axes `{mesh_axes}`." + ) from None + return placements diff --git a/keras/src/backend/torch/distribution_lib_test.py b/keras/src/backend/torch/distribution_lib_test.py new file mode 100644 index 000000000000..2897b022a0d4 --- /dev/null +++ b/keras/src/backend/torch/distribution_lib_test.py @@ -0,0 +1,160 @@ +import os + +import numpy as np +import pytest +import torch +import torch.distributed as dist + +from keras.src import backend +from keras.src.backend import distribution_lib +from keras.src.distribution import DeviceMesh +from keras.src.distribution import TensorLayout + + +@pytest.mark.skipif( + backend.backend() != "torch", + reason="Backend specific test", +) +def setup_torch_distributed(): + """ + A fixture to initialize the distributed process group if not already done. + This allows test file to be run directly with `pytest` for single-process + checks, while also working correctly when launched with `torchrun`. + """ + if not dist.is_available() or dist.is_initialized(): + return + + os.environ.setdefault("MASTER_ADDR", "localhost") + os.environ.setdefault("MASTER_PORT", "29500") + os.environ.setdefault("RANK", "0") + os.environ.setdefault("WORLD_SIZE", "1") + dist.init_process_group(backend="gloo") + + +@pytest.mark.skipif( + not torch.distributed.is_available(), + reason="PyTorch distributed components are not available.", +) +class TestTorchDistributionLibLive: + """ + Tests for the Torch distribution library without using mocks. + These tests will reflect the capabilities of environment they are run in. + """ + + def test_device_listing_and_info(self): + """Tests device discovery functions against the runtime environment.""" + if torch.cuda.is_available(): + gpu_devices = distribution_lib.list_devices("gpu") + assert len(gpu_devices) == torch.cuda.device_count() + assert gpu_devices[0] == "cuda:0" + else: + assert distribution_lib.list_devices("gpu") == [] + + cpu_devices = distribution_lib.list_devices("cpu") + assert cpu_devices == ["cpu:0"] + + with pytest.raises(ValueError, match="Unknown device type"): + distribution_lib.list_devices("unsupported_device") + + def test_device_helpers(self): + """Tests validation, backend, and memory info functions.""" + device_str = "cpu:0" + if torch.cuda.is_available(): + device_str = "cuda:0" + + assert distribution_lib.validate_device_placement(device_str) is True + assert distribution_lib.validate_device_placement("invalid:0") is False + + assert distribution_lib.get_device_backend("cpu") == "torch" + assert distribution_lib.get_device_backend("gpu") == "torch" + + mem_info = distribution_lib.get_device_memory_info(device_str) + assert mem_info is not None + assert "type" in mem_info + assert mem_info["index"] == 0 + + def test_process_discovery(self): + """Tests process_id and num_processes in the live environment.""" + rank = distribution_lib.process_id() + world_size = distribution_lib.num_processes() + + if dist.is_initialized(): + assert rank == dist.get_rank() + assert world_size == dist.get_world_size() + else: + assert rank == 0 + assert world_size == 1 + + def test_backend_conversions(self): + """Tests the conversion of Keras objects to Torch backend objects.""" + world_size = distribution_lib.num_processes() + if world_size < 2: + pytest.skip( + "Skipping conversion tests in a single-process environment." + ) + + devices = [f"cpu:{i}" for i in range(world_size)] + shape = (world_size,) + axis_names = ("data",) + keras_mesh = DeviceMesh(shape, axis_names, devices) + + torch_mesh = distribution_lib._to_backend_mesh(keras_mesh) + assert isinstance(torch_mesh, dist.DeviceMesh) + assert torch_mesh.mesh.shape == shape + + keras_layout = TensorLayout(axes=("data",), device_mesh=keras_mesh) + placements = distribution_lib._to_backend_layout(keras_layout) + assert isinstance(placements[0], dist.Shard) + + keras_layout_replicated = TensorLayout( + axes=(None,), device_mesh=keras_mesh + ) + placements_replicated = distribution_lib._to_backend_layout( + keras_layout_replicated + ) + assert isinstance(placements_replicated[0], dist.Replicate) + + def test_tensor_distribution(self): + """Tests the distribution of a tensor into a DTensor.""" + if not dist.is_initialized() or distribution_lib.num_processes() < 2: + pytest.skip( + "Tensor distribution test requires a multi-process environment." + ) + + world_size = distribution_lib.num_processes() + devices = np.arange(world_size) + keras_mesh = DeviceMesh((world_size,), ("batch",), devices) + keras_layout = TensorLayout(("batch", None), keras_mesh) + + local_tensor = torch.randn((10, 20)) + + dtensor = distribution_lib.distribute_tensor(local_tensor, keras_layout) + assert isinstance(dtensor, torch.distributed.dtensor.DTensor) + assert dtensor.device_mesh.mesh.shape == (world_size,) + assert isinstance(dtensor.placements[0], dist.Shard) + + dvariable = distribution_lib.distribute_variable( + local_tensor, keras_layout + ) + assert isinstance(dvariable, torch.distributed.dtensor.DTensor) + + def test_distribute_data_input(self): + """Tests the `from_local` logic for distributing input data.""" + if not dist.is_initialized() or distribution_lib.num_processes() < 2: + pytest.skip( + "Input distribution test requires a multi-process environment." + ) + + world_size = distribution_lib.num_processes() + devices = np.arange(world_size) + keras_mesh = DeviceMesh((world_size,), ("batch",), devices) + keras_layout = TensorLayout(("batch", None), keras_mesh) + + per_process_batch = torch.ones((8, 16)) + + global_batch = distribution_lib.distribute_data_input( + per_process_batch, keras_layout, batch_dim_name="batch" + ) + + assert isinstance(global_batch, torch.distributed.dtensor.DTensor) + assert global_batch.shape == (world_size * 8, 16) diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py new file mode 100644 index 000000000000..470c00774660 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -0,0 +1,245 @@ +import os + +import pytest + +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" + +from keras import Input +from keras import Model +from keras import layers +from keras.src import backend +from keras.src import testing +from keras.src.distribution import distributed_backend +from keras.src.distribution.tensor_parallel.autoconfig import ( + analyze_dense_layer_directly, +) +from keras.src.distribution.tensor_parallel.autoconfig import ( + get_default_config_keras, +) +from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras + + +@pytest.mark.skipif( + backend.backend() not in ("torch", "jax") + or distributed_backend.get_device_info()["device_count"] <= 1, + reason="This test is for JAX/PyTorch backends and requires > 1 device.", +) +class TestAutoConfigKeras(testing.TestCase): + def setUp(self): + """Set up the test case and common variables.""" + super().setUp() + device_info = distributed_backend.get_device_info() + self.world_size = device_info["device_count"] + self.device_ids = [f"cpu:{i}" for i in range(self.world_size)] + + self.assertGreater( + self.world_size, 1, "Distribution tests require more than 1 device." + ) + + def _assert_split_keras_equal(self, rule1, rule2): + """Helper to compare two SplitKeras objects by their attributes.""" + self.assertIsInstance(rule1, SplitKeras) + self.assertIsInstance(rule2, SplitKeras) + self.assertDictEqual(vars(rule1), vars(rule2)) + + def _assert_rules_equal(self, actual_rules, expected_rules): + """Helper to compare two dictionaries of sharding rules.""" + self.assertSetEqual( + set(actual_rules.keys()), set(expected_rules.keys()) + ) + for key in expected_rules: + actual_val = actual_rules[key] + expected_val = expected_rules[key] + if isinstance(expected_val, SplitKeras): + self._assert_split_keras_equal(actual_val, expected_val) + else: + self.assertEqual(actual_val, expected_val) + + def test_analyze_dense_layer(self): + """Tests the direct analysis and classification of Dense layers.""" + up_proj_layer = layers.Dense(32) + up_proj_layer.build(input_shape=(None, 16)) + self.assertEqual( + analyze_dense_layer_directly(up_proj_layer, None, ""), + "up_projection", + ) + + down_proj_layer = layers.Dense(16) + down_proj_layer.build(input_shape=(None, 32)) + self.assertEqual( + analyze_dense_layer_directly(down_proj_layer, None, ""), + "down_projection", + ) + + generic_layer = layers.Dense(20) + generic_layer.build(input_shape=(None, 16)) + self.assertEqual( + analyze_dense_layer_directly(generic_layer, None, ""), + "generic_dense", + ) + + def test_simple_mlp_sharding(self): + """Tests a simple MLP with up and down projection layers.""" + inputs = Input(shape=(64,)) + x = layers.Dense(256, name="up_projection_layer", use_bias=True)(inputs) + outputs = layers.Dense(64, name="down_projection_layer", use_bias=True)( + x + ) + model = Model(inputs=inputs, outputs=outputs, name="simple_mlp") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + r"^simple_mlp.up_projection_layer.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^simple_mlp.up_projection_layer.bias$": SplitKeras( + self.world_size, 0, "column" + ), + r"^simple_mlp.down_projection_layer.kernel$": SplitKeras( + self.world_size, 0, "row" + ), + # Bias for down-projection is not sharded according to the new logic + } + expected_output_rules = { + r"^simple_mlp.up_projection_layer$": {0: "gather"}, + r"^simple_mlp.down_projection_layer$": {0: "allreduce"}, + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + + def test_generic_dense_sharding(self): + """Tests a generic Dense layer that isn't an up/down projection.""" + inputs = Input(shape=(64,)) + outputs = layers.Dense(80, name="generic_layer", use_bias=True)(inputs) + model = Model(inputs=inputs, outputs=outputs, name="generic_model") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + r"^generic_model.generic_layer.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^generic_model.generic_layer.bias$": SplitKeras( + self.world_size, 0, "column" + ), + } + expected_output_rules = { + r"^generic_model.generic_layer$": {0: "gather -1"} + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + + def test_embedding_sharding(self): + """Tests an Embedding layer for vocabulary parallelism.""" + inputs = Input(shape=(10,), dtype="int32") + outputs = layers.Embedding( + input_dim=1000, output_dim=128, name="token_embedding" + )(inputs) + model = Model(inputs=inputs, outputs=outputs, name="embed_model") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + # FIX: Removed the incorrect backslash before ".token_embedding" + r"^embed_model.token_embedding\..*embeddings$": SplitKeras( + self.world_size, 1, "column" + ) + } + expected_output_rules = { + r"^embed_model.token_embedding$": {0: "no_comm"} + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + + def test_einsum_dense_sharding(self): + """Tests the special handling for EinsumDense layers.""" + inputs = Input(shape=(64,)) + x = layers.EinsumDense( + "bh,hd->bd", output_shape=128, name="query_proj" + )(inputs) + outputs = layers.EinsumDense( + "bd,dh->bh", output_shape=64, name="attention_output" + )(x) + model = Model(inputs=inputs, outputs=outputs, name="einsum_model") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + r"^einsum_model.query_proj.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^einsum_model.attention_output.kernel$": SplitKeras( + self.world_size, 0, "row" + ), + } + expected_output_rules = { + r"^einsum_model.query_proj$": {0: "gather -1"}, + r"^einsum_model.attention_output$": {0: "allreduce"}, + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + + def test_normalization_layers_ignored(self): + """Tests that normalization layers are correctly ignored.""" + inputs = Input(shape=(64,)) + x = layers.Dense(64, name="dense1", use_bias=True)(inputs) + x = layers.LayerNormalization(name="layernorm")(x) + outputs = layers.Dense(64, name="dense2", use_bias=True)(x) + model = Model(inputs=inputs, outputs=outputs, name="norm_model") + + config = get_default_config_keras(model, self.device_ids) + + for key in config.state_rules: + self.assertNotIn("layernorm", key) + for key in config.output_rules: + self.assertNotIn("layernorm", key) + + self.assertIn(r"^norm_model.dense1.kernel$", config.state_rules) + self.assertIn(r"^norm_model.dense2.kernel$", config.state_rules) + self.assertEqual(len(config.state_rules), 4) + self.assertEqual(len(config.output_rules), 2) + + def test_nested_model_sharding(self): + """Tests that the traversal logic correctly handles nested models.""" + inner_inputs = Input(shape=(32,)) + inner_outputs = layers.Dense(128, name="inner_dense", use_bias=True)( + inner_inputs + ) + inner_model = Model( + inputs=inner_inputs, outputs=inner_outputs, name="inner_block" + ) + + outer_inputs = Input(shape=(32,)) + x = inner_model(outer_inputs) + outer_outputs = layers.Dense(32, name="outer_dense", use_bias=True)(x) + outer_model = Model( + inputs=outer_inputs, outputs=outer_outputs, name="outer_model" + ) + + config = get_default_config_keras(outer_model, self.device_ids) + + expected_state_rules = { + r"^outer_model.inner_block.inner_dense.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^outer_model.inner_block.inner_dense.bias$": SplitKeras( + self.world_size, 0, "column" + ), + r"^outer_model.outer_dense.kernel$": SplitKeras( + self.world_size, 0, "row" + ), + # Bias for down-projection is not sharded according to the new logic + } + expected_output_rules = { + r"^outer_model.inner_block.inner_dense$": {0: "gather"}, + r"^outer_model.outer_dense$": {0: "allreduce"}, + } + + self.maxDiff = None + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) diff --git a/keras/src/distribution/tensor_parallel/communications_test.py b/keras/src/distribution/tensor_parallel/communications_test.py new file mode 100644 index 000000000000..1b6d95d37664 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/communications_test.py @@ -0,0 +1,84 @@ +import pytest + +import keras +from keras.src import backend +from keras.src import testing +from keras.src.backend import distributed_backend +from keras.src.distribution.tensor_parallel.communications import AllGatherKeras +from keras.src.distribution.tensor_parallel.communications import AllReduceKeras +from keras.src.distribution.tensor_parallel.communications import BroadcastKeras +from keras.src.distribution.tensor_parallel.communications import ( + TensorParallelCommunicator, +) + + +@pytest.mark.skipif( + backend.backend() not in ("torch", "jax"), + reason="This test is for JAX/PyTorch backends.", +) +class TestCollectiveOps(testing.TestCase): + """ + Tests collective communication ops on a JAX distributed backend. + """ + + def setUp(self): + super().setUp() + device_info = distributed_backend.get_device_info() + self.world_size = device_info.get("device_count", 1) + + if not self.world_size: + self.world_size = 1 + + self.axis_name = "data" + + def test_all_reduce(self): + """Tests the all-reduce operation.""" + all_reduce_op = AllReduceKeras(world_size=self.world_size, op="sum") + local_tensor = keras.ops.array([1.0, 2.0, 3.0]) + + result = all_reduce_op(local_tensor, axis_name=self.axis_name) + + expected_output = keras.ops.multiply( + local_tensor, float(self.world_size) + ) + self.assertAllClose(result, expected_output) + + def test_all_gather(self): + """Tests the all-gather operation.""" + all_gather_op = AllGatherKeras(world_size=self.world_size, dim=0) + local_slice = keras.ops.arange(6, dtype="float32").reshape((2, 3)) + result = all_gather_op(local_slice, axis_name=self.axis_name) + + expected_output = keras.ops.concatenate( + [local_slice] * self.world_size, axis=0 + ) + self.assertAllClose(result, expected_output) + + def test_broadcast(self): + """Tests the broadcast operation.""" + broadcast_op = BroadcastKeras( + world_size=self.world_size, src_rank=0, rank=0 + ) + tensor_to_broadcast = keras.ops.array([5.0, 10.0, 15.0]) + result = broadcast_op(tensor_to_broadcast, axis_name=self.axis_name) + + self.assertAllClose(result, tensor_to_broadcast) + + def test_tensor_parallel_communicator_forward_column_parallel(self): + """Tests the communicator's all-gather for column-parallel forward.""" + communicator = TensorParallelCommunicator( + world_size=self.world_size, rank=0 + ) + + local_slice = keras.ops.array([[0.0, 1.0], [2.0, 3.0]], dtype="float32") + + result = communicator.forward_column_parallel( + partial_outputs=[local_slice], + dim=0, + axis_name=self.axis_name, + ) + + expected_output = keras.ops.concatenate( + [local_slice] * self.world_size, axis=0 + ) + self.assertAllClose(result, expected_output) diff --git a/keras/src/distribution/tensor_parallel/config_test.py b/keras/src/distribution/tensor_parallel/config_test.py new file mode 100644 index 000000000000..cbb26e40e6db --- /dev/null +++ b/keras/src/distribution/tensor_parallel/config_test.py @@ -0,0 +1,96 @@ +import pytest + +from keras.src import backend +from keras.src import testing +from keras.src.distribution.tensor_parallel.communications import AllGatherKeras +from keras.src.distribution.tensor_parallel.communications import AllReduceKeras +from keras.src.distribution.tensor_parallel.communications import BroadcastKeras +from keras.src.distribution.tensor_parallel.config import ConfigKeras +from keras.src.distribution.tensor_parallel.config import _create_ops_from_rules + + +@pytest.mark.skipif( + backend.backend() not in ("torch", "jax"), + reason="This test is for JAX/PyTorch backends.", +) +class TestConfig(testing.TestCase): + """Test suite for the tensor parallel configuration.""" + + def test_create_ops_from_rules_helper(self): + """ + Tests the private _create_ops_from_rules helper function directly + to ensure it correctly parses various rule types. + """ + devices = ["/gpu:0", "/gpu:1"] + world_size = len(devices) + rules = { + "dense/kernel": {"forward": "sum", "backward": "mean"}, + "embedding/weight": { + "forward": "gather 0", + "backward": "gather -1", + }, + "attention/dense/bias": {"forward": "broadcast"}, + "passthrough": {"action": 123}, + "no_dict_action": "identity", + } + + processed_rules = _create_ops_from_rules(rules, world_size) + + sum_op = processed_rules["dense/kernel"]["forward"] + self.assertIsInstance(sum_op, AllReduceKeras) + self.assertEqual(sum_op.op, "sum") + self.assertEqual(sum_op.world_size, world_size) + + mean_op = processed_rules["dense/kernel"]["backward"] + self.assertIsInstance(mean_op, AllReduceKeras) + self.assertEqual(mean_op.op, "mean") + + gather_op_0 = processed_rules["embedding/weight"]["forward"] + self.assertIsInstance(gather_op_0, AllGatherKeras) + self.assertEqual(gather_op_0.dim, 0) + self.assertEqual(gather_op_0.world_size, world_size) + + gather_op_neg1 = processed_rules["embedding/weight"]["backward"] + self.assertIsInstance(gather_op_neg1, AllGatherKeras) + self.assertEqual(gather_op_neg1.dim, -1) + + broadcast_op = processed_rules["attention/dense/bias"]["forward"] + self.assertIsInstance(broadcast_op, BroadcastKeras) + self.assertEqual(broadcast_op.world_size, world_size) + + self.assertEqual(processed_rules["passthrough"]["action"], 123) + self.assertEqual(processed_rules["no_dict_action"], "identity") + + def test_config_keras_create_collective_ops(self): + """ + Tests the public create_collective_ops method of the ConfigKeras class. + """ + devices = ["/gpu:0", "/gpu:1"] + world_size = len(devices) + + state_rules = {"some_weight": "split"} + output_rules = { + "layer_1_output": {"activation": "sum"}, + "layer_2_output": {"activation": "gather -1"}, + } + + config = ConfigKeras(state_rules=state_rules, output_rules=output_rules) + new_config = config.create_collective_ops(devices) + + self.assertIsNot(new_config, config) + + self.assertEqual(new_config.state_rules, state_rules) + + self.assertIsInstance( + config.output_rules["layer_1_output"]["activation"], str + ) + + sum_op = new_config.output_rules["layer_1_output"]["activation"] + self.assertIsInstance(sum_op, AllReduceKeras) + self.assertEqual(sum_op.op, "sum") + self.assertEqual(sum_op.world_size, world_size) + + gather_op = new_config.output_rules["layer_2_output"]["activation"] + self.assertIsInstance(gather_op, AllGatherKeras) + self.assertEqual(gather_op.dim, -1) + self.assertEqual(gather_op.world_size, world_size) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py new file mode 100644 index 000000000000..38d80a5ec258 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -0,0 +1,180 @@ +import numpy as np +import pytest + +import keras +from keras import ops +from keras.src import backend +from keras.src import optimizers +from keras.src import testing +from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + CoordinatedOptimizer, +) +from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + TensorParallelOptimizer, +) + + +@pytest.mark.skipif( + backend.backend() not in ("torch", "jax"), + reason="This test is for JAX/PyTorch backends.", +) +class CoordinatedOptimizerTest(testing.TestCase): + def _get_simple_model(self): + """Creates a simple, uncompiled Keras model.""" + inputs = keras.Input(shape=(10,)) + x = keras.layers.Dense(20, name="dense_1")(inputs) + outputs = keras.layers.Dense(5, name="dense_2")(x) + return keras.Model(inputs, outputs) + + def _get_mock_gradients_and_vars(self, model, world_size): + """Generates mock gradients and variables for N shards.""" + model.build(input_shape=(None, 10)) + variables = model.trainable_variables + grads_and_vars_per_shard = [] + for i in range(world_size): + multiplier = float(i + 1) + gradients = [ + ops.convert_to_tensor( + np.ones_like(v.numpy()) * multiplier, dtype="float32" + ) + for v in variables + ] + grads_and_vars_per_shard.append(list(zip(gradients, variables))) + return grads_and_vars_per_shard + + def test_initialization(self): + """Tests that the optimizer initializes with the correct defaults.""" + base_optimizer = optimizers.Adam() + coord = CoordinatedOptimizer(base_optimizer, world_size=4) + self.assertEqual(coord.base_optimizer, base_optimizer) + self.assertTrue(coord.shard_optimizer_states) + self.assertEqual(coord.sharded_states, {}) + + def test_apply_gradients_with_replicated_states(self): + """Tests that replicated gradients are averaged and applied once.""" + + class AdamWithCallCounter(optimizers.Adam): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.apply_gradients_call_count = 0 + self.received_grads = [] + + def apply_gradients(self, grads_and_vars, *args, **kwargs): + self.apply_gradients_call_count += 1 + self.received_grads = [g for g, v in grads_and_vars] + super().apply_gradients(grads_and_vars, *args, **kwargs) + + world_size = 4 + model = self._get_simple_model() + optimizer = AdamWithCallCounter() + model.build((None, 10)) + mock_grads = self._get_mock_gradients_and_vars(model, world_size) + + coord = CoordinatedOptimizer( + optimizer, + world_size, + shard_optimizer_states=False, + ) + coord.apply_gradients(mock_grads, []) + + self.assertEqual(optimizer.apply_gradients_call_count, 1) + grad_numpy = ops.convert_to_numpy(optimizer.received_grads[0]) + self.assertAllClose( + grad_numpy, + np.ones_like(grad_numpy) * 2.5, + ) + + def test_init_from_string(self): + optimizer = TensorParallelOptimizer("adam", world_size=4) + self.assertIsInstance(optimizer.base_optimizer, optimizers.Adam) + + def test_apply_gradients_delegation(self): + """Tests that apply_gradients correctly delegates.""" + world_size = 4 + base_opt = optimizers.Adam() + optimizer = TensorParallelOptimizer(base_opt, world_size) + model = self._get_simple_model() + mock_grads = self._get_mock_gradients_and_vars(model, world_size) + + coord_apply_tracker = {"called": False} + + def coord_apply_mock(*args, **kwargs): + coord_apply_tracker["called"] = True + + optimizer.coordinated_optimizer.apply_gradients = coord_apply_mock + + base_apply_tracker = {"called": False} + + def base_apply_mock(*args, **kwargs): + base_apply_tracker["called"] = True + + optimizer.base_optimizer.apply_gradients = base_apply_mock + + optimizer.apply_gradients(mock_grads, shard_models=[]) + self.assertTrue(coord_apply_tracker["called"]) + self.assertFalse(base_apply_tracker["called"]) + + coord_apply_tracker["called"] = False + unsharded_grads = mock_grads[0] + optimizer.apply_gradients(unsharded_grads) + self.assertTrue(base_apply_tracker["called"]) + self.assertFalse(coord_apply_tracker["called"]) + + def test_build_and_state_sharding(self): + """Tests that the build method correctly initializes sharded states.""" + optimizer = TensorParallelOptimizer(optimizers.Adam(), world_size=4) + model = self._get_simple_model() + model.build(input_shape=(None, 10)) + + self.assertEqual(optimizer.coordinated_optimizer.sharded_states, {}) + optimizer.build(model.trainable_variables) + self.assertTrue(optimizer.built) + + sharded_states = optimizer.coordinated_optimizer.sharded_states + self.assertIn("momentum", sharded_states) + self.assertIn("velocity", sharded_states) + self.assertIn("iterations", sharded_states) + + dense_1_kernel_path = model.get_layer("dense_1").kernel.path + self.assertIn(dense_1_kernel_path, sharded_states["momentum"]) + self.assertEqual( + len(sharded_states["momentum"][dense_1_kernel_path]), 4 + ) + + def test_serialization(self): + world_size = 4 + base_opt = optimizers.Adam(learning_rate=0.1) + optimizer = TensorParallelOptimizer( + base_opt, world_size, distributed_backend=None + ) + + config = optimizer.get_config() + recreated = TensorParallelOptimizer.from_config(config) + + self.assertEqual(recreated.world_size, world_size) + self.assertIsInstance(recreated.base_optimizer, optimizers.Adam) + self.assertIsNone(recreated.distributed_backend) + self.assertAllClose(recreated.base_optimizer.learning_rate, 0.1) + + def test_sharding_with_prefixed_variable_names(self): + """Tests that state is correctly mapped with prefixed variable names.""" + inputs = keras.Input(shape=(10,)) + x = keras.layers.Dense(4, name="dense")(inputs) + outputs = keras.layers.Dense(2, name="dense_output")(x) + model = keras.Model(inputs, outputs) + model.build(input_shape=(None, 10)) + + optimizer = TensorParallelOptimizer(optimizers.Adam(), world_size=2) + optimizer.build(model.trainable_variables) + + state_to_param = ( + optimizer.coordinated_optimizer._state_variable_to_parameter + ) + self.assertGreater(len(state_to_param), 0) + + dense_output_kernel = model.get_layer("dense_output").kernel + optimizer_name = optimizer.base_optimizer.name + kernel_path = dense_output_kernel.path.replace("/", "_") + momentum_path = f"{optimizer_name}/{kernel_path}_momentum" + + self.assertIs(state_to_param[momentum_path], dense_output_kernel) diff --git a/keras/src/distribution/tensor_parallel/parameter_sharding.py b/keras/src/distribution/tensor_parallel/parameter_sharding.py new file mode 100644 index 000000000000..30a16e9c63fe --- /dev/null +++ b/keras/src/distribution/tensor_parallel/parameter_sharding.py @@ -0,0 +1,669 @@ +import re +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np + +from keras.src.distribution.tensor_parallel.communications import ( + TensorParallelCommunicator, +) +from keras.src.distribution.tensor_parallel.config import ConfigKeras +from keras.src.distribution.tensor_parallel.state_action_keras import ( + StateActionKeras, +) + + +class ShardedWeight: + """A wrapper class for a sharded Keras Variable. + + This class holds a shard of a model weight as a `keras.Variable` and + provides an interface similar to the original variable, allowing it to be + seamlessly integrated into the Keras ecosystem. + + Args: + tensor_shard: The tensor slice (shard) of the weight. + name (str): The name for the underlying `keras.Variable`. + trainable (bool): Whether the variable is trainable. + """ + + def __init__(self, tensor_shard, name, trainable=True): + import keras + + self._variable = keras.Variable( + initializer=tensor_shard, trainable=trainable, name=name + ) + self.regularizer = None + + @property + def name(self) -> str: + """Returns the name of the underlying variable.""" + return self._variable.name + + @property + def trainable(self) -> bool: + """Returns whether the variable is trainable.""" + return self._variable.trainable + + @property + def shape(self) -> Tuple[int, ...]: + """Returns the shape of the variable.""" + return self._variable.shape + + @property + def dtype(self) -> any: + """Returns the dtype of the underlying variable.""" + return self._variable.dtype + + @property + def variable(self): + """Provides direct access to the underlying `keras.Variable`.""" + return self._variable + + def numpy(self) -> np.ndarray: + """Returns the value of the variable as a NumPy array.""" + return self._variable.numpy() + + def num_elements(self) -> int: + """Returns the total number of elements in the tensor.""" + import keras + + return keras.ops.size(self._variable) + + def __repr__(self) -> str: + """Provides a developer-friendly string representation.""" + return ( + f"" + ) + + +class ParameterShardingStrategy: + """Manages the sharding of model parameters for tensor parallelism. + + This strategy identifies weights in a Keras model based on configuration + rules, shards them, and stores the sharded weights and metadata. It's + designed to modify a model's parameters without altering its architecture. + + Args: + world_size (int): The total number of devices (workers) in the + parallel computation group. + rank (int): The unique identifier for the current device (worker), + from 0 to `world_size - 1`. + """ + + def __init__(self, world_size: int, rank: int): + self.world_size = world_size + self.rank = rank + self.sharded_weights = {} # Maps param name to its sharded tensor + self.original_weights = {} # Stores a copy of original weights + self.weight_mapping = {} # Maps param name to sharding info + self.sharded_weights_by_id = {} # Maps original weight ID to shard + + def shard_model_parameters( + self, + model, + config: ConfigKeras, + communicator: TensorParallelCommunicator, + device_id: Any, + ) -> Tuple[Any, set]: + """Shards model parameters and wraps the model for tensor parallelism. + + This method iterates through the model's parameters, applies sharding + rules defined in the config, and creates a `ParameterShardedModel` + which handles the forward pass with necessary communication primitives. + + Args: + model: The original Keras model to be sharded. + config (ConfigKeras): The configuration object containing sharding + rules (`state_rules` and `output_rules`). + communicator (TensorParallelCommunicator): The communicator for + handling cross-device data transfer (e.g., all-gather). + device_id (Any): The device identifier where the model will run. + + Returns: + A tuple containing: + - ParameterShardedModel: The new model wrapped for tensor + parallelism. + - set: A set of names of the parameters that were sharded. + """ + ParameterShardedModel = _define_parameter_sharded_model() + + self._store_original_weights(model) + modified_parameters = set() + + for pattern, action in config.state_rules.items(): + if isinstance(action, StateActionKeras): + matching_params = self._find_matching_parameters(model, pattern) + + for param_name, param in matching_params: + try: + param_id = id(param.experimental_ref()) + except AttributeError: + param_id = id(param) + + if param_id in self.sharded_weights_by_id: + self.sharded_weights[param_name] = ( + self.sharded_weights_by_id[param_id] + ) + + existing_param_name = "unknown" + for name, shard in self.sharded_weights.items(): + if shard is self.sharded_weights_by_id[param_id]: + existing_param_name = name + break + + self.weight_mapping[param_name] = self.weight_mapping[ + existing_param_name + ] + modified_parameters.add(param_name) + continue + + sharded_param = action(param, self.rank) + + self.sharded_weights[param_name] = sharded_param + self.sharded_weights_by_id[param_id] = sharded_param + + self.weight_mapping[param_name] = { + "original_shape": param.shape, + "sharded_shape": sharded_param.shape, + "action": action, + } + + modified_parameters.add(param_name) + + sharded_model = ParameterShardedModel( + original_model=model, + sharding_strategy=self, + communicator=communicator, + config=config, + device_id=device_id, + ) + + return sharded_model, modified_parameters + + def _store_original_weights(self, model): + """Recursively traverses the model and stores original weights.""" + from keras.src import layers + + def find_weights_recursive( + current_layer: layers.Layer, prefix: str = "" + ): + """Helper to recursively find and store weights.""" + name = current_layer.name + full_name = f"{prefix}.{name}" if prefix else name + + if hasattr(current_layer, "weights") and current_layer.weights: + for weight in current_layer.weights: + cleaned_name = weight.name.split("/")[-1].split(":")[0] + param_name = f"{full_name}.{cleaned_name}" + self.original_weights[param_name] = weight.numpy() + + if hasattr(current_layer, "layers") and current_layer.layers: + for sub_layer in current_layer.layers: + find_weights_recursive(sub_layer, full_name) + + for attr_name in dir(current_layer): + if attr_name.startswith("__") and attr_name.endswith("__"): + continue + try: + attr = getattr(current_layer, attr_name) + except Exception: + continue + if isinstance(attr, layers.Layer) and attr is not current_layer: + find_weights_recursive(attr, full_name) + elif isinstance(attr, (list, tuple)): + for item in attr: + if isinstance(item, layers.Layer): + find_weights_recursive(item, full_name) + + for layer in model.layers: + find_weights_recursive(layer, prefix="") + + def _find_matching_parameters( + self, model, pattern: str + ) -> List[Tuple[str, Any]]: + """Finds model parameters whose names match a given regex pattern. + + This method recursively searches through the model's layers and + sub-layers to find all weights, then filters them based on the pattern. + + Args: + model: The Keras model to search within. + pattern (str): A regular expression to match against parameter + names. + + Returns: + A list of tuples, where each tuple contains the parameter's full + name and the parameter object itself. + """ + from keras.src import layers + + matching_params = [] + processed_layers = set() + + def search_layer_recursive( + current_layer: layers.Layer, prefix: str = "" + ): + """Helper to recursively find matching parameters.""" + if id(current_layer) in processed_layers: + return + processed_layers.add(id(current_layer)) + + name = current_layer.name + full_name = f"{prefix}.{name}" if prefix else name + + if hasattr(current_layer, "weights") and current_layer.weights: + for weight in current_layer.weights: + cleaned_weight_name = weight.name.split("/")[-1].split(":")[ + 0 + ] + param_name = f"{full_name}.{cleaned_weight_name}" + + if re.match(pattern, param_name): + matching_params.append((param_name, weight)) + + if hasattr(current_layer, "layers") and current_layer.layers: + for sub_layer in current_layer.layers: + search_layer_recursive(sub_layer, full_name) + + for attr_name in dir(current_layer): + if attr_name.startswith("__") and attr_name.endswith("__"): + continue + + try: + attr = getattr(current_layer, attr_name) + except Exception: + continue + + if isinstance(attr, layers.Layer) and attr is not current_layer: + search_layer_recursive(attr, full_name) + + elif isinstance(attr, (list, tuple)): + for item in attr: + if isinstance(item, layers.Layer): + search_layer_recursive(item, full_name) + + search_layer_recursive(model, prefix="") + + return matching_params + + def get_sharded_weight(self, param_name: str) -> Optional[np.ndarray]: + """Retrieves the sharded weight for a given parameter name. + + Args: + param_name (str): The name of the parameter. + + Returns: + The sharded weight as a NumPy array if it exists, otherwise None. + """ + if param_name in self.sharded_weights: + return self.sharded_weights[param_name].numpy() + return None + + def get_weight_info(self, param_name: str) -> Optional[Dict]: + """Retrieves sharding information for a specific parameter. + + Args: + param_name (str): The name of the parameter. + + Returns: + A dictionary containing metadata about the sharding (e.g., + original shape, sharded shape, action) if it exists, + otherwise None. + """ + return self.weight_mapping.get(param_name) + + +def _define_parameter_sharded_model(): + """Factory function to define and return the ParameterShardedModel class. + + This approach encapsulates the class definition and avoids potential + circular dependencies, while also keeping the related logic organized. + + Returns: + The `ParameterShardedModel` class. + """ + from keras.src.models import Model + + class ParameterShardedModel(Model): + """A Keras Model wrapper for executing a parameter-sharded model. + + This model overrides the `call` and `train_step` methods to inject + the necessary communication operations (e.g., all-reduce, all-gather) + for tensor parallelism during the forward and backward passes. + + Args: + original_model (Model): The original, non-sharded Keras model. + sharding_strategy (ParameterShardingStrategy): The strategy + instance that holds the sharded weights and metadata. + communicator (TensorParallelCommunicator): The object responsible + for inter-device communication. + config (ConfigKeras): The configuration with sharding and + communication rules. + device_id (Any): The identifier of the device this model runs on. + """ + + def __init__( + self, + original_model: Model, + sharding_strategy: ParameterShardingStrategy, + communicator: TensorParallelCommunicator, + config: ConfigKeras, + device_id: Any, + ): + super().__init__() + + self.original_model = original_model + self.sharding_strategy = sharding_strategy + self.config = config + self.communicator = communicator + self._device = device_id + + self._build_and_cache_weights() + + if original_model.inputs: + self.build(original_model.inputs[0].shape) + + @property + def device(self): + """Returns the device identifier for this model instance.""" + return self._device + + def train_step(self, data): + """Custom training step for the parameter-sharded model. + + This method performs a standard forward and backward pass but + adds a crucial gradient synchronization step (`all_reduce`) before + applying gradients. This ensures that each device updates its + local weight shards using gradients computed from all devices. + + Args: + data: A tuple of (x, y, sample_weight) as passed by `fit()`. + + Returns: + A dictionary mapping metric names to their current values. + """ + import tensorflow as tf + + import keras + + x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data) + + with tf.GradientTape() as tape: + y_pred = self(x, training=True) + loss = self.compute_loss( + x=x, y=y, y_pred=y_pred, sample_weight=sample_weight + ) + + trainable_vars = self.trainable_variables + gradients = tape.gradient(loss, trainable_vars) + + synced_gradients = self.communicator.all_reduce( + gradients, op="sum", axis_name="model" + ) + self.optimizer.apply_gradients( + zip(synced_gradients, trainable_vars) + ) + + self.compiled_metrics.update_state(y, y_pred, sample_weight) + + return {m.name: m.result() for m in self.metrics} + + def _build_and_cache_weights(self): + """Constructs a unified list of weights for the model. + + This list includes the custom `ShardedWeight` objects for parameters + that were sharded, and the original `keras.Variable` objects for + those that were not. + """ + weights_list = [] + + sharded_weight_ids = set( + self.sharding_strategy.sharded_weights_by_id.keys() + ) + + for ( + param_name, + sharded_tensor, + ) in self.sharding_strategy.sharded_weights.items(): + weights_list.append(ShardedWeight(sharded_tensor, param_name)) + + unsharded_count = 0 + for weight in self.original_model.weights: + try: + weight_id = id(weight.experimental_ref()) + except AttributeError: + weight_id = id(weight) + + if weight_id not in sharded_weight_ids: + weights_list.append(weight) + unsharded_count += 1 + + self._weights_list = weights_list + + @property + def weights(self): + """Returns the combined list of sharded and non-sharded weights.""" + return self._weights_list + + def call(self, inputs, training=None, mask=None): + """Defines the forward pass of the model. + + This method executes the layers of the original model sequentially. + After each layer's execution, it checks if an output communication + rule applies (e.g., for row-parallel or column-parallel layers) + and triggers the corresponding communication operation. + + Args: + inputs: Input tensor(s). + training (bool): Indicates if the model is in training mode. + mask: A mask or list of masks. + + Returns: + The output tensor of the model. + """ + from keras.src import layers + + tensor_cache = {} + current_tensor = inputs + + for layer in self.original_model.layers: + if isinstance(layer, layers.InputLayer): + continue + + if isinstance(layer, layers.Add): + try: + if "feedforward_output" in layer.name: + residual_source_name = layer.name.replace( + "feedforward_output", "self_attention_output" + ) + elif "self_attention_output" in layer.name: + residual_source_name = layer.name.replace( + "self_attention_output", "input_layer_norm" + ) + else: + residual_source_name = None + + if ( + residual_source_name + and residual_source_name in tensor_cache + ): + layer_inputs = [ + current_tensor, + tensor_cache[residual_source_name], + ] + else: + layer_inputs = [current_tensor, current_tensor] + except Exception: + layer_inputs = [current_tensor, current_tensor] + else: + layer_inputs = current_tensor + + if ( + "attention_output" in layer.name + or "feedforward_output" in layer.name + ): + tensor_cache[layer.name] = current_tensor + + current_tensor = layer(layer_inputs, training=training) + + layer_path = layer.path + + output_rule = None + for pattern, rule in self.config.output_rules.items(): + if re.search(pattern, layer_path): + output_rule = rule.get(0) + break + + if output_rule: + current_tensor = self._apply_communication( + current_tensor, layer.name, output_rule + ) + + return current_tensor + + def _apply_communication(self, sharded_output, layer_name, rule): + """Applies a communication primitive based on a rule. + + Args: + sharded_output: The output tensor from a layer. + layer_name (str): The name of the layer. + rule: The communication rule from the config. + + Returns: + The tensor after the communication operation has been applied. + """ + op_name = str(rule).lower() + + if "sum" in op_name or "allreduce" in op_name: + return self.communicator.forward_row_parallel( + sharded_output, op="sum", axis_name="model" + ) + + elif "gather" in op_name: + try: + dim = int(op_name.split(" ")[-1]) + except (ValueError, IndexError): + dim = -1 + return self.communicator.forward_column_parallel( + sharded_output, dim=dim, axis_name="model" + ) + + elif hasattr(rule, "__call__"): + return rule(sharded_output) + + else: + return sharded_output + + def get_config(self): + """Serializes the model's configuration.""" + return self.original_model.get_config() + + @classmethod + def from_config(cls, config, custom_objects=None): + """Creates a model from its configuration.""" + return cls(**config) + + return ParameterShardedModel + + +def make_parameter_sharded_model( + module, config: ConfigKeras, rank: int, world_size: int, device_id: Any +) -> Tuple[Any, set]: + """Creates a parameter-sharded version of a Keras model. + + This is a high-level factory function that orchestrates the creation of + the communicator, the sharding strategy, and the final sharded model. + + Args: + module: The Keras model to be sharded. + config (ConfigKeras): Configuration object with sharding rules. + rank (int): The rank of the current process. + world_size (int): The total number of processes. + device_id (Any): The device on which the model will be placed. + + Returns: + A tuple containing: + - The newly created `ParameterShardedModel`. + - A set of names for the parameters that were modified. + """ + communicator = TensorParallelCommunicator(world_size=world_size, rank=rank) + sharding_strategy = ParameterShardingStrategy(world_size, rank) + + sharded_model, modified_parameters = ( + sharding_strategy.shard_model_parameters( + module, config, communicator, device_id + ) + ) + + return sharded_model, modified_parameters + + +def apply_parameter_sharding_to_existing_model( + model, config: ConfigKeras, rank: int, world_size: int +): + """Applies parameter sharding directly to an existing model instance. + + This function modifies a model in-place. Instead of returning a new + wrapped model, it shards the weights and attaches the sharding strategy + to the original model object. This is useful when the model's execution + logic is handled externally. + + Args: + model: The Keras model to modify. + config (ConfigKeras): Configuration object with sharding rules. + rank (int): The rank of the current process. + world_size (int): The total number of processes. + + Returns: + The modified model with an attached `_tensor_parallel_sharding` + strategy attribute. + """ + + sharding_strategy = ParameterShardingStrategy(world_size, rank) + for pattern, action in config.state_rules.items(): + if isinstance(action, StateActionKeras): + matching_params = sharding_strategy._find_matching_parameters( + model, pattern + ) + + for param_name, param in matching_params: + try: + param_id = id(param.experimental_ref()) + except AttributeError: + param_id = id(param) + + if param_id in sharding_strategy.sharded_weights_by_id: + sharding_strategy.sharded_weights[param_name] = ( + sharding_strategy.sharded_weights_by_id[param_id] + ) + existing_param_name = next( + k + for k, v in sharding_strategy.sharded_weights.items() + if v + is sharding_strategy.sharded_weights_by_id[param_id] + ) + sharding_strategy.weight_mapping[param_name] = ( + sharding_strategy.weight_mapping[existing_param_name] + ) + continue + + sharded_param = action(param, rank) + + sharding_strategy.sharded_weights[param_name] = sharded_param + sharding_strategy.sharded_weights_by_id[param_id] = ( + sharded_param + ) + + sharding_strategy.weight_mapping[param_name] = { + "original_shape": param.shape, + "sharded_shape": sharded_param.shape, + "action": action, + } + + model._tensor_parallel_sharding = sharding_strategy + + return model diff --git a/keras/src/distribution/tensor_parallel/parameter_sharding_test.py b/keras/src/distribution/tensor_parallel/parameter_sharding_test.py new file mode 100644 index 000000000000..dc686436af97 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/parameter_sharding_test.py @@ -0,0 +1,142 @@ +import os + +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" +import re + +import numpy as np +import pytest + +import keras +from keras import distribution +from keras.src import backend +from keras.src.distribution.tensor_parallel.config import ConfigKeras +from keras.src.distribution.tensor_parallel.parameter_sharding import ( + ShardedWeight, +) +from keras.src.distribution.tensor_parallel.parameter_sharding import ( + make_parameter_sharded_model, +) +from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras +from keras.src.testing import TestCase + + +@pytest.mark.skipif( + backend.backend() not in ("torch", "jax"), + reason="This test is for JAX/PyTorch backends.", +) +def _create_simple_mlp(): + """Creates a simple, unsharded Keras MLP model for testing.""" + inputs = keras.Input(shape=(16,), name="input") + x = keras.layers.Dense(32, use_bias=True, name="up_proj")(inputs) + x = keras.layers.Activation("relu")(x) + outputs = keras.layers.Dense(8, use_bias=False, name="down_proj")(x) + return keras.Model(inputs=inputs, outputs=outputs, name="simple_mlp") + + +class ParameterShardingTest(TestCase): + def setUp(self): + super().setUp() + import logging + + logging.getLogger().setLevel(logging.ERROR) + + self.world_size = 2 + all_devices = distribution.list_devices() + self.devices = all_devices[: self.world_size] + if len(self.devices) < self.world_size: + self.skipTest( + f"""Not enough devices to run TP test. + Found {len(self.devices)}, need {self.world_size}""" + ) + + self.original_model = _create_simple_mlp() + self.original_model.build(input_shape=(None, 16)) + + self.tp_config = ConfigKeras( + state_rules={ + re.escape("simple_mlp.up_proj.kernel"): SplitKeras( + self.world_size, dim=1 + ), + re.escape("simple_mlp.down_proj.kernel"): SplitKeras( + self.world_size, dim=0 + ), + }, + output_rules={}, + ) + self.input_data = np.random.rand(4, 16).astype("float32") + self.labels = np.random.rand(4, 8).astype("float32") + + def test_model_sharding_creation_and_weight_counts(self): + sharded_models = [] + for rank in range(self.world_size): + with keras.device(self.devices[rank]): + sharded_model, modified_params = make_parameter_sharded_model( + self.original_model, + self.tp_config, + rank=rank, + world_size=self.world_size, + device_id=self.devices[rank], + ) + self.assertIsInstance(sharded_model, keras.Model) + self.assertIn("simple_mlp.up_proj.kernel", modified_params) + self.assertIn("simple_mlp.down_proj.kernel", modified_params) + sharded_models.append(sharded_model) + self.assertEqual( + len(self.original_model.weights), len(sharded_models[0].weights) + ) + + def test_sharded_weight_shapes(self): + rank = 0 + with keras.device(self.devices[rank]): + sharded_model, _ = make_parameter_sharded_model( + self.original_model, + self.tp_config, + rank=rank, + world_size=self.world_size, + device_id=self.devices[rank], + ) + original_weights_dict = {w.path: w for w in self.original_model.weights} + sharded_weights_dict = { + w.name if isinstance(w, ShardedWeight) else w.path: w + for w in sharded_model.weights + } + orig_up_kernel = original_weights_dict["up_proj/kernel"] + shard_up_kernel = sharded_weights_dict["simple_mlp.up_proj.kernel"] + self.assertEqual(shard_up_kernel.shape[0], orig_up_kernel.shape[0]) + self.assertEqual( + shard_up_kernel.shape[1], + orig_up_kernel.shape[1] // self.world_size, + ) + orig_down_kernel = original_weights_dict["down_proj/kernel"] + shard_down_kernel = sharded_weights_dict["simple_mlp.down_proj.kernel"] + self.assertEqual( + shard_down_kernel.shape[0], + orig_down_kernel.shape[0] // self.world_size, + ) + self.assertEqual(shard_down_kernel.shape[1], orig_down_kernel.shape[1]) + + def test_forward_pass_correctness(self): + expected_output = self.original_model(self.input_data) + sharded_outputs = [] + original_weights = self.original_model.get_weights() + for rank in range(self.world_size): + with keras.device(self.devices[rank]): + cloned_original = keras.models.clone_model(self.original_model) + cloned_original.set_weights(original_weights) + sharded_model, _ = make_parameter_sharded_model( + cloned_original, + self.tp_config, + rank=rank, + world_size=self.world_size, + device_id=self.devices[rank], + ) + output = sharded_model(self.input_data) + sharded_outputs.append(output) + reconstructed_output = ( + keras.ops.sum(keras.ops.stack(sharded_outputs), axis=0) + / self.world_size + ) + + self.assertAllClose( + expected_output, reconstructed_output, atol=1e-5, rtol=1e-5 + ) diff --git a/keras/src/distribution/tensor_parallel/state_action_keras_test.py b/keras/src/distribution/tensor_parallel/state_action_keras_test.py new file mode 100644 index 000000000000..a6947958a4aa --- /dev/null +++ b/keras/src/distribution/tensor_parallel/state_action_keras_test.py @@ -0,0 +1,109 @@ +import pytest + +import keras +from keras.src import backend +from keras.src import testing +from keras.src.distribution.tensor_parallel.state_action_keras import ( + GatherKeras, +) +from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras +from keras.src.distribution.tensor_parallel.state_action_keras import SumKeras + + +@pytest.mark.skipif( + backend.backend() not in ("torch", "jax"), + reason="This test is for JAX/PyTorch backends.", +) +class TestStateActions(testing.TestCase): + """Test suite for tensor distribution state actions.""" + + def test_split_keras_even_split(self): + """Tests SplitKeras with a tensor that divides evenly.""" + world_size = 4 + tensor = keras.ops.reshape( + keras.ops.arange(16, dtype="float32"), (4, 4) + ) + + action_row = SplitKeras(world_size=world_size, dim=0) + shards_row = [action_row(tensor, rank=i) for i in range(world_size)] + + self.assertEqual(shards_row[0].shape, (1, 4)) + self.assertAllClose(shards_row[0], tensor[0:1, :]) + self.assertAllClose(shards_row[3], tensor[3:4, :]) + + reconstructed_row = action_row.undo(shards_row) + self.assertAllClose(reconstructed_row, tensor) + + action_col = SplitKeras(world_size=world_size, dim=1) + shards_col = [action_col(tensor, rank=i) for i in range(world_size)] + + self.assertEqual(shards_col[0].shape, (4, 1)) + self.assertAllClose(shards_col[0], tensor[:, 0:1]) + self.assertAllClose(shards_col[2], tensor[:, 2:3]) + + reconstructed_col = action_col.undo(shards_col) + self.assertAllClose(reconstructed_col, tensor) + + def test_split_keras_uneven_split(self): + """Tests SplitKeras with a tensor that does not divide evenly.""" + world_size = 3 + tensor = keras.ops.reshape( + keras.ops.arange(40, dtype="float32"), (4, 10) + ) + + action = SplitKeras(world_size=world_size, dim=1) + shards = [action(tensor, rank=i) for i in range(world_size)] + + self.assertEqual(shards[0].shape, (4, 4)) + self.assertEqual(shards[1].shape, (4, 3)) + self.assertEqual(shards[2].shape, (4, 3)) + + self.assertAllClose(shards[0], tensor[:, 0:4]) + self.assertAllClose(shards[1], tensor[:, 4:7]) + self.assertAllClose(shards[2], tensor[:, 7:10]) + + reconstructed = action.undo(shards) + self.assertAllClose(reconstructed, tensor) + + def test_split_keras_sharding_type_inference(self): + """Tests that `sharding_type` correctly infers the split dimension.""" + action_row = SplitKeras(world_size=2, dim=-1, sharding_type="row") + self.assertEqual(action_row.dim, 0) + + action_col = SplitKeras(world_size=2, dim=-1, sharding_type="column") + self.assertEqual(action_col.dim, 1) + + def test_gather_keras(self): + """Tests the GatherKeras action.""" + world_size = 4 + action = GatherKeras(world_size=world_size, dim=0) + tensor = keras.ops.array([[1, 2], [3, 4]], dtype="float32") + + processed_tensor = action(tensor, rank=0) + self.assertAllClose(processed_tensor, tensor) + + tensors_to_gather = [ + keras.ops.ones((2, 2)), + keras.ops.zeros((2, 2)), + keras.ops.ones((2, 2)), + ] + reconstructed = action.undo(tensors_to_gather) + expected = keras.ops.concatenate(tensors_to_gather, axis=0) + self.assertAllClose(reconstructed, expected) + + def test_sum_keras(self): + """Tests the SumKeras action.""" + world_size = 2 + action = SumKeras(world_size=world_size) + tensor = keras.ops.array([[1, 2], [3, 4]], dtype="float32") + + processed_tensor = action(tensor, rank=0) + self.assertAllClose(processed_tensor, tensor) + + tensors_to_sum = [ + keras.ops.full((2, 3), 5.0), + keras.ops.full((2, 3), 10.0), + ] + reconstructed = action.undo(tensors_to_sum) + expected = keras.ops.full((2, 3), 15.0) + self.assertAllClose(reconstructed, expected) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 11e4046c7b8a..ba4abbe1139a 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -39,7 +39,8 @@ from keras.src.backend.common.remat import get_current_remat_mode from keras.src.backend.common.symbolic_scope import in_symbolic_scope from keras.src.backend.config import is_nnx_enabled -from keras.src.distribution import distribution_lib + +# from keras.src.distribution import distribution_lib from keras.src.dtype_policies import DTypePolicyMap from keras.src.layers import input_spec from keras.src.metrics.metric import Metric @@ -942,6 +943,8 @@ def maybe_convert(x): # Change the layout for the layer output if needed. # This is useful for relayout intermediate tensor in the model # to achieve the optimal performance. + from keras.src.distribution import distribution_lib + distribution = distribution_lib.distribution() if distribution is not None: current_layer_path = current_path()