diff --git a/keras/api/_tf_keras/keras/distribution/__init__.py b/keras/api/_tf_keras/keras/distribution/__init__.py index 66fed24c761d..cb947b863cf1 100644 --- a/keras/api/_tf_keras/keras/distribution/__init__.py +++ b/keras/api/_tf_keras/keras/distribution/__init__.py @@ -4,6 +4,21 @@ since your modifications would be overwritten. """ +from keras.src.distribution.distributed_backend import ( + apply_gradients as apply_gradients, +) +from keras.src.distribution.distributed_backend import ( + create_optimizer as create_optimizer, +) +from keras.src.distribution.distributed_backend import ( + get_communication_ops as get_communication_ops, +) +from keras.src.distribution.distributed_backend import ( + get_device_info as get_device_info, +) +from keras.src.distribution.distributed_backend import ( + is_multi_device_capable as is_multi_device_capable, +) from keras.src.distribution.distribution_lib import DataParallel as DataParallel from keras.src.distribution.distribution_lib import DeviceMesh as DeviceMesh from keras.src.distribution.distribution_lib import LayoutMap as LayoutMap diff --git a/keras/api/distribution/__init__.py b/keras/api/distribution/__init__.py index 66fed24c761d..cb947b863cf1 100644 --- a/keras/api/distribution/__init__.py +++ b/keras/api/distribution/__init__.py @@ -4,6 +4,21 @@ since your modifications would be overwritten. """ +from keras.src.distribution.distributed_backend import ( + apply_gradients as apply_gradients, +) +from keras.src.distribution.distributed_backend import ( + create_optimizer as create_optimizer, +) +from keras.src.distribution.distributed_backend import ( + get_communication_ops as get_communication_ops, +) +from keras.src.distribution.distributed_backend import ( + get_device_info as get_device_info, +) +from keras.src.distribution.distributed_backend import ( + is_multi_device_capable as is_multi_device_capable, +) from keras.src.distribution.distribution_lib import DataParallel as DataParallel from keras.src.distribution.distribution_lib import DeviceMesh as DeviceMesh from keras.src.distribution.distribution_lib import LayoutMap as LayoutMap diff --git a/keras/src/backend/__init__.py b/keras/src/backend/__init__.py index 15f1af2145d5..b22ea22547bb 100644 --- a/keras/src/backend/__init__.py +++ b/keras/src/backend/__init__.py @@ -37,6 +37,8 @@ if backend() == "tensorflow": from keras.src.backend.tensorflow import * # noqa: F403 from keras.src.backend.tensorflow.core import Variable as BackendVariable + + distributed_backend = None elif backend() == "jax": from keras.src.backend.jax import * # noqa: F403 from keras.src.backend.jax.core import Variable as BackendVariable @@ -44,17 +46,20 @@ from keras.src.backend.torch import * # noqa: F403 from keras.src.backend.torch.core import Variable as BackendVariable + distributed_backend = None distribution_lib = None elif backend() == "numpy": from keras.src.backend.numpy import * # noqa: F403 from keras.src.backend.numpy.core import Variable as BackendVariable distribution_lib = None + distributed_backend = None elif backend() == "openvino": from keras.src.backend.openvino import * # noqa: F403 from keras.src.backend.openvino.core import Variable as BackendVariable distribution_lib = None + distributed_backend = None else: raise ValueError(f"Unable to import backend : {backend()}") diff --git a/keras/src/backend/jax/__init__.py b/keras/src/backend/jax/__init__.py index 89ac0fa71c8c..0a275fb70cf1 100644 --- a/keras/src/backend/jax/__init__.py +++ b/keras/src/backend/jax/__init__.py @@ -1,5 +1,6 @@ from keras.src.backend.config import is_nnx_enabled from keras.src.backend.jax import core +from keras.src.backend.jax import distributed_backend from keras.src.backend.jax import distribution_lib from keras.src.backend.jax import image from keras.src.backend.jax import linalg diff --git a/keras/src/backend/jax/distributed_backend.py b/keras/src/backend/jax/distributed_backend.py new file mode 100644 index 000000000000..e04a38f26497 --- /dev/null +++ b/keras/src/backend/jax/distributed_backend.py @@ -0,0 +1,248 @@ +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Literal + +import jax +import jax.lax as lax +import jax.numpy as jnp + +import keras + + +def compute_gradients( + _loss: jnp.ndarray, trainable_vars: List[jnp.ndarray] +) -> List[jnp.ndarray]: + """Computes gradients of the loss with respect to trainable variables. + + Note: This is a placeholder implementation that returns zeros. A real + implementation would use `jax.grad`. + + Args: + _loss (jnp.ndarray): The loss value for which to compute gradients. + trainable_vars (List[jnp.ndarray]): A list of variables to compute + gradients with respect to. + + Returns: + List[jnp.ndarray]: A list of gradients corresponding to the + trainable variables. + """ + return [jnp.zeros_like(var) for var in trainable_vars] + + +def apply_gradients( + gradients: List[jnp.ndarray], + trainable_vars: List[jnp.ndarray], + learning_rate: float = 0.001, +) -> List[jnp.ndarray]: + """Applies gradients and returns the updated variables.""" + updated_vars = [] + for grad, var in zip(gradients, trainable_vars): + if grad is not None: + new_var = var - (learning_rate * grad) + updated_vars.append(new_var) + else: + updated_vars.append(var) + return updated_vars + + +def create_optimizer(optimizer_class: str, **kwargs) -> Dict[str, Any]: + """Creates a configuration dictionary for an optimizer. + + This function returns a dictionary containing the optimizer's configuration, + removing the need for a specific optimizer library like Optax. + + Args: + optimizer_class (str): The name of the optimizer to create (e.g., + `"adam"`, `"sgd"`). + **kwargs: Keyword arguments to be passed to the optimizer's + constructor (e.g., `learning_rate`). + + Returns: + Dict[str, Any]: A dictionary representing the optimizer configuration. + """ + config = kwargs.copy() + config["name"] = optimizer_class.lower() + config.setdefault("learning_rate", 0.001) + return config + + +def get_device_info() -> Dict[str, Any]: + """Retrieves information about the available JAX devices. + + Returns: + Dict[str, Any]: A dictionary containing the backend name, a list of + available device strings, and the total device count. + """ + available_devices = jax.devices() + return { + "backend": "jax", + "devices": [str(d) for d in available_devices], + "device_count": len(available_devices), + } + + +def is_multi_device_capable() -> bool: + """Checks if more than one JAX device is available. + + Returns: + bool: `True` if JAX reports more than one local device, `False` + otherwise. + """ + return jax.local_device_count() > 1 + + +def get_communication_ops() -> Dict[str, Callable]: + """Provides a dictionary of JAX collective communication operations. + + These operations are designed to work within a `jax.pmap` context for + multi-device computation. If not in a `pmap` context, they generally + behave as no-ops or simulate the operation on the single local device. + + Returns: + Dict[str, Callable]: A dictionary mapping operation names to their + JAX implementations. + """ + + def _is_in_pmap(axis_name: str = "data") -> bool: + """Checks if currently inside a pmap by probing the axis name.""" + try: + lax.axis_index(axis_name) + return True + except NameError: + return False + + def all_reduce( + x: jnp.ndarray, + op: Literal["sum", "mean"] = "sum", + axis_name: str = "data", + ) -> jnp.ndarray: + """Reduces a tensor across all devices in a `pmap`. + + Args: + x (jnp.ndarray): The tensor to reduce. + op (Literal["sum", "mean"], optional): The reduction operation. + Defaults to "sum". + axis_name (str, optional): The name of the `pmap` axis. + Defaults to "data". + + Returns: + jnp.ndarray: The reduced tensor. Returns the input tensor `x` if + not in a `pmap` context. + """ + if _is_in_pmap(axis_name): + reduce_ops = { + "sum": lax.psum, + "mean": lax.pmean, + } + reduce_fn = reduce_ops.get(op) + + if reduce_fn is None: + raise ValueError(f"Unsupported all_reduce op: {op}") + return reduce_fn(x, axis_name=axis_name) + else: + world_size = jax.local_device_count() + if world_size <= 1: + return x + if op == "sum": + return keras.ops.multiply(x, float(world_size)) + elif op == "mean": + return x + else: + raise ValueError(f"Unsupported all_reduce op: {op}") + + def all_gather( + x: jnp.ndarray, axis: int = 0, axis_name: str = "data" + ) -> jnp.ndarray: + """Gathers tensors from all devices and concatenates them. + + Args: + x (jnp.ndarray): The local tensor to gather. + axis (int, optional): The axis along which to concatenate the + gathered tensors. Defaults to 0. + axis_name (str, optional): The name of the `pmap` axis. + Defaults to "data". + + Returns: + jnp.ndarray: The concatenated tensor from all devices. + """ + if _is_in_pmap(axis_name): + return lax.all_gather(x, axis_name=axis_name, axis=axis) + else: + world_size = jax.local_device_count() + if world_size <= 1: + return x + return keras.ops.concatenate([x] * world_size, axis=axis) + + def broadcast( + x: jnp.ndarray, root: int = 0, axis_name: str = "data" + ) -> jnp.ndarray: + """Broadcasts a tensor from a root device to all other devices. + + Args: + x (jnp.ndarray): The tensor to broadcast. On the root device, this + is the tensor to be sent. + root (int, optional): The rank of the device from which to + broadcast. Defaults to 0. + axis_name (str, optional): The name of the `pmap` axis. + Defaults to "data". + + Returns: + jnp.ndarray: The tensor received from the root device. + """ + if _is_in_pmap(axis_name): + return lax.all_gather(x, axis_name=axis_name, axis=0)[root] + else: + return x + + def scatter( + x: jnp.ndarray, + root: int = 0, + axis: int = 0, + axis_name: str = "data", + ) -> jnp.ndarray: + """Scatters a tensor from a root device to all devices. + + Args: + x (jnp.ndarray): The tensor on the root device to be scattered. + root (int, optional): The rank of the device that holds the full + tensor. Defaults to 0. + axis (int, optional): The axis along which to split the tensor. + Defaults to 0. + axis_name (str, optional): The name of the `pmap` axis. + Defaults to "data". + + Returns: + jnp.ndarray: The chunk of the tensor for the local device. + """ + if _is_in_pmap(axis_name): + full_tensor = lax.all_gather(x, axis_name=axis_name, axis=0)[root] + device_id = lax.axis_index(axis_name=axis_name) + num_devices = lax.psum(1, axis_name=axis_name) + chunk_size = full_tensor.shape[axis] // num_devices + start_index = device_id * chunk_size + return lax.dynamic_slice_in_dim( + operand=full_tensor, + start_index=start_index, + slice_size=chunk_size, + axis=axis, + ) + else: + world_size = jax.local_device_count() + if world_size <= 1: + return x + if x.shape[axis] % world_size != 0: + raise ValueError( + f"Tensor with shape {x.shape} cannot be scattered along " + f"axis {axis} across {world_size} devices." + ) + chunks = keras.ops.split(x, world_size, axis=axis) + return chunks[0] + + return { + "all_reduce": all_reduce, + "all_gather": all_gather, + "broadcast": broadcast, + "scatter": scatter, + } diff --git a/keras/src/backend/jax/distributed_backend_test.py b/keras/src/backend/jax/distributed_backend_test.py new file mode 100644 index 000000000000..e57286e8bf47 --- /dev/null +++ b/keras/src/backend/jax/distributed_backend_test.py @@ -0,0 +1,121 @@ +import os + +os.environ["JAX_PLATFORM_NAME"] = "cpu" + +import pytest + +import keras +from keras.src import backend +from keras.src import ops +from keras.src import testing +from keras.src.backend import distributed_backend + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="Jax Backend specific test", +) +class TestJaxDistributedFunctions(testing.TestCase): + """Unit tests for the JAX distributed backend standalone functions.""" + + def test_compute_gradients_returns_zeros(self): + """Test that compute_gradients returns correctly shaped zero tensors.""" + loss = ops.array(10.0) + trainable_vars = [ops.array([1.0, 2.0]), ops.array(3.0)] + gradients = distributed_backend.compute_gradients(loss, trainable_vars) + self.assertEqual(len(gradients), 2) + self.assertAllClose(gradients[0], ops.zeros_like(trainable_vars[0])) + self.assertAllClose(gradients[1], ops.zeros_like(trainable_vars[1])) + + def test_apply_gradients(self): + """Test the application of gradients to Keras variables.""" + var1 = keras.Variable([1.0, 2.0]) + var2 = keras.Variable(5.0) + trainable_vars = [var1, var2] + grad1 = ops.array([0.1, 0.2]) + grad2 = ops.array(0.5) + gradients = [grad1, grad2] + learning_rate = 0.1 + + updated_vars = distributed_backend.apply_gradients( + gradients, trainable_vars, learning_rate + ) + expected_var1 = ops.array([1.0, 2.0]) - ops.multiply( + ops.array([0.1, 0.2]), learning_rate + ) + expected_var2 = 5.0 - (0.5 * learning_rate) + self.assertAllClose(updated_vars[0], expected_var1) + self.assertAllClose(updated_vars[1], expected_var2) + + def test_create_optimizer(self): + """Test optimizer configuration creation.""" + adam_config = distributed_backend.create_optimizer( + "adam", learning_rate=0.01 + ) + self.assertIsInstance(adam_config, dict) + self.assertEqual(adam_config["name"], "adam") + self.assertEqual(adam_config["learning_rate"], 0.01) + + sgd_config = distributed_backend.create_optimizer( + "sgd", learning_rate=0.1, momentum=0.9 + ) + self.assertIsInstance(sgd_config, dict) + self.assertEqual(sgd_config["name"], "sgd") + self.assertEqual(sgd_config["learning_rate"], 0.1) + self.assertEqual(sgd_config["momentum"], 0.9) + + unknown_config = distributed_backend.create_optimizer( + "some_unknown_optimizer" + ) + self.assertIsInstance(unknown_config, dict) + self.assertEqual(unknown_config["name"], "some_unknown_optimizer") + self.assertEqual(unknown_config["learning_rate"], 0.001) + + def test_get_device_info(self): + """Test retrieving device information from the JAX backend.""" + info = distributed_backend.get_device_info() + self.assertEqual(info["backend"], "jax") + self.assertIsInstance(info["devices"], list) + self.assertIsInstance(info["device_count"], int) + self.assertGreater(info["device_count"], 0) + self.assertEqual(len(info["devices"]), info["device_count"]) + + def test_is_multi_device_capable(self): + """Test the boolean check for multi-device capability.""" + self.assertIsInstance( + distributed_backend.is_multi_device_capable(), bool + ) + + def test_communication_ops_simulation_logic(self): + """Test the simulated communication ops in a single-device context.""" + comm_ops = distributed_backend.get_communication_ops() + device_info = distributed_backend.get_device_info() + world_size = device_info.get("device_count", 1) + + # Test all_reduce + x_reduce = ops.array([[1.0, 2.0], [3.0, 4.0]]) + reduced = comm_ops["all_reduce"](x_reduce, op="sum") + if world_size > 1: + expected_reduce = ops.multiply(x_reduce, float(world_size)) + else: + expected_reduce = x_reduce + self.assertAllClose(reduced, expected_reduce) + + # Test all_gather + x_gather = ops.array([[1.0, 2.0]]) + gathered = comm_ops["all_gather"](x_gather, axis=0) + expected_gather = ops.concatenate([x_gather] * world_size, axis=0) + self.assertAllClose(gathered, expected_gather) + + # Test broadcast + x_broadcast = ops.array([5.0, 6.0]) + broadcasted = comm_ops["broadcast"](x_broadcast) + self.assertAllClose(broadcasted, x_broadcast) + + # Test scatter + if world_size > 0: + scatter_data = ops.arange(world_size * 2, dtype="float32") + x_scatter = ops.reshape(scatter_data, (world_size, 2)) + scattered = comm_ops["scatter"](x_scatter) + expected_scatter = ops.split(x_scatter, world_size, axis=0)[0] + self.assertAllClose(scattered, expected_scatter) diff --git a/keras/src/distribution/__init__.py b/keras/src/distribution/__init__.py index 04d907f35697..9670743bd3ed 100644 --- a/keras/src/distribution/__init__.py +++ b/keras/src/distribution/__init__.py @@ -1,3 +1,8 @@ +from keras.src.distribution.distributed_backend import apply_gradients +from keras.src.distribution.distributed_backend import create_optimizer +from keras.src.distribution.distributed_backend import get_communication_ops +from keras.src.distribution.distributed_backend import get_device_info +from keras.src.distribution.distributed_backend import is_multi_device_capable from keras.src.distribution.distribution_lib import DataParallel from keras.src.distribution.distribution_lib import DeviceMesh from keras.src.distribution.distribution_lib import Distribution diff --git a/keras/src/distribution/distributed_backend.py b/keras/src/distribution/distributed_backend.py new file mode 100644 index 000000000000..7b54d25b7f09 --- /dev/null +++ b/keras/src/distribution/distributed_backend.py @@ -0,0 +1,87 @@ +from typing import Any +from typing import List + +from keras.src.api_export import keras_export +from keras.src.backend import distributed_backend + + +@keras_export("keras.distribution.apply_gradients") +def apply_gradients( + gradients: List[Any], + trainable_vars: List[Any], + learning_rate: float = 0.001, +) -> None: + """Applies gradients to trainable variables. + + This function is a distribution-aware wrapper that delegates the gradient + application to the current backend's implementation. + + Args: + gradients (List[Any]): A list of gradients to be applied. + trainable_vars (List[Any]): A list of trainable variables to be updated. + learning_rate (float, optional): The learning rate to use for the + update. Defaults to 0.001. + """ + return distributed_backend.apply_gradients( + gradients, trainable_vars, learning_rate + ) + + +@keras_export("keras.distribution.create_optimizer") +def create_optimizer(optimizer_class: str, **kwargs): + """Creates a backend-specific optimizer instance. + + This function instantiates an optimizer suitable for the current distributed + backend, forwarding all keyword arguments to the optimizer's constructor. + + Args: + optimizer_class (str): The class name of the optimizer to create (e.g., + `"Adam"`). + **kwargs: Additional keyword arguments to be passed to the optimizer's + constructor. + + Returns: + An instance of the requested optimizer. + """ + return distributed_backend.create_optimizer(optimizer_class, **kwargs) + + +@keras_export("keras.distribution.get_device_info") +def get_device_info() -> dict: + """Gets information about available computational devices. + + Retrieves details about the devices (e.g., CPU, GPU) that are visible + to the current backend. + + Returns: + dict: A dictionary containing information about the available devices. + """ + return distributed_backend.get_device_info() + + +@keras_export("keras.distribution.is_multi_device_capable") +def is_multi_device_capable() -> bool: + """Checks if the backend supports multi-device operations. + + This function determines if the underlying backend is configured and + capable of running computations across multiple devices. + + Returns: + bool: `True` if the backend supports multi-device training, + `False` otherwise. + """ + return distributed_backend.is_multi_device_capable() + + +@keras_export("keras.distribution.get_communication_ops") +def get_communication_ops() -> dict: + """Gets collective communication operations for the backend. + + This function returns a dictionary of collective ops (e.g., `all_reduce`, + `all_gather`) that can be used for distributed communication. + + Returns: + dict: A dictionary mapping the names of communication operations + (str) to their callable implementations. + """ + return distributed_backend.get_communication_ops() diff --git a/keras/src/distribution/tensor_parallel/communications.py b/keras/src/distribution/tensor_parallel/communications.py new file mode 100644 index 000000000000..fc0ca19e457d --- /dev/null +++ b/keras/src/distribution/tensor_parallel/communications.py @@ -0,0 +1,465 @@ +from typing import Any +from typing import List +from typing import Tuple + +from keras.src.distribution import distributed_backend + + +class CollectiveOpKeras: + """Base class for Keras collective communication operations. + + This class provides a common interface for various collective communication + primitives like AllReduce, AllGather, and Broadcast. Subclasses must + implement the `__call__` method. + + Args: + world_size (int): The total number of participating processes or devices + in the communication group. + rank (int, optional): The rank of the current process. Defaults to 0. + """ + + def __init__(self, world_size: int, rank: int = 0): + """Initializes the collective operation. + + Args: + world_size (int): The total number of participating processes or + devices in the communication group. + rank (int, optional): The rank of the current process. Defaults + to 0. + """ + self.world_size = world_size + self.rank = rank + + def __call__(self, *args, **kwargs): + """Executes the collective operation.""" + raise NotImplementedError + + +class AllReduceKeras(CollectiveOpKeras): + """Performs an AllReduce collective operation. + + AllReduce reduces the input tensor across all devices and distributes the + final result back to all devices. + + Args: + world_size (int): The total number of participating processes. + op (str, optional): The reduction operation. Supported values are + "sum" and "mean". Defaults to "sum". + rank (int, optional): The rank of the current process. Defaults to 0. + + Raises: + NotImplementedError: If the current backend does not support the + AllReduce operation. + """ + + def __init__(self, world_size: int, op: str = "sum", rank: int = 0): + """Initializes the AllReduce operation. + + Args: + world_size (int): The total number of participating processes. + op (str, optional): The reduction operation. Supported values are + "sum" and "mean". Defaults to "sum". + rank (int, optional): The rank of current process. Defaults to 0. + """ + super().__init__(world_size, rank) + self.op = op + self.all_reduce_fn = distributed_backend.get_communication_ops().get( + "all_reduce" + ) + if self.all_reduce_fn is None: + raise NotImplementedError( + "AllReduce is not supported by the current backend." + ) + + def __call__(self, local_tensor: Any, axis_name: str) -> Any: + """Executes the AllReduce operation. + + Args: + local_tensor (Any): The tensor on the local device to be reduced. + axis_name (str): The name of the axis to reduce over, used by the + backend for identifying the device group. + + Returns: + Any: The reduced tensor, which is identical on all devices. + """ + return self.all_reduce_fn(local_tensor, op=self.op, axis_name=axis_name) + + +class AllGatherKeras(CollectiveOpKeras): + """Performs an AllGather collective operation. + + AllGather gathers tensors from all devices and concatenates them along a + specified dimension. The final concatenated tensor is available on all + devices. + + Args: + world_size (int): The total number of participating processes. + dim (int, optional): The dimension along which to concatenate the + gathered tensors. Defaults to -1. + rank (int, optional): The rank of the current process. Defaults to 0. + + Raises: + NotImplementedError: If the current backend does not support the + AllGather operation. + """ + + def __init__(self, world_size: int, dim: int = -1, rank: int = 0): + """Initializes the AllGather operation. + + Args: + world_size (int): The total number of participating processes. + dim (int, optional): The dimension along which to concatenate the + gathered tensors. Defaults to -1. + rank (int, optional): The rank of current process. Defaults to 0. + """ + super().__init__(world_size, rank) + self.dim = dim + self.all_gather_fn = distributed_backend.get_communication_ops().get( + "all_gather" + ) + if self.all_gather_fn is None: + raise NotImplementedError( + "AllGather is not supported by the current backend." + ) + + def __call__(self, local_tensor: Any, axis_name: str) -> Any: + """Executes the AllGather operation. + + Args: + local_tensor (Any): The tensor on the local device to be gathered. + axis_name (str): The name of the axis for the device group, used by + the backend for communication. + + Returns: + Any: The concatenated tensor, containing data from all devices. + """ + return self.all_gather_fn( + local_tensor, axis=self.dim, axis_name=axis_name + ) + + +class BroadcastKeras(CollectiveOpKeras): + """Performs a Broadcast collective operation. + + Broadcast sends a tensor from a single source device to all other devices + in the group. + + Args: + world_size (int): The total number of participating processes. + src_rank (int, optional): The rank of the source process that is + broadcasting the tensor. Defaults to 0. + rank (int, optional): The rank of the current process. Defaults to 0. + + Raises: + NotImplementedError: If the current backend does not support the + Broadcast operation. + """ + + def __init__(self, world_size: int, src_rank: int = 0, rank: int = 0): + """Initializes the Broadcast operation. + + Args: + world_size (int): The total number of participating processes. + src_rank (int, optional): The rank of the source process that is + broadcasting the tensor. Defaults to 0. + rank (int, optional): The rank of current process. Defaults to 0. + """ + super().__init__(world_size, rank) + self.src_rank = src_rank + self.broadcast_fn = distributed_backend.get_communication_ops().get( + "broadcast" + ) + if self.broadcast_fn is None: + raise NotImplementedError( + "Broadcast is not supported by the current backend." + ) + + def __call__(self, tensor: Any, axis_name: str) -> Any: + """Executes the Broadcast operation. + + Args: + tensor (Any): The tensor to be broadcasted (on the source device) or + received (on other devices). + axis_name (str): The name of the axis for the device group, used by + the backend for communication. + + Returns: + Any: The broadcasted tensor from the source device. + """ + return self.broadcast_fn( + tensor, root=self.src_rank, axis_name=axis_name + ) + + +class TensorParallelCommunicator: + """Manages communication operations for tensor parallelism. + + This class abstracts the collective communication logic required for + implementing tensor-parallel models, providing specific methods for + column-parallel and row-parallel layers. + + Args: + world_size (int): The total number of devices in the group. + rank (int, optional): The rank of the current device. Defaults to 0. + """ + + def __init__(self, world_size: int, rank: int = 0): + """Initializes the communicator. + + Args: + world_size (int): The total number of devices in the group. + rank (int, optional): The rank of the current device. Defaults to 0. + """ + self.world_size = world_size + self.rank = rank + self.allreduce = AllReduceKeras(world_size, rank=rank) + self.allgather = AllGatherKeras(world_size, rank=rank) + self.broadcast = BroadcastKeras(world_size, rank=rank) + + def forward_column_parallel( + self, partial_outputs: List, dim: int = -1, axis_name: str = "batch" + ): + """Gathers output shards in a column-parallel forward pass. + + In a column-parallel layer, the output activations are sharded across + devices. This function collects all shards using an AllGather operation + to form the full output tensor. + + Args: + partial_outputs (List): A list of output shards, with one tensor + from each device in the communication group. + dim (int, optional): The dimension along which to concatenate the + gathered tensors. Defaults to -1. + axis_name (str, optional): The name of the communication axis used + by the backend. Defaults to "batch". + + Returns: + Any: The full, gathered output tensor, which is identical on all + devices. + """ + self.allgather.dim = dim + return self.allgather(partial_outputs[self.rank], axis_name=axis_name) + + def backward_column_parallel( + self, + partial_gradients: List, + op: str = "sum", + axis_name: str = "batch", + ) -> List: + """Reduces weight gradients in a column-parallel backward pass. + + This is the conjugate operation to `forward_column_parallel`. It uses an + AllReduce operation to sum the gradients computed on each device for + the weight matrix. + + Args: + partial_gradients (List): A list of local weight gradients, with + one tensor from each device. + op (str, optional): The reduction operation, either "sum" or "mean". + Defaults to "sum". + axis_name (str, optional): The name of the communication axis. + Defaults to "batch". + + Returns: + Any: The reduced gradient tensor, identical on all devices. + """ + self.allreduce.op = op + return self.allreduce(partial_gradients[self.rank], axis_name=axis_name) + + def forward_row_parallel( + self, partial_outputs: List, op: str = "sum", axis_name: str = "batch" + ) -> List: + """Reduces output shards in a row-parallel forward pass. + + In a row-parallel layer, each device computes a partial output. This + function uses an AllReduce operation to sum these partial outputs into + the final, correct output tensor. + + Args: + partial_outputs (List): A list of partial outputs, one from each + device. + op (str, optional): The reduction operation, either "sum" or "mean". + Defaults to "sum". + axis_name (str, optional): The name of the communication axis. + Defaults to "batch". + + Returns: + Any: The final, reduced output tensor. + """ + self.allreduce.op = op + return self.allreduce(partial_outputs[self.rank], axis_name=axis_name) + + def backward_row_parallel( + self, partial_gradients: List, dim: int = -1, axis_name: str = "batch" + ): + """Gathers input gradients in a row-parallel backward pass. + + This is the conjugate operation to `forward_row_parallel`. It uses an + AllGather operation to collect the sharded input gradients from all + devices to reconstruct the full gradient tensor. + + Args: + partial_gradients (List): A list of local input gradients, one + from each device. + dim (int, optional): The dimension along which to concatenate the + gradients. Defaults to -1. + axis_name (str, optional): The name of the communication axis. + Defaults to "batch". + + Returns: + Any: The full, gathered gradient tensor. + """ + self.allgather.dim = dim + return self.allgather(partial_gradients[self.rank], axis_name=axis_name) + + def handle_mlp_handshake( + self, up_projection_outputs: List, down_projection_inputs: List + ) -> Tuple: + """Manages communication between two MLP layers for tensor parallelism. + + This is a specialized function for a common pattern where a + column-parallel layer (`up_projection`) is followed by a row-parallel + layer (`down_projection`). It combines their forward communication. + + Args: + up_projection_outputs (List): A list of local output tensors from + the `up_projection` layer on each device. + down_projection_inputs (List): A list of local input tensors for + the `down_projection` layer on each device. + + Returns: + tuple: A tuple with the gathered output from `up_projection` and + the reduced input for `down_projection`. + """ + up_output = self.forward_column_parallel( + up_projection_outputs[self.rank], dim=-1 + ) + down_inputs = self.forward_row_parallel( + down_projection_inputs[self.rank], op="sum" + ) + return up_output, down_inputs + + def slice_upstream_gradient_for_column_parallel( + self, full_gradient: Any, rank: int, world_size: int, dim: int = -1 + ) -> Any: + """Slices the gradient for a column-parallel layer's backward pass. + + Before the backward pass of a column-parallel layer, the full upstream + gradient must be sliced so that each device receives the portion + corresponding to its output shard. It handles uneven sharding. + + Args: + full_gradient (Any): The complete upstream gradient tensor. + rank (int): The rank of the current device. + world_size (int): The total number of devices. + dim (int, optional): The dimension to slice along. Defaults to -1. + + Returns: + Any: The sliced portion of the gradient for the current device. + """ + shape = getattr(full_gradient, "shape", None) + if shape is None or not (-len(shape) <= dim < len(shape)): + return full_gradient + + total_size = shape[dim] + slice_size = total_size // world_size + remainder = total_size % world_size + start_idx = rank * slice_size + min(rank, remainder) + end_idx = start_idx + slice_size + (1 if rank < remainder else 0) + slices = [slice(None)] * len(shape) + slices[dim] = slice(start_idx, end_idx) + return full_gradient[tuple(slices)] + + def slice_upstream_gradient_for_row_parallel( + self, full_gradient: Any, rank: int, world_size: int, dim: int = 0 + ) -> Any: + """Slices the gradient for a row-parallel layer's backward pass. + + Before the backward pass of a row-parallel layer, the full upstream + gradient must be sliced so each device gets the part + corresponding to its input shard. + + Args: + full_gradient (Any): The complete upstream gradient tensor. + rank (int): The rank of the current device. + world_size (int): The total number of devices. + dim (int, optional): The dimension to slice along. Defaults to 0. + + Returns: + Any: The sliced portion of the gradient for the current device. + """ + shape = getattr(full_gradient, "shape", None) + if shape is None or not (-len(shape) <= dim < len(shape)): + return full_gradient + + total_size = shape[dim] + slice_size = total_size // world_size + start_idx = rank * slice_size + end_idx = (rank + 1) * slice_size + if rank == world_size - 1: + end_idx = total_size + slices = [slice(None)] * len(shape) + slices[dim] = slice(start_idx, end_idx) + return full_gradient[tuple(slices)] + + +def allreduce_gradients(gradients: Any, world_size: int) -> Any: + """Utility function to perform a mean AllReduce operation on gradients. + + This is commonly used in data parallelism to average gradients across all + devices before applying the optimizer step. + + Args: + gradients (Any): A tensor or list of tensors representing the gradients + on the local device. + world_size (int): The total number of devices. + + Returns: + Any: The averaged gradient tensor. + """ + allreduce_op = AllReduceKeras(world_size, op="mean") + local_gradient = gradients[0] if isinstance(gradients, list) else gradients + return allreduce_op(local_gradient, axis_name="batch") + + +def allgather_outputs(outputs: Any, world_size: int, dim: int = -1) -> Any: + """Utility function to perform an AllGather operation on model outputs. + + This can be used to collect the final outputs from all devices when running + inference in a distributed manner. + + Args: + outputs (Any): A tensor or list of tensors representing the model's + output on the local device. + world_size (int): The total number of devices. + dim (int, optional): The dimension along which to concatenate the + outputs. Defaults to -1. + + Returns: + Any: The gathered, full output tensor. + """ + allgather_op = AllGatherKeras(world_size, dim=dim) + local_output = outputs[0] if isinstance(outputs, list) else outputs + return allgather_op(local_output, axis_name="batch") + + +def broadcast_parameters( + parameters: List[Any], world_size: int, src_rank: int = 0 +) -> Any: + """Utility function to broadcast model parameters from a source device. + + This is typically used at the beginning of training to ensure all devices + start with the same initial model weights. + + Args: + parameters (List[Any]): A list of model parameters, where each element + corresponds to the parameters on a device. + world_size (int): The total number of devices. + src_rank (int, optional): The rank of the source device to broadcast + from. Defaults to 0. + + Returns: + Any: The broadcasted parameters. + """ + broadcast_op = BroadcastKeras(world_size, src_rank=src_rank) + return broadcast_op(parameters[src_rank], axis_name="batch") diff --git a/keras/src/distribution/tensor_parallel/communications_test.py b/keras/src/distribution/tensor_parallel/communications_test.py new file mode 100644 index 000000000000..3e89eacd6df3 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/communications_test.py @@ -0,0 +1,83 @@ +import pytest + +import keras +from keras.src import testing +from keras.src.backend import distributed_backend +from keras.src.distribution.tensor_parallel.communications import AllGatherKeras +from keras.src.distribution.tensor_parallel.communications import AllReduceKeras +from keras.src.distribution.tensor_parallel.communications import BroadcastKeras +from keras.src.distribution.tensor_parallel.communications import ( + TensorParallelCommunicator, +) + + +@pytest.mark.skipif( + keras.backend.backend() != "jax", + reason="This test suite requires a real JAX distributed backend.", +) +class TestCollectiveOps(testing.TestCase): + """ + Tests collective communication ops on a JAX distributed backend. + """ + + def setUp(self): + super().setUp() + device_info = distributed_backend.get_device_info() + self.world_size = device_info.get("device_count", 1) + + if not self.world_size: + self.world_size = 1 + + self.axis_name = "data" + + def test_all_reduce(self): + """Tests the all-reduce operation.""" + all_reduce_op = AllReduceKeras(world_size=self.world_size, op="sum") + local_tensor = keras.ops.array([1.0, 2.0, 3.0]) + + result = all_reduce_op(local_tensor, axis_name=self.axis_name) + + expected_output = keras.ops.multiply( + local_tensor, float(self.world_size) + ) + self.assertAllClose(result, expected_output) + + def test_all_gather(self): + """Tests the all-gather operation.""" + all_gather_op = AllGatherKeras(world_size=self.world_size, dim=0) + local_slice = keras.ops.arange(6, dtype="float32").reshape((2, 3)) + result = all_gather_op(local_slice, axis_name=self.axis_name) + + expected_output = keras.ops.concatenate( + [local_slice] * self.world_size, axis=0 + ) + self.assertAllClose(result, expected_output) + + def test_broadcast(self): + """Tests the broadcast operation.""" + broadcast_op = BroadcastKeras( + world_size=self.world_size, src_rank=0, rank=0 + ) + tensor_to_broadcast = keras.ops.array([5.0, 10.0, 15.0]) + result = broadcast_op(tensor_to_broadcast, axis_name=self.axis_name) + + self.assertAllClose(result, tensor_to_broadcast) + + def test_tensor_parallel_communicator_forward_column_parallel(self): + """Tests the communicator's all-gather for column-parallel forward.""" + communicator = TensorParallelCommunicator( + world_size=self.world_size, rank=0 + ) + + local_slice = keras.ops.array([[0.0, 1.0], [2.0, 3.0]], dtype="float32") + + result = communicator.forward_column_parallel( + partial_outputs=[local_slice], + dim=0, + axis_name=self.axis_name, + ) + + expected_output = keras.ops.concatenate( + [local_slice] * self.world_size, axis=0 + ) + self.assertAllClose(result, expected_output) diff --git a/keras/src/distribution/tensor_parallel/config.py b/keras/src/distribution/tensor_parallel/config.py new file mode 100644 index 000000000000..8a6b89613b12 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/config.py @@ -0,0 +1,92 @@ +import dataclasses +from typing import Any +from typing import Dict +from typing import Sequence + +from keras.src.distribution.tensor_parallel.communications import AllGatherKeras +from keras.src.distribution.tensor_parallel.communications import AllReduceKeras +from keras.src.distribution.tensor_parallel.communications import BroadcastKeras + + +def _create_ops_from_rules( + rules: Dict[str, Any], world_size: int +) -> Dict[str, Any]: + """Parses a rules dictionary to create collective op instances. + + This function iterates through a dictionary of rules. If it encounters a + string identifier for a collective operation (e.g., "sum", "mean", + "gather -1"), it replaces it with an instantiated Keras collective op + object. Other values are passed through unchanged. + + Args: + rules (Dict[str, Any]): The dictionary of rules to process. + world_size (int): The total number of devices in the distributed setup. + + Returns: + Dict[str, Any]: A new dictionary with string identifiers replaced by + collective op instances. + """ + processed_rules = {} + for pattern, actions in rules.items(): + if not isinstance(actions, dict): + processed_rules[pattern] = actions + continue + + processed_rules[pattern] = {} + for key, action in actions.items(): + if not isinstance(action, str): + processed_rules[pattern][key] = action + continue + + if action == "sum": + op = AllReduceKeras(world_size, op="sum") + elif action == "mean": + op = AllReduceKeras(world_size, op="mean") + elif action.startswith("gather"): + dim = int(action.split(" ")[1]) if " " in action else -1 + op = AllGatherKeras(world_size, dim=dim) + elif action == "broadcast": + op = BroadcastKeras(world_size) + else: + op = action + processed_rules[pattern][key] = op + return processed_rules + + +@dataclasses.dataclass +class ConfigKeras: + """A dataclass holding configuration for tensor parallelism in Keras. + + Attributes: + state_rules (Dict[str, Any]): Rules governing how model state variables + (e.g., weights) are handled across devices. + output_rules (Dict[str, Any]): Rules governing how layer outputs are + handled. These rules are processed by `create_collective_ops` to + instantiate the necessary communication operations. + """ + + state_rules: Dict[str, Any] + output_rules: Dict[str, Any] + + def create_collective_ops(self, devices: Sequence[str]): + """Creates a new ConfigKeras instance with collective ops. + + This method processes the `output_rules` of the current instance, + replacing string-based rule definitions with actual collective + communication op objects required for distributed execution. + + Args: + devices (Sequence[str]): A sequence of device strings (e.g., + ["/gpu:0", "/gpu:1"]), used to determine the world size. + + Returns: + ConfigKeras: A new `ConfigKeras` object with the `output_rules` + populated with instantiated collective op objects. + """ + world_size = len(devices) + new_output_rules = _create_ops_from_rules(self.output_rules, world_size) + + return dataclasses.replace( + self, + output_rules=new_output_rules, + ) diff --git a/keras/src/distribution/tensor_parallel/config_test.py b/keras/src/distribution/tensor_parallel/config_test.py new file mode 100644 index 000000000000..16258e917ad1 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/config_test.py @@ -0,0 +1,96 @@ +import pytest + +import keras +from keras.src import testing +from keras.src.distribution.tensor_parallel.communications import AllGatherKeras +from keras.src.distribution.tensor_parallel.communications import AllReduceKeras +from keras.src.distribution.tensor_parallel.communications import BroadcastKeras +from keras.src.distribution.tensor_parallel.config import ConfigKeras +from keras.src.distribution.tensor_parallel.config import _create_ops_from_rules + + +@pytest.mark.skipif( + keras.backend.backend() != "jax", + reason="This test suite requires a real JAX distributed backend.", +) +class TestConfig(testing.TestCase): + """Test suite for the tensor parallel configuration.""" + + def test_create_ops_from_rules_helper(self): + """ + Tests the private _create_ops_from_rules helper function directly + to ensure it correctly parses various rule types. + """ + devices = ["/gpu:0", "/gpu:1"] + world_size = len(devices) + rules = { + "dense/kernel": {"forward": "sum", "backward": "mean"}, + "embedding/weight": { + "forward": "gather 0", + "backward": "gather -1", + }, + "attention/dense/bias": {"forward": "broadcast"}, + "passthrough": {"action": 123}, + "no_dict_action": "identity", + } + + processed_rules = _create_ops_from_rules(rules, world_size) + + sum_op = processed_rules["dense/kernel"]["forward"] + self.assertIsInstance(sum_op, AllReduceKeras) + self.assertEqual(sum_op.op, "sum") + self.assertEqual(sum_op.world_size, world_size) + + mean_op = processed_rules["dense/kernel"]["backward"] + self.assertIsInstance(mean_op, AllReduceKeras) + self.assertEqual(mean_op.op, "mean") + + gather_op_0 = processed_rules["embedding/weight"]["forward"] + self.assertIsInstance(gather_op_0, AllGatherKeras) + self.assertEqual(gather_op_0.dim, 0) + self.assertEqual(gather_op_0.world_size, world_size) + + gather_op_neg1 = processed_rules["embedding/weight"]["backward"] + self.assertIsInstance(gather_op_neg1, AllGatherKeras) + self.assertEqual(gather_op_neg1.dim, -1) + + broadcast_op = processed_rules["attention/dense/bias"]["forward"] + self.assertIsInstance(broadcast_op, BroadcastKeras) + self.assertEqual(broadcast_op.world_size, world_size) + + self.assertEqual(processed_rules["passthrough"]["action"], 123) + self.assertEqual(processed_rules["no_dict_action"], "identity") + + def test_config_keras_create_collective_ops(self): + """ + Tests the public create_collective_ops method of the ConfigKeras class. + """ + devices = ["/gpu:0", "/gpu:1"] + world_size = len(devices) + + state_rules = {"some_weight": "split"} + output_rules = { + "layer_1_output": {"activation": "sum"}, + "layer_2_output": {"activation": "gather -1"}, + } + + config = ConfigKeras(state_rules=state_rules, output_rules=output_rules) + new_config = config.create_collective_ops(devices) + + self.assertIsNot(new_config, config) + + self.assertEqual(new_config.state_rules, state_rules) + + self.assertIsInstance( + config.output_rules["layer_1_output"]["activation"], str + ) + + sum_op = new_config.output_rules["layer_1_output"]["activation"] + self.assertIsInstance(sum_op, AllReduceKeras) + self.assertEqual(sum_op.op, "sum") + self.assertEqual(sum_op.world_size, world_size) + + gather_op = new_config.output_rules["layer_2_output"]["activation"] + self.assertIsInstance(gather_op, AllGatherKeras) + self.assertEqual(gather_op.dim, -1) + self.assertEqual(gather_op.world_size, world_size) diff --git a/keras/src/distribution/tensor_parallel/state_action_keras.py b/keras/src/distribution/tensor_parallel/state_action_keras.py new file mode 100644 index 000000000000..e670020b9db7 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/state_action_keras.py @@ -0,0 +1,146 @@ +from typing import Any +from typing import Sequence + +import keras + + +class StateActionKeras: + """ + Abstract base class for actions that transform tensors for distribution. + + An action defines how a tensor should be processed for a specific worker + (rank) and how to reverse that action to reconstruct the original tensor. + """ + + def __call__(self, tensor: Any, rank: int) -> Any: + """ + Apply the state action to a tensor for a given worker rank. + + Args: + tensor: The input tensor to transform. + rank: The rank of the worker process. + + Returns: + The transformed tensor shard for the specified rank. + """ + raise NotImplementedError + + def undo(self, tensors: Sequence[Any]) -> Any: + """ + Reverse the action to reconstruct the original tensor from its parts. + + Args: + tensors: A sequence of tensor shards from all worker processes. + + Returns: + The reconstructed, original tensor. + """ + raise NotImplementedError + + +class _ConcatenateMixin: + """A mixin class that provides a common `undo` method via concatenation.""" + + def undo(self, tensors: Sequence[Any]) -> Any: + """Concatenate a sequence of tensors along the specified dimension.""" + if self.dim == -1: + dim = keras.ops.ndim(tensors[0]) - 1 + else: + dim = self.dim + return keras.ops.concatenate(tensors, axis=dim) + + +class SplitKeras(_ConcatenateMixin, StateActionKeras): + """ + Splits a tensor into shards along a specified dimension for each worker. + + Args: + world_size: The total number of workers/shards. + dim: The dimension along which to split the tensor. If -1, the last + dimension is used. + sharding_type: If `dim` is -1, this can be 'row' (dim=0) or 'column' + (dim=1) to infer the split axis. + """ + + def __init__(self, world_size: int, dim: int, sharding_type: str = "auto"): + self.world_size = world_size + self.dim = dim + self.sharding_type = sharding_type + + if dim == -1 and sharding_type != "auto": + if sharding_type == "row": + self.dim = 0 + elif sharding_type == "column": + self.dim = 1 + + def __call__(self, tensor: Any, rank: int) -> Any: + """Splits the tensor and returns the shard corresponding to the rank.""" + if self.dim == -1: + dim = keras.ops.ndim(tensor) - 1 + else: + dim = self.dim + + total_size = tensor.shape[dim] + split_size = total_size // self.world_size + remainder = total_size % self.world_size + + start_idx = rank * split_size + min(rank, remainder) + end_idx = start_idx + split_size + (1 if rank < remainder else 0) + + slices = [slice(None)] * keras.ops.ndim(tensor) + slices[dim] = slice(start_idx, end_idx) + return tensor[tuple(slices)] + + +class GatherKeras(_ConcatenateMixin, StateActionKeras): + """ + Represents a gather operation, where tensors are collected from all ranks. + + The actual collective communication is handled by a different layer; this + class primarily serves as a placeholder to trigger that communication and + define how to undo it. + + Args: + world_size: The total number of workers. + dim: The dimension along which tensors will be concatenated in the + `undo` operation. + """ + + def __init__(self, world_size: int, dim: int): + self.world_size = world_size + self.dim = dim + + def __call__(self, tensor: Any, rank: int) -> Any: + """ + Returns the tensor as-is. + + The actual gathering is performed by the communication backend. + """ + return tensor + + +class SumKeras(StateActionKeras): + """ + Represents a sum operation, where tensors are summed across all ranks. + + The actual collective communication (AllReduce) is handled by a different + layer. This class triggers that operation and defines the `undo` logic. + + Args: + world_size: The total number of workers. + """ + + def __init__(self, world_size: int): + self.world_size = world_size + + def __call__(self, tensor: Any, rank: int) -> Any: + """ + Returns the tensor as-is. + + The actual summing is performed by the communication backend. + """ + return tensor + + def undo(self, tensors: Sequence[Any]) -> Any: + """Sums the collected tensors from all workers.""" + return sum(tensors) diff --git a/keras/src/distribution/tensor_parallel/state_action_keras_test.py b/keras/src/distribution/tensor_parallel/state_action_keras_test.py new file mode 100644 index 000000000000..4db0c035041a --- /dev/null +++ b/keras/src/distribution/tensor_parallel/state_action_keras_test.py @@ -0,0 +1,108 @@ +import pytest + +import keras +from keras.src import testing +from keras.src.distribution.tensor_parallel.state_action_keras import ( + GatherKeras, +) +from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras +from keras.src.distribution.tensor_parallel.state_action_keras import SumKeras + + +@pytest.mark.skipif( + keras.backend.backend() != "jax", + reason="This test suite requires a real JAX distributed backend.", +) +class TestStateActions(testing.TestCase): + """Test suite for tensor distribution state actions.""" + + def test_split_keras_even_split(self): + """Tests SplitKeras with a tensor that divides evenly.""" + world_size = 4 + tensor = keras.ops.reshape( + keras.ops.arange(16, dtype="float32"), (4, 4) + ) + + action_row = SplitKeras(world_size=world_size, dim=0) + shards_row = [action_row(tensor, rank=i) for i in range(world_size)] + + self.assertEqual(shards_row[0].shape, (1, 4)) + self.assertAllClose(shards_row[0], tensor[0:1, :]) + self.assertAllClose(shards_row[3], tensor[3:4, :]) + + reconstructed_row = action_row.undo(shards_row) + self.assertAllClose(reconstructed_row, tensor) + + action_col = SplitKeras(world_size=world_size, dim=1) + shards_col = [action_col(tensor, rank=i) for i in range(world_size)] + + self.assertEqual(shards_col[0].shape, (4, 1)) + self.assertAllClose(shards_col[0], tensor[:, 0:1]) + self.assertAllClose(shards_col[2], tensor[:, 2:3]) + + reconstructed_col = action_col.undo(shards_col) + self.assertAllClose(reconstructed_col, tensor) + + def test_split_keras_uneven_split(self): + """Tests SplitKeras with a tensor that does not divide evenly.""" + world_size = 3 + tensor = keras.ops.reshape( + keras.ops.arange(40, dtype="float32"), (4, 10) + ) + + action = SplitKeras(world_size=world_size, dim=1) + shards = [action(tensor, rank=i) for i in range(world_size)] + + self.assertEqual(shards[0].shape, (4, 4)) + self.assertEqual(shards[1].shape, (4, 3)) + self.assertEqual(shards[2].shape, (4, 3)) + + self.assertAllClose(shards[0], tensor[:, 0:4]) + self.assertAllClose(shards[1], tensor[:, 4:7]) + self.assertAllClose(shards[2], tensor[:, 7:10]) + + reconstructed = action.undo(shards) + self.assertAllClose(reconstructed, tensor) + + def test_split_keras_sharding_type_inference(self): + """Tests that `sharding_type` correctly infers the split dimension.""" + action_row = SplitKeras(world_size=2, dim=-1, sharding_type="row") + self.assertEqual(action_row.dim, 0) + + action_col = SplitKeras(world_size=2, dim=-1, sharding_type="column") + self.assertEqual(action_col.dim, 1) + + def test_gather_keras(self): + """Tests the GatherKeras action.""" + world_size = 4 + action = GatherKeras(world_size=world_size, dim=0) + tensor = keras.ops.array([[1, 2], [3, 4]], dtype="float32") + + processed_tensor = action(tensor, rank=0) + self.assertAllClose(processed_tensor, tensor) + + tensors_to_gather = [ + keras.ops.ones((2, 2)), + keras.ops.zeros((2, 2)), + keras.ops.ones((2, 2)), + ] + reconstructed = action.undo(tensors_to_gather) + expected = keras.ops.concatenate(tensors_to_gather, axis=0) + self.assertAllClose(reconstructed, expected) + + def test_sum_keras(self): + """Tests the SumKeras action.""" + world_size = 2 + action = SumKeras(world_size=world_size) + tensor = keras.ops.array([[1, 2], [3, 4]], dtype="float32") + + processed_tensor = action(tensor, rank=0) + self.assertAllClose(processed_tensor, tensor) + + tensors_to_sum = [ + keras.ops.full((2, 3), 5.0), + keras.ops.full((2, 3), 10.0), + ] + reconstructed = action.undo(tensors_to_sum) + expected = keras.ops.full((2, 3), 15.0) + self.assertAllClose(reconstructed, expected)