diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py new file mode 100644 index 000000000000..32d6734860cc --- /dev/null +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -0,0 +1,291 @@ +from typing import Any +from typing import Dict +from typing import Sequence +from typing import Set + +from keras.src.distribution.tensor_parallel.config import ConfigKeras +from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras + + +def analyze_dense_layer_directly(layer, module, prefix: str) -> str: + """Analyzes a Keras Dense layer to classify its sharding strategy. + + This function inspects the input and output dimensions of a Dense layer + to determine if it functions as an expansion layer ("up-projection"), a + contraction layer ("down-projection"), or neither ("generic_dense"). This + classification is a heuristic commonly used to apply tensor parallelism + in Transformer-based models, such as in an MLP block where an up-projection + is followed by a down-projection. + + Args: + layer: The Keras `layers.Dense` instance to analyze. + module: The parent module containing the layer (currently unused). + prefix (str): The name prefix for the layer in the model hierarchy + (currently unused). + + Returns: + str: A string classifying the layer as 'up_projection', + 'down_projection', or 'generic_dense'. + """ + from keras.src import layers + + if not isinstance(layer, layers.Dense): + return "generic_dense" + + input_dim = None + output_dim = None + + if hasattr(layer, "kernel") and layer.kernel is not None: + kernel_shape = layer.kernel.shape + if len(kernel_shape) == 2: + input_dim = kernel_shape[0] + output_dim = kernel_shape[1] + + if input_dim is None or output_dim is None: + if hasattr(layer, "units"): + output_dim = layer.units + else: + return "generic_dense" + + if ( + hasattr(layer, "input_shape") + and layer.input_shape + and len(layer.input_shape) > 1 + ): + input_dim = layer.input_shape[-1] + else: + return "generic_dense" + + if not input_dim or not output_dim: + return "generic_dense" + + expansion_threshold = 1.5 + is_expansion = output_dim > input_dim * expansion_threshold + is_contraction = input_dim > output_dim * expansion_threshold + + if is_expansion: + return "up_projection" + elif is_contraction: + return "down_projection" + else: + return "generic_dense" + + +def _find_and_shard_layers( + current_layer, + prefix: str, + module, + world_size: int, + state_rules: Dict[str, Any], + output_rules: Dict[str, Any], + processed_layers: Set[int], +): + """Recursively traverses the model graph to apply sharding rules. + + This function walks through all nested layers of a given Keras model or + layer. For each encountered layer, it determines an appropriate tensor + parallelism strategy and populates the `state_rules` and `output_rules` + dictionaries with the corresponding sharding actions. It uses a set of + processed layer IDs to avoid redundant processing of shared layers. + + The sharding logic is as follows: + - `Dense` layers are sharded based on their classification (up/down proj). + - Up-projections are split along the column axis (output features). + - Down-projections are split along the row axis (input features). + - `EinsumDense` layers in attention blocks are sharded similarly. + - `Embedding` layers are sharded column-wise for vocabulary parallelism. + - Normalization layers are ignored (replicated on all devices). + + Args: + current_layer: The Keras layer currently being processed. + prefix (str): The hierarchical name prefix for the `current_layer`. + module: The top-level Keras model or layer being configured. + world_size (int): The total number of devices for sharding. + state_rules (Dict[str, Any]): A dictionary with rules for + sharding layer weights (state). Keys are regex patterns matching + weight names, values are `SplitKeras` actions. + output_rules (Dict[str, Any]): A dictionary with rules + for handling layer outputs. Keys are regex patterns matching layer + names, values describe the communication op (e.g., 'allreduce'). + processed_layers (Set[int]): A set of `id()`s of layers that have + already been processed to prevent cycles and redundant work. + """ + from keras.src import layers + + if id(current_layer) in processed_layers: + return + processed_layers.add(id(current_layer)) + + name = current_layer.name + full_name = f"{prefix}.{name}" if prefix else name + + if isinstance(current_layer, layers.Dense): + mlp_type = analyze_dense_layer_directly( + current_layer, module, full_name + ) + + if mlp_type == "up_projection": + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 1, "column" + ) + if current_layer.use_bias: + state_rules[f"^{full_name}.bias$"] = SplitKeras( + world_size, 0, "column" + ) + output_rules[f"^{full_name}$"] = {0: "gather"} + + elif mlp_type == "down_projection": + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 0, "row" + ) + output_rules[f"^{full_name}$"] = {0: "allreduce"} + + else: + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 1, "column" + ) + if current_layer.use_bias: + state_rules[f"^{full_name}.bias$"] = SplitKeras( + world_size, 0, "column" + ) + output_rules[f"^{full_name}$"] = {0: "gather -1"} + return + + elif isinstance(current_layer, layers.EinsumDense): + if "attention_output" in full_name: + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 0, "row" + ) + if ( + hasattr(current_layer, "bias") + and current_layer.bias is not None + ): + pass + output_rules[f"^{full_name}$"] = {0: "allreduce"} + else: + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 1, "column" + ) + if ( + hasattr(current_layer, "bias") + and current_layer.bias is not None + ): + state_rules[f"^{full_name}.bias$"] = SplitKeras( + world_size, 0, "column" + ) + output_rules[f"^{full_name}$"] = {0: "gather -1"} + return + + elif isinstance(current_layer, (layers.Embedding,)): + if hasattr(current_layer, "token_embedding") or hasattr( + current_layer, "position_embedding" + ): + pass + else: + weight_name = None + if hasattr(current_layer, "embeddings"): + weight_name = "embeddings" + elif hasattr(current_layer, "position_embeddings"): + weight_name = "position_embeddings" + + if weight_name: + state_rules[f"^{full_name}\\..*{weight_name}$"] = SplitKeras( + world_size, 1, "column" + ) + output_rules[f"^{full_name}$"] = {0: "no_comm"} + return + + elif isinstance( + current_layer, + ( + layers.LayerNormalization, + layers.BatchNormalization, + layers.GroupNormalization, + ), + ): + return + + if hasattr(current_layer, "layers") and current_layer.layers: + for sub_layer in current_layer.layers: + _find_and_shard_layers( + sub_layer, + full_name, + module, + world_size, + state_rules, + output_rules, + processed_layers, + ) + + for attr_name in dir(current_layer): + if attr_name.startswith("__") and attr_name.endswith("__"): + continue + if hasattr(current_layer, attr_name): + attr = getattr(current_layer, attr_name) + + if isinstance(attr, layers.Layer) and attr is not current_layer: + _find_and_shard_layers( + attr, + full_name, + module, + world_size, + state_rules, + output_rules, + processed_layers, + ) + elif isinstance(attr, (list, tuple)): + for item in attr: + if isinstance(item, layers.Layer): + _find_and_shard_layers( + item, + full_name, + module, + world_size, + state_rules, + output_rules, + processed_layers, + ) + + +def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: + """Generates a default tensor parallel sharding configuration for a model. + + This function serves as entry point for automatically creating a sharding + plan for a given Keras model or layer. It initializes the rule dictionaries + and starts the recursive layer traversal to populate them based on a default + set of heuristics for common architectures like Transformers. + + Example: + ```python + model = MyTransformerModel() + device_ids = ["gpu:0", "gpu:1"] + sharding_config = get_default_config_keras(model, device_ids) + # sharding_config can now be used to distribute the model + ``` + + Args: + module: The Keras `Model` or `Layer` to generate a config for. + device_ids (Sequence[str]): A sequence of device IDs (e.g., + ["gpu:0", "gpu:1"]) across which the model will be sharded. + + Returns: + ConfigKeras: A configuration object containing the generated sharding + rules for model weights (`state_rules`) and layer outputs + (`output_rules`). + """ + world_size = len(device_ids) + state_rules = {} + output_rules = {} + processed_layers = set() + + _find_and_shard_layers( + current_layer=module, + prefix="", + module=module, + world_size=world_size, + state_rules=state_rules, + output_rules=output_rules, + processed_layers=processed_layers, + ) + + return ConfigKeras(state_rules=state_rules, output_rules=output_rules) diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py new file mode 100644 index 000000000000..228a2b184569 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -0,0 +1,241 @@ +import os + +import pytest + +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" + +from keras import Input +from keras import Model +from keras import layers +from keras.src import backend +from keras.src import testing +from keras.src.distribution import distributed_backend +from keras.src.distribution.tensor_parallel.autoconfig import ( + analyze_dense_layer_directly, +) +from keras.src.distribution.tensor_parallel.autoconfig import ( + get_default_config_keras, +) +from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras + + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="Tensor Parallelism autoconfig tests are only for the JAX backend.", +) +class TestAutoConfigKeras(testing.TestCase): + def setUp(self): + """Set up the test case and common variables.""" + super().setUp() + device_info = distributed_backend.get_device_info() + self.world_size = device_info["device_count"] + self.device_ids = [f"cpu:{i}" for i in range(self.world_size)] + + self.assertGreater( + self.world_size, 1, "Distribution tests require more than 1 device." + ) + + def _assert_split_keras_equal(self, rule1, rule2): + """Helper to compare two SplitKeras objects by their attributes.""" + self.assertIsInstance(rule1, SplitKeras) + self.assertIsInstance(rule2, SplitKeras) + self.assertDictEqual(vars(rule1), vars(rule2)) + + def _assert_rules_equal(self, actual_rules, expected_rules): + """Helper to compare two dictionaries of sharding rules.""" + self.assertSetEqual( + set(actual_rules.keys()), set(expected_rules.keys()) + ) + for key in expected_rules: + actual_val = actual_rules[key] + expected_val = expected_rules[key] + if isinstance(expected_val, SplitKeras): + self._assert_split_keras_equal(actual_val, expected_val) + else: + self.assertEqual(actual_val, expected_val) + + def test_analyze_dense_layer(self): + """Tests the direct analysis and classification of Dense layers.""" + up_proj_layer = layers.Dense(32) + up_proj_layer.build(input_shape=(None, 16)) + self.assertEqual( + analyze_dense_layer_directly(up_proj_layer, None, ""), + "up_projection", + ) + + down_proj_layer = layers.Dense(16) + down_proj_layer.build(input_shape=(None, 32)) + self.assertEqual( + analyze_dense_layer_directly(down_proj_layer, None, ""), + "down_projection", + ) + + generic_layer = layers.Dense(20) + generic_layer.build(input_shape=(None, 16)) + self.assertEqual( + analyze_dense_layer_directly(generic_layer, None, ""), + "generic_dense", + ) + + def test_simple_mlp_sharding(self): + """Tests a simple MLP with up and down projection layers.""" + inputs = Input(shape=(64,)) + x = layers.Dense(256, name="up_projection_layer", use_bias=True)(inputs) + outputs = layers.Dense(64, name="down_projection_layer", use_bias=True)( + x + ) + model = Model(inputs=inputs, outputs=outputs, name="simple_mlp") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + r"^simple_mlp.up_projection_layer.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^simple_mlp.up_projection_layer.bias$": SplitKeras( + self.world_size, 0, "column" + ), + r"^simple_mlp.down_projection_layer.kernel$": SplitKeras( + self.world_size, 0, "row" + ), + } + expected_output_rules = { + r"^simple_mlp.up_projection_layer$": {0: "gather"}, + r"^simple_mlp.down_projection_layer$": {0: "allreduce"}, + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + + def test_generic_dense_sharding(self): + """Tests a generic Dense layer that isn't an up/down projection.""" + inputs = Input(shape=(64,)) + outputs = layers.Dense(80, name="generic_layer", use_bias=True)(inputs) + model = Model(inputs=inputs, outputs=outputs, name="generic_model") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + r"^generic_model.generic_layer.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^generic_model.generic_layer.bias$": SplitKeras( + self.world_size, 0, "column" + ), + } + expected_output_rules = { + r"^generic_model.generic_layer$": {0: "gather -1"} + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + + def test_embedding_sharding(self): + """Tests an Embedding layer for vocabulary parallelism.""" + inputs = Input(shape=(10,), dtype="int32") + outputs = layers.Embedding( + input_dim=1000, output_dim=128, name="token_embedding" + )(inputs) + model = Model(inputs=inputs, outputs=outputs, name="embed_model") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + r"^embed_model.token_embedding\..*embeddings$": SplitKeras( + self.world_size, 1, "column" + ) + } + expected_output_rules = { + r"^embed_model.token_embedding$": {0: "no_comm"} + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + + def test_einsum_dense_sharding(self): + """Tests the special handling for EinsumDense layers.""" + inputs = Input(shape=(64,)) + x = layers.EinsumDense( + "bh,hd->bd", output_shape=128, name="query_proj" + )(inputs) + outputs = layers.EinsumDense( + "bd,dh->bh", output_shape=64, name="attention_output" + )(x) + model = Model(inputs=inputs, outputs=outputs, name="einsum_model") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + r"^einsum_model.query_proj.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^einsum_model.attention_output.kernel$": SplitKeras( + self.world_size, 0, "row" + ), + } + expected_output_rules = { + r"^einsum_model.query_proj$": {0: "gather -1"}, + r"^einsum_model.attention_output$": {0: "allreduce"}, + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + + def test_normalization_layers_ignored(self): + """Tests that normalization layers are correctly ignored.""" + inputs = Input(shape=(64,)) + x = layers.Dense(64, name="dense1", use_bias=True)(inputs) + x = layers.LayerNormalization(name="layernorm")(x) + outputs = layers.Dense(64, name="dense2", use_bias=True)(x) + model = Model(inputs=inputs, outputs=outputs, name="norm_model") + + config = get_default_config_keras(model, self.device_ids) + + for key in config.state_rules: + self.assertNotIn("layernorm", key) + for key in config.output_rules: + self.assertNotIn("layernorm", key) + + self.assertIn(r"^norm_model.dense1.kernel$", config.state_rules) + self.assertIn(r"^norm_model.dense2.kernel$", config.state_rules) + self.assertEqual(len(config.state_rules), 4) + self.assertEqual(len(config.output_rules), 2) + + def test_nested_model_sharding(self): + """Tests that the traversal logic correctly handles nested models.""" + inner_inputs = Input(shape=(32,)) + inner_outputs = layers.Dense(128, name="inner_dense", use_bias=True)( + inner_inputs + ) + inner_model = Model( + inputs=inner_inputs, outputs=inner_outputs, name="inner_block" + ) + + outer_inputs = Input(shape=(32,)) + x = inner_model(outer_inputs) + outer_outputs = layers.Dense(32, name="outer_dense", use_bias=True)(x) + outer_model = Model( + inputs=outer_inputs, outputs=outer_outputs, name="outer_model" + ) + + config = get_default_config_keras(outer_model, self.device_ids) + + expected_state_rules = { + r"^outer_model.inner_block.inner_dense.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^outer_model.inner_block.inner_dense.bias$": SplitKeras( + self.world_size, 0, "column" + ), + r"^outer_model.outer_dense.kernel$": SplitKeras( + self.world_size, 0, "row" + ), + } + expected_output_rules = { + r"^outer_model.inner_block.inner_dense$": {0: "gather"}, + r"^outer_model.outer_dense$": {0: "allreduce"}, + } + + self.maxDiff = None + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py new file mode 100644 index 000000000000..ca7f8e5d5fcc --- /dev/null +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -0,0 +1,614 @@ +import re +from typing import Any + +import numpy as np + +import keras +from keras.src import ops +from keras.src import optimizers +from keras.src.distribution import distributed_backend + + +class CoordinatedOptimizer: + """Manages an optimizer's state for distributed training. + + This class is an internal coordinator that handles the complexities of + sharding optimizer states across multiple devices (shards) and + synchronizing gradients according to tensor parallelism rules. It is not + intended to be used directly by the end-user but is a core component of + the `TensorParallelOptimizer`. + + Args: + base_optimizer: The Keras optimizer instance + (e.g., `keras.optimizers.Adam`) whose state will be managed. + world_size: The total number of devices/processes in the distributed + setup. + distributed_backend: The distributed communication backend to use. + Defaults to "auto". + rank: The rank of the current process. Defaults to 0. + shard_optimizer_states: If `True`, the optimizer's state variables + (e.g., momentum, velocity) will be partitioned across `world_size` + devices. Defaults to `True`. + tensor_parallel_config: An optional configuration object that defines + rules for tensor parallelism, such as which gradients to + all-reduce. Defaults to `None`. + """ + + def __init__( + self, + base_optimizer: optimizers.Optimizer, + world_size: int, + distributed_backend: str = "auto", + rank: int = 0, + shard_optimizer_states: bool = True, + tensor_parallel_config=None, + ): + self.base_optimizer = base_optimizer + self.world_size = world_size + self.shard_optimizer_states = shard_optimizer_states + self.tensor_parallel_config = tensor_parallel_config + self.sharded_states = {} + self._state_variable_to_parameter = {} + self._variables = None + self._variable_to_slot_name = {} + + def _initialize_sharded_states(self): + """ + Partitions the optimizer's state variables across shards by inspecting + the variables created by the base optimizer. + """ + if not self.shard_optimizer_states or not self.base_optimizer.built: + return + + self.sharded_states = {} + self._state_variable_to_parameter = {} + self._variable_to_slot_name = {} + opt_name = self.base_optimizer.name + + normalized_params = sorted( + [(p.path.replace("/", "_"), p) for p in self._variables], + key=lambda x: len(x[0]), + reverse=True, + ) + + for state_var in self.base_optimizer.variables: + if state_var is self.base_optimizer.iterations: + continue + + path_parts = state_var.path.split("/") + if len(path_parts) != 2 or path_parts[0] != opt_name: + continue + + state_suffix = path_parts[1] + + found_param = None + slot_name = None + for norm_param_path, param in normalized_params: + if state_suffix.startswith(norm_param_path): + found_param = param + slot_suffix = state_suffix[len(norm_param_path) :] + slot_name = slot_suffix.strip("_") + break + + if found_param is not None and slot_name is not None: + self._state_variable_to_parameter[state_var.path] = found_param + self._variable_to_slot_name[state_var.path] = slot_name + + sharding_dim = 0 + if self.tensor_parallel_config: + norm_param_name = found_param.path.replace("/", ".") + for p, a in self.tensor_parallel_config.state_rules.items(): + if re.search(p, norm_param_name) and hasattr(a, "dim"): + sharding_dim = a.dim + break + + partitioned_state = self._partition_state( + state_var, dim=sharding_dim + ) + self.sharded_states.setdefault(slot_name, {})[ + found_param.path + ] = partitioned_state + + if self.base_optimizer.iterations is not None: + self.sharded_states["iterations"] = self._partition_state( + self.base_optimizer.iterations, dim=0 + ) + + def _partition_state( + self, state_variable: any, dim: int + ) -> list[np.ndarray]: + """Splits a single state variable numpy array into chunks. + + If the variable cannot be split along the given dimension, it is + replicated across all shards. + + Args: + state_variable: The optimizer state variable. + dim: The dimension along which to partition the variable. + + Returns: + A list of NumPy arrays, where each array is a partition of the + original state variable for a specific shard. + """ + state_array = ops.convert_to_numpy(state_variable) + if state_array.ndim > dim and state_array.shape[dim] >= self.world_size: + return np.array_split(state_array, self.world_size, axis=dim) + else: + return [np.copy(state_array) for _ in range(self.world_size)] + + def get_config(self) -> dict[str, Any]: + return { + "base_optimizer": self.base_optimizer.get_config(), + "world_size": self.world_size, + "shard_optimizer_states": self.shard_optimizer_states, + } + + def apply_gradients( + self, gradients_and_vars: list[list[tuple]], shard_models: list + ): + """Coordinates gradient synchronization and application. + + This method first synchronizes gradients across all shards based on + tensor parallelism rules. Then, it applies the gradients using either + sharded optimizer states or replicated states. + + Args: + gradients_and_vars: A list of lists, where each inner list contains + (gradient, variable) tuples for a specific model shard. + shard_models: A list of the sharded model instances. + + Raises: + ValueError: If the number of gradient sets does not match the + world size. + """ + if len(gradients_and_vars) != self.world_size: + error_msg = ( + f"Expected {self.world_size} gradient sets, " + f"got {len(gradients_and_vars)}" + ) + raise ValueError(error_msg) + + synchronized_gradients = self._synchronize_gradients(gradients_and_vars) + + if self.shard_optimizer_states and self.sharded_states: + self._apply_gradients_with_sharded_states( + synchronized_gradients, shard_models + ) + else: + self._apply_gradients_with_replicated_states( + synchronized_gradients, shard_models + ) + + def _apply_gradients_with_replicated_states( + self, synchronized_gradients: list[list[tuple]], shard_models: list + ): + """Averages gradients across all shards and applies them once. + + This method is used when optimizer state sharding is disabled. It + calculates the average of the gradients for each variable across all + shards and applies the averaged gradients using the single, replicated + optimizer state. + + Args: + synchronized_gradients: The gradients after synchronization. + shard_models: The list of sharded models. + """ + num_vars = len(synchronized_gradients[0]) + averaged_grads_and_vars = [] + + for i in range(num_vars): + variable = synchronized_gradients[0][i][1] + grads_for_var = [ + shard_grads[i][0] + for shard_grads in synchronized_gradients + if shard_grads[i][0] is not None + ] + + if not grads_for_var: + continue + + if len(grads_for_var) > 1: + stacked_grads = ops.stack(grads_for_var, axis=0) + averaged_grad = ops.mean(stacked_grads, axis=0) + else: + averaged_grad = grads_for_var[0] + + averaged_grads_and_vars.append((averaged_grad, variable)) + + if averaged_grads_and_vars: + self.base_optimizer.apply_gradients(averaged_grads_and_vars) + + def _apply_gradients_with_sharded_states( + self, synchronized_gradients: list[list[tuple]], shard_models: list + ): + """Applies gradients to each shard using its local optimizer state.""" + for shard_idx in range(self.world_size): + local_states = self._get_local_optimizer_states(shard_idx) + shard_optimizer = shard_models[shard_idx].optimizer + + self._update_optimizer_internal_state(shard_optimizer, local_states) + + shard_grads_and_vars = synchronized_gradients[shard_idx] + shard_optimizer.apply_gradients(shard_grads_and_vars) + + self._update_global_sharded_states(shard_optimizer, shard_idx) + + def _get_local_optimizer_states(self, shard_idx: int) -> dict[str, Any]: + """Constructs the state dictionary for a single shard.""" + local_states = {} + for state_name, state_value in self.sharded_states.items(): + if isinstance(state_value, dict): + local_states[state_name] = {} + for param_name, param_states in state_value.items(): + local_states[state_name][param_name] = param_states[ + shard_idx + ] + else: + local_states[state_name] = state_value[shard_idx] + return local_states + + def _update_optimizer_internal_state(self, optimizer, local_states: dict): + """Assigns local sharded state values to the optimizer's variables.""" + if not optimizer.built: + return + + for var in optimizer.variables: + if var is optimizer.iterations: + if "iterations" in local_states: + ops.assign(var, local_states["iterations"]) + continue + + param = self._state_variable_to_parameter.get(var.path, None) + slot_name = self._variable_to_slot_name.get(var.path) + + if ( + param + and slot_name + and slot_name in local_states + and param.path in local_states[slot_name] + ): + local_param_state = local_states[slot_name][param.path] + if var.shape == local_param_state.shape: + ops.assign(var, local_param_state) + + def _update_global_sharded_states(self, optimizer, shard_idx: int): + """Updates the main sharded_states dictionary after a gradient step.""" + if not optimizer.built: + return + + for var in optimizer.variables: + if var is optimizer.iterations: + self.sharded_states["iterations"][shard_idx] = ( + ops.convert_to_numpy(var) + ) + continue + + param = self._state_variable_to_parameter.get(var.path, None) + slot_name = self._variable_to_slot_name.get(var.path) + + if ( + param + and slot_name + and slot_name in self.sharded_states + and param.path in self.sharded_states[slot_name] + ): + self.sharded_states[slot_name][param.path][shard_idx] = ( + ops.convert_to_numpy(var) + ) + + def _synchronize_gradients( + self, gradients_and_vars: list[list[tuple]] + ) -> list[list[tuple]]: + """Synchronizes gradients across shards based on tensor parallel rules. + + Specifically, it performs an all-reduce operation on gradients of + weights that are split along a "column" dimension in tensor parallelism. + Other gradients are passed through unchanged. + + Args: + gradients_and_vars: The list of (gradient, variable) lists from + all shards. + + Returns: + The list of (gradient, variable) lists after synchronization. + """ + if not self.tensor_parallel_config: + return gradients_and_vars + + rules = self.tensor_parallel_config.state_rules.items() + column_parallel_patterns = { + pattern + for pattern, action in rules + if hasattr(action, "sharding_type") + and action.sharding_type == "column" + } + + if not column_parallel_patterns: + return gradients_and_vars + + num_weights = len(gradients_and_vars[0]) + for i in range(num_weights): + variable = gradients_and_vars[0][i][1] + var_name = getattr(variable, "path", getattr(variable, "name", "")) + + if any( + re.search(pattern, var_name) + for pattern in column_parallel_patterns + ): + grads_to_reduce = [ + g_and_v[i][0] + for g_and_v in gradients_and_vars + if g_and_v[i][0] is not None + ] + if grads_to_reduce: + synced_grad = self._allreduce_gradients(grads_to_reduce)[0] + for shard_idx in range(self.world_size): + gradients_and_vars[shard_idx][i] = ( + synced_grad, + variable, + ) + return gradients_and_vars + + def _allreduce_gradients(self, gradients: list[Any]) -> list[Any]: + """Performs a mean all-reduce operation on a list of gradients. + + If a distributed backend is available, it uses it. Otherwise, it + falls back to a local mean calculation. + + Args: + gradients: A list of gradients (one from each shard) to be averaged. + + Returns: + A list where each element is the mean of the input gradients. + """ + if not gradients: + return [] + + if distributed_backend.is_multi_device_capable(): + all_reduce_fn = distributed_backend.get_communication_ops()[ + "all_reduce" + ] + numpy_grad = ops.convert_to_numpy(gradients[0]) + synced_numpy = all_reduce_fn(numpy_grad, op="mean") + synced_tensor = ops.convert_to_tensor(synced_numpy) + return [synced_tensor for _ in range(self.world_size)] + + stacked_grads = keras.ops.stack( + [ops.convert_to_tensor(g) for g in gradients], axis=0 + ) + mean_grad = ops.mean(stacked_grads, axis=0) + return [mean_grad for _ in range(len(gradients))] + + def get_weights(self) -> list[np.ndarray]: + """Returns the weights of the base optimizer.""" + return [ + ops.convert_to_numpy(var) for var in self.base_optimizer.variables + ] + + def set_weights(self, weights: list[np.ndarray]): + """Sets the weights of the base optimizer.""" + self.base_optimizer.set_weights(weights) + + def enable_optimizer_state_sharding(self, variables: list): + """Enables and initializes optimizer state sharding. + + This method is called from `build()`, which is guarded from running + multiple times. We can assume this should always execute. + """ + self.shard_optimizer_states = True + self._variables = variables + self._initialize_sharded_states() + + def disable_optimizer_state_sharding(self): + """Disables sharding and clears any sharded states. + + This reverts the optimizer to using a single, replicated state. + """ + if self.shard_optimizer_states: + self.shard_optimizer_states = False + self.sharded_states = {} + + +class TensorParallelOptimizer(optimizers.Optimizer): + """A Keras Optimizer wrapper for tensor-parallel distributed training. + + This optimizer wraps a standard Keras optimizer (e.g., Adam, SGD) and + delegates the complex tasks of state management and gradient synchronization + to a `CoordinatedOptimizer` instance. It is designed to work with models + that have been sharded for tensor parallelism. + + When `apply_gradients` is called with a list of gradient lists (one for each + model shard), it uses the `CoordinatedOptimizer` to handle synchronization + and state sharding. Otherwise, it behaves like the base optimizer. + + Args: + base_optimizer: A Keras optimizer instance or a string identifier + (e.g., 'adam', 'sgd'). + world_size: The total number of devices/processes in the distributed + setup. + distributed_backend: The distributed communication backend to use. + Defaults to "auto". + tensor_parallel_config: An optional configuration object that defines + rules for tensor parallelism. Defaults to `None`. + + Example: + + ```python + import keras + + # Assume model variables and gradients from 4 shards exist. + # The structure is: list[list[tuple[gradient, variable]]] + trainable_vars = [keras.Variable(1.0), keras.Variable(2.0)] + sharded_grads_and_vars = [ + [(keras.ops.ones_like(v), v) for v in trainable_vars] + for _ in range(4) # 4 shards + ] + + # 1. Wrap a standard Keras optimizer. + base_optimizer = keras.optimizers.Adam() + optimizer = TensorParallelOptimizer(base_optimizer, world_size=4) + optimizer.build(trainable_vars) + + # 2. Apply the sharded gradients. + # The optimizer will handle synchronization (e.g., all-reduce) internally. + optimizer.apply_gradients(sharded_grads_and_vars) + ``` + """ + + def __init__( + self, + base_optimizer: optimizers.Optimizer, + world_size: int, + distributed_backend: str = "auto", + tensor_parallel_config=None, + ): + if isinstance(base_optimizer, str): + base_optimizer_instance = optimizers.get(base_optimizer) + else: + base_optimizer_instance = base_optimizer + + learning_rate = base_optimizer_instance.learning_rate + if callable(learning_rate): + lr_value = float(ops.convert_to_numpy(learning_rate(0))) + else: + lr_value = float(ops.convert_to_numpy(learning_rate)) + + super().__init__( + learning_rate=lr_value, + name=f"TensorParallel_{base_optimizer_instance.name}", + ) + + self.base_optimizer = base_optimizer_instance + self.world_size = world_size + self.distributed_backend = distributed_backend + self.coordinated_optimizer = CoordinatedOptimizer( + self.base_optimizer, + world_size, + distributed_backend=distributed_backend, + tensor_parallel_config=tensor_parallel_config, + ) + + def apply_gradients(self, grads_and_vars: list, **kwargs): + """Applies gradients to the model variables. + + If `grads_and_vars` is a list of lists, it's assumed to be from + sharded models, and the `CoordinatedOptimizer` is used. Otherwise, + it calls the `base_optimizer`'s `apply_gradients` directly. + + Args: + grads_and_vars: A list of (gradient, variable) tuples, or a list + of such lists if running in a sharded context. + **kwargs: Additional arguments. `shard_models` can be passed to + provide the list of model shards. + """ + is_sharded_grads = ( + isinstance(grads_and_vars, list) + and grads_and_vars + and isinstance(grads_and_vars[0], list) + ) + if is_sharded_grads: + shard_models = kwargs.get("shard_models", []) + self.coordinated_optimizer.apply_gradients( + grads_and_vars, shard_models + ) + else: + self.base_optimizer.apply_gradients(grads_and_vars) + + def get_config(self) -> dict[str, Any]: + from keras.src import saving + + config = super().get_config() + config.pop("learning_rate", None) + config.pop("name", None) + + config.update( + { + "base_optimizer": saving.serialize_keras_object( + self.base_optimizer + ), + "world_size": self.world_size, + "distributed_backend": self.distributed_backend, + } + ) + return config + + def update_step(self, gradient, variable, *args, **kwargs): + if hasattr(self.base_optimizer, "update_step"): + try: + return self.base_optimizer.update_step( + gradient, variable, *args, **kwargs + ) + except TypeError: + return self.base_optimizer.update_step(gradient, variable) + try: + return super().update_step(gradient, variable, *args, **kwargs) + except TypeError: + return super().update_step(gradient, variable) + + @classmethod + def from_config(cls, config: dict[str, Any]) -> "TensorParallelOptimizer": + from keras.src import saving + + base_optimizer_config = config.pop("base_optimizer") + base_optimizer = saving.deserialize_keras_object(base_optimizer_config) + + init_kwargs = { + "world_size": config.get("world_size"), + "distributed_backend": config.get("distributed_backend", "auto"), + "tensor_parallel_config": config.get("tensor_parallel_config"), + } + + return cls(base_optimizer=base_optimizer, **init_kwargs) + + def build(self, variables: list): + """Builds the optimizer and initializes sharded states. + + This method is called the first time the optimizer is used. It builds + the base optimizer and then triggers the `CoordinatedOptimizer` to + initialize its sharded states. + + Args: + variables: A list of model variables to be optimized. + """ + if self.built: + return + + self.base_optimizer.build(variables) + if variables: + zero_grads = [ops.zeros_like(v) for v in variables] + self.base_optimizer.apply_gradients(zip(zero_grads, variables)) + + self.coordinated_optimizer.enable_optimizer_state_sharding(variables) + super().build(variables) + + def get_weights(self) -> list[np.ndarray]: + """Returns the weights of the base optimizer.""" + return self.coordinated_optimizer.get_weights() + + def set_weights(self, weights: list[np.ndarray]): + """Sets the weights of the base optimizer.""" + self.coordinated_optimizer.set_weights(weights) + + @property + def variables(self) -> list: + """Returns the list of variables from the base optimizer.""" + return self.base_optimizer.variables + + @property + def learning_rate(self) -> Any: + """Provides access to the learning rate of the base optimizer.""" + return self.base_optimizer.learning_rate + + @learning_rate.setter + def learning_rate(self, value): + self.base_optimizer.learning_rate = value + + @property + def iterations(self): + """ + Returns the training iteration count, compensating for the initial + dummy step in the build method. + """ + if self.base_optimizer.iterations is None: + return None + return self.base_optimizer.iterations - 1 diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py new file mode 100644 index 000000000000..39cce46de72c --- /dev/null +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -0,0 +1,180 @@ +import numpy as np +import pytest + +import keras +from keras import ops +from keras.src import optimizers +from keras.src import testing + +if keras.backend.backend() == "jax": + from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + CoordinatedOptimizer, + ) + from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + TensorParallelOptimizer, + ) + + +@pytest.mark.skipif( + keras.backend.backend() != "jax", + reason="This test is JAX-specific.", +) +class CoordinatedOptimizerTest(testing.TestCase): + def _get_simple_model(self): + """Creates a simple, uncompiled Keras model.""" + inputs = keras.Input(shape=(10,)) + x = keras.layers.Dense(20, name="dense_1")(inputs) + outputs = keras.layers.Dense(5, name="dense_2")(x) + return keras.Model(inputs, outputs) + + def _get_mock_gradients_and_vars(self, model, world_size): + """Generates mock gradients and variables for N shards.""" + model.build(input_shape=(None, 10)) + variables = model.trainable_variables + grads_and_vars_per_shard = [] + for i in range(world_size): + multiplier = float(i + 1) + gradients = [ + ops.convert_to_tensor( + np.ones_like(v.numpy()) * multiplier, dtype="float32" + ) + for v in variables + ] + grads_and_vars_per_shard.append(list(zip(gradients, variables))) + return grads_and_vars_per_shard + + def test_initialization(self): + """Tests that the optimizer initializes with the correct defaults.""" + base_optimizer = optimizers.Adam() + coord = CoordinatedOptimizer(base_optimizer, world_size=4) + self.assertEqual(coord.base_optimizer, base_optimizer) + self.assertTrue(coord.shard_optimizer_states) + self.assertEqual(coord.sharded_states, {}) + + def test_apply_gradients_with_replicated_states(self): + """Tests that replicated gradients are averaged and applied once.""" + + class AdamWithCallCounter(optimizers.Adam): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.apply_gradients_call_count = 0 + self.received_grads = [] + + def apply_gradients(self, grads_and_vars, *args, **kwargs): + self.apply_gradients_call_count += 1 + self.received_grads = [g for g, v in grads_and_vars] + super().apply_gradients(grads_and_vars, *args, **kwargs) + + world_size = 4 + model = self._get_simple_model() + optimizer = AdamWithCallCounter() + model.build((None, 10)) + mock_grads = self._get_mock_gradients_and_vars(model, world_size) + + coord = CoordinatedOptimizer( + optimizer, + world_size, + shard_optimizer_states=False, + ) + coord.apply_gradients(mock_grads, []) + + self.assertEqual(optimizer.apply_gradients_call_count, 1) + self.assertAllClose( + optimizer.received_grads[0], + np.ones_like(optimizer.received_grads[0]) * 2.5, + ) + + def test_init_from_string(self): + optimizer = TensorParallelOptimizer("adam", world_size=4) + self.assertIsInstance(optimizer.base_optimizer, optimizers.Adam) + + def test_apply_gradients_delegation(self): + """Tests that apply_gradients correctly delegates.""" + world_size = 4 + base_opt = optimizers.Adam() + optimizer = TensorParallelOptimizer(base_opt, world_size) + model = self._get_simple_model() + mock_grads = self._get_mock_gradients_and_vars(model, world_size) + + coord_apply_tracker = {"called": False} + + def coord_apply_mock(*args, **kwargs): + coord_apply_tracker["called"] = True + + optimizer.coordinated_optimizer.apply_gradients = coord_apply_mock + + base_apply_tracker = {"called": False} + + def base_apply_mock(*args, **kwargs): + base_apply_tracker["called"] = True + + optimizer.base_optimizer.apply_gradients = base_apply_mock + + optimizer.apply_gradients(mock_grads, shard_models=[]) + self.assertTrue(coord_apply_tracker["called"]) + self.assertFalse(base_apply_tracker["called"]) + + coord_apply_tracker["called"] = False + unsharded_grads = mock_grads[0] + optimizer.apply_gradients(unsharded_grads) + self.assertTrue(base_apply_tracker["called"]) + self.assertFalse(coord_apply_tracker["called"]) + + def test_build_and_state_sharding(self): + """Tests that the build method correctly initializes sharded states.""" + optimizer = TensorParallelOptimizer(optimizers.Adam(), world_size=4) + model = self._get_simple_model() + model.build(input_shape=(None, 10)) + + self.assertEqual(optimizer.coordinated_optimizer.sharded_states, {}) + optimizer.build(model.trainable_variables) + self.assertTrue(optimizer.built) + + sharded_states = optimizer.coordinated_optimizer.sharded_states + self.assertIn("momentum", sharded_states) + self.assertIn("velocity", sharded_states) + self.assertIn("iterations", sharded_states) + + dense_1_kernel_path = model.get_layer("dense_1").kernel.path + self.assertIn(dense_1_kernel_path, sharded_states["momentum"]) + self.assertEqual( + len(sharded_states["momentum"][dense_1_kernel_path]), 4 + ) + + def test_serialization(self): + world_size = 4 + base_opt = optimizers.Adam(learning_rate=0.1) + optimizer = TensorParallelOptimizer( + base_opt, world_size, distributed_backend=None + ) + + config = optimizer.get_config() + recreated = TensorParallelOptimizer.from_config(config) + + self.assertEqual(recreated.world_size, world_size) + self.assertIsInstance(recreated.base_optimizer, optimizers.Adam) + self.assertIsNone(recreated.distributed_backend) + self.assertAllClose(recreated.base_optimizer.learning_rate, 0.1) + + def test_sharding_with_prefixed_variable_names(self): + """Tests that state is correctly mapped with prefixed variable names.""" + inputs = keras.Input(shape=(10,)) + x = keras.layers.Dense(4, name="dense")(inputs) + outputs = keras.layers.Dense(2, name="dense_output")(x) + model = keras.Model(inputs, outputs) + model.build(input_shape=(None, 10)) + + optimizer = TensorParallelOptimizer(optimizers.Adam(), world_size=2) + optimizer.build(model.trainable_variables) + + state_to_param = ( + optimizer.coordinated_optimizer._state_variable_to_parameter + ) + self.assertGreater(len(state_to_param), 0) + + dense_output_kernel = model.get_layer("dense_output").kernel + optimizer_name = optimizer.base_optimizer.name + kernel_path = dense_output_kernel.path.replace("/", "_") + momentum_path = f"{optimizer_name}/{kernel_path}_momentum" + + self.assertIs(state_to_param[momentum_path], dense_output_kernel) diff --git a/keras/src/distribution/tensor_parallel/sharding_keras.py b/keras/src/distribution/tensor_parallel/sharding_keras.py new file mode 100644 index 000000000000..012234cb77f4 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/sharding_keras.py @@ -0,0 +1,82 @@ +from typing import Any +from typing import Collection +from typing import Dict +from typing import List +from typing import Sequence + +from keras.src.distribution.tensor_parallel.config import ConfigKeras + + +class ShardedKeras: + """ + Manages sharded parameters for Keras models. + """ + + def __init__( + self, + model_shards, + replicated_param_names: Collection[str], + tensor_parallel_config: ConfigKeras, + devices: Sequence[str], + output_device_index: int, + ): + """ + Initialize the sharding manager. + + Args: + model_shards: List of model shards + replicated_param_names: Names of parameters that are replicated + tensor_parallel_config: Tensor parallel configuration + devices: List of device IDs + output_device_index: Index of the output device + """ + self.model_shards = model_shards + self.replicated_param_names = set(replicated_param_names) + self.tensor_parallel_config = tensor_parallel_config + self.devices = devices + self.output_device_index = output_device_index + + def get_shard_parameters(self, shard_index: int) -> Dict[str, Any]: + """ + Get parameters for a specific shard. + + Args: + shard_index: Index of the shard + + Returns: + Dictionary of parameter names to values + """ + if shard_index >= len(self.model_shards): + raise ValueError(f"Shard index {shard_index} out of range") + + shard = self.model_shards[shard_index] + params = {} + + for weight in shard.weights: + param_name = weight.path.replace("/", ".") + params[param_name] = weight + + return params + + def get_all_parameters(self) -> List[Dict[str, Any]]: + """ + Get parameters from all shards. + + Returns: + List of parameter dictionaries for each shard + """ + return [ + self.get_shard_parameters(i) for i in range(len(self.model_shards)) + ] + + def apply_sharding(self): + """ + Apply sharding to the model parameters. + """ + pass + + def unshard_parameters(self): + """ + Unshard parameters back to their original form. + """ + pass