From dd3181eceb093827b088fd017932d71ff85d06bf Mon Sep 17 00:00:00 2001 From: Suhana Date: Wed, 1 Oct 2025 15:59:38 +0530 Subject: [PATCH 1/9] adding autoconfig and coordinated_optimizer --- .../tensor_parallel/autoconfig.py | 222 ++++++ .../tensor_parallel/autoconfig_test.py | 146 ++++ .../tensor_parallel/coordinated_optimizer.py | 646 ++++++++++++++++++ .../coordinated_optimizer_test.py | 154 +++++ 4 files changed, 1168 insertions(+) create mode 100644 keras/src/distribution/tensor_parallel/autoconfig.py create mode 100644 keras/src/distribution/tensor_parallel/autoconfig_test.py create mode 100644 keras/src/distribution/tensor_parallel/coordinated_optimizer.py create mode 100644 keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py new file mode 100644 index 000000000000..b1c0bb9d5e19 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -0,0 +1,222 @@ +from typing import Sequence + +from keras.src import layers +from keras.src.distribution.tensor_parallel.config import ConfigKeras +from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras +from keras.src.models import Model + + +def analyze_dense_layer_directly( + layer: layers.Dense, module: Model, prefix: str +) -> str: + """Analyzes a Dense layer to classify it for tensor parallelism sharding. + + This function inspects the layer's weight shapes to determine if it's an + "up-projection" (expanding feature dimensions), a "down-projection" + (contracting feature dimensions), or a generic layer. This classification + helps in deciding whether to apply column-wise or row-wise parallelism. + + Args: + layer: The keras.layers.Dense instance to analyze. + module: The parent Keras model containing the layer. + prefix: The hierarchical name prefix for the layer. + + Returns: + A string indicating the layer's classification: 'up_projection', + 'down_projection', or 'generic_dense'. + """ + if not isinstance(layer, layers.Dense): + return "generic_dense" + + input_dim = None + output_dim = None + + if hasattr(layer, "kernel"): + kernel_shape = layer.kernel.shape + if len(kernel_shape) == 2: + input_dim = kernel_shape[0] + output_dim = kernel_shape[1] + else: + if hasattr(layer, "units"): + output_dim = layer.units + + if ( + hasattr(layer, "input_shape") + and layer.input_shape + and len(layer.input_shape) > 1 + ): + input_dim = layer.input_shape[-1] + + if not input_dim or not output_dim: + return "generic_dense" + + expansion_threshold = 1.5 + is_expansion = output_dim > input_dim * expansion_threshold + is_contraction = input_dim > output_dim * expansion_threshold + + if is_expansion: + return "up_projection" + elif is_contraction: + return "down_projection" + else: + return "generic_dense" + + +def _traverse_and_shard_layer( + current_layer: layers.Layer, + module: Model, + world_size: int, + state_rules: dict, + output_rules: dict, + processed_layers: set, + prefix: str = "", +): + """Traverses a layer and its sub-layers to apply sharding rules. + + This function navigates through the model's layer hierarchy. For each + layer, it identifies its type and applies appropriate sharding logic, + populating the `state_rules` and `output_rules` dictionaries. + + Args: + current_layer: The current keras.Layer object to be processed. + module: The top-level Keras Model, used for context analysis. + world_size: The total number of devices for sharding. + state_rules: The dictionary of state sharding rules to populate. + output_rules: The dictionary of output sharding rules to populate. + processed_layers: A set of layer IDs that have already been processed + to avoid redundant computation and infinite loops. + prefix: The hierarchical name prefix from parent layers, used to + construct the full unique name for the current layer. + """ + if id(current_layer) in processed_layers: + return + processed_layers.add(id(current_layer)) + + name = current_layer.name + full_name = f"{prefix}.{name}" if prefix else name + + if isinstance(current_layer, layers.Dense): + mlp_type = analyze_dense_layer_directly( + current_layer, module, full_name + ) + + if mlp_type == "down_projection": + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 0, "row" + ) + output_rules[f"^{full_name}$"] = {0: "allreduce"} + + else: + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 1, "column" + ) + if current_layer.use_bias: + state_rules[f"^{full_name}.bias$"] = SplitKeras( + world_size, 0, "column" + ) + output_rules[f"^{full_name}$"] = {0: "no_comm"} + return + + elif isinstance(current_layer, layers.EinsumDense): + is_row_parallel = False + if "->" in current_layer.equation: + equation_parts = current_layer.equation.split("->") + if len(equation_parts) == 2: + input_spec = equation_parts[0].split(",")[0].strip() + output_spec = equation_parts[1].strip() + if ( + input_spec + and output_spec + and len(output_spec) < len(input_spec) + ): + is_row_parallel = True + + if is_row_parallel: + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 0, "row" + ) + output_rules[f"^{full_name}$"] = {0: "allreduce"} + else: + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 1, "column" + ) + if ( + hasattr(current_layer, "bias") + and current_layer.bias is not None + ): + state_rules[f"^{full_name}.bias$"] = SplitKeras( + world_size, 0, "column" + ) + output_rules[f"^{full_name}$"] = {0: "no_comm"} + return + + elif isinstance(current_layer, layers.Embedding): + weight_name = ( + "embeddings" if hasattr(current_layer, "embeddings") else None + ) + if weight_name: + state_rules[f"^{full_name}\.{weight_name}$"] = SplitKeras( + world_size, 1, "column" + ) + output_rules[f"^{full_name}$"] = {0: "no_comm"} + return + + elif isinstance( + current_layer, + ( + layers.LayerNormalization, + layers.BatchNormalization, + layers.GroupNormalization, + ), + ): + return + else: + if hasattr(current_layer, "layers"): + for sub_layer in current_layer.layers: + _traverse_and_shard_layer( + sub_layer, + module, + world_size, + state_rules, + output_rules, + processed_layers, + full_name, + ) + + +def get_default_config_keras( + module: Model, device_ids: Sequence[str] +) -> ConfigKeras: + """Generates a smart, recursive sharding configuration for a Keras model. + + This function traverses the layers of a given Keras model and applies a + set of heuristics to automatically determine how each layer's weights + and outputs should be sharded for tensor parallelism. It uses a helper + function to perform the recursive traversal. + + Args: + module: The Keras Model to generate a sharding configuration for. + device_ids: A sequence of device identifiers, used to determine the + world size (number of devices) for sharding. + + Returns: + A ConfigKeras object containing the generated 'state_rules' (for model + parameters) and 'output_rules' (for layer outputs). + """ + world_size = len(device_ids) + state_rules = {} + output_rules = {} + processed_layers = set() + + for layer in module.layers: + _traverse_and_shard_layer( + current_layer=layer, + module=module, + world_size=world_size, + state_rules=state_rules, + output_rules=output_rules, + processed_layers=processed_layers, + prefix="", + ) + + return ConfigKeras(state_rules=state_rules, output_rules=output_rules) \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py new file mode 100644 index 000000000000..9b955f00525b --- /dev/null +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -0,0 +1,146 @@ +import os + +if "WORLD_SIZE" not in os.environ: + os.environ["WORLD_SIZE"] = "4" + +from keras import Input +from keras import Model +from keras import layers +from keras.src import testing +from keras.src.distribution.tensor_parallel.autoconfig import ( + analyze_dense_layer_directly, +) +from keras.src.distribution.tensor_parallel.autoconfig import ( + get_default_config_keras, +) +from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras + + +class TestAutoConfigKeras(testing.TestCase): + def setUp(self): + """Set up the test case and common variables.""" + super().setUp() + self.world_size = int(os.environ["WORLD_SIZE"]) + self.device_ids = [f"device:{i}" for i in range(self.world_size)] + + def _assert_split_keras_equal(self, rule1, rule2): + """ + Helper to compare two SplitKeras objects by their attributes. + MODIFIED: Use vars() for robust comparison without knowing attr names. + """ + self.assertIsInstance(rule1, SplitKeras) + self.assertIsInstance(rule2, SplitKeras) + self.assertDictEqual(vars(rule1), vars(rule2)) + + def _assert_rules_equal(self, actual_rules, expected_rules): + """Helper to compare two dictionaries of sharding rules.""" + self.assertSetEqual( + set(actual_rules.keys()), set(expected_rules.keys()) + ) + for key in expected_rules: + actual_val = actual_rules[key] + expected_val = expected_rules[key] + if isinstance(expected_val, SplitKeras): + self._assert_split_keras_equal(actual_val, expected_val) + else: + self.assertEqual(actual_val, expected_val) + + def test_analyze_dense_layer(self): + """Tests the direct analysis and classification of Dense layers.""" + up_proj_layer = layers.Dense(32) + up_proj_layer.build(input_shape=(None, 16)) + self.assertEqual( + analyze_dense_layer_directly(up_proj_layer, None, ""), + "up_projection", + ) + + down_proj_layer = layers.Dense(16) + down_proj_layer.build(input_shape=(None, 32)) + self.assertEqual( + analyze_dense_layer_directly(down_proj_layer, None, ""), + "down_projection", + ) + + def test_simple_mlp_sharding(self): + """Tests a simple MLP with up and down projection layers.""" + inputs = Input(shape=(64,)) + x = layers.Dense(256, name="up_projection_layer", use_bias=True)(inputs) + outputs = layers.Dense( + 64, name="down_projection_layer", use_bias=False + )(x) + model = Model(inputs=inputs, outputs=outputs, name="simple_mlp") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + r"^up_projection_layer.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^up_projection_layer.bias$": SplitKeras( + self.world_size, 0, "column" + ), + r"^down_projection_layer.kernel$": SplitKeras( + self.world_size, 0, "row" + ), + } + expected_output_rules = { + r"^up_projection_layer$": {0: "no_comm"}, + r"^down_projection_layer$": {0: "allreduce"}, + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + + def test_embedding_sharding(self): + """Tests an Embedding layer.""" + inputs = Input(shape=(10,), dtype="int32") + outputs = layers.Embedding( + input_dim=1000, output_dim=128, name="token_embedding" + )(inputs) + model = Model(inputs=inputs, outputs=outputs, name="embed_model") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + r"^token_embedding\.embeddings$": SplitKeras( + self.world_size, 1, "column" + ) + } + expected_output_rules = {r"^token_embedding$": {0: "no_comm"}} + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + + def test_nested_model_sharding(self): + """Tests that the traversal logic correctly handles nested models.""" + inner_inputs = Input(shape=(32,)) + inner_outputs = layers.Dense(128, name="inner_dense")(inner_inputs) + inner_model = Model( + inputs=inner_inputs, outputs=inner_outputs, name="inner_block" + ) + + outer_inputs = Input(shape=(32,)) + x = inner_model(outer_inputs) + outer_outputs = layers.Dense(32, name="outer_dense")(x) + outer_model = Model( + inputs=outer_inputs, outputs=outer_outputs, name="outer_model" + ) + + config = get_default_config_keras(outer_model, self.device_ids) + + expected_state_rules = { + r"^inner_block.inner_dense.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^inner_block.inner_dense.bias$": SplitKeras( + self.world_size, 0, "column" + ), + r"^outer_dense.kernel$": SplitKeras(self.world_size, 0, "row"), + } + expected_output_rules = { + r"^inner_block.inner_dense$": {0: "no_comm"}, + r"^outer_dense$": {0: "allreduce"}, + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py new file mode 100644 index 000000000000..77e5c13629b6 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -0,0 +1,646 @@ +import re +from typing import Any +from typing import Dict +from typing import List + +import numpy as np + +import keras +from keras.src import ops +from keras.src import optimizers +from keras.src.backend.distributed import backend_resolver + + +class CoordinatedOptimizer: + """Manages an optimizer's state for distributed training. + + This class is an internal coordinator that handles the complexities of + sharding optimizer states across multiple devices (shards) and + synchronizing gradients according to tensor parallelism rules. It is not + intended to be used directly by the end-user but is a core component of + the `TensorParallelOptimizer`. + + Args: + base_optimizer: The Keras optimizer instance + (e.g., `keras.optimizers.Adam`) whose state will be managed. + world_size: The total number of devices/processes in the distributed + setup. + distributed_backend: The distributed communication backend to use. + Defaults to "auto". + rank: The rank of the current process. Defaults to 0. + shard_optimizer_states: If `True`, the optimizer's state variables + (e.g., momentum, velocity) will be partitioned across `world_size` + devices. Defaults to `True`. + tensor_parallel_config: An optional configuration object that defines + rules for tensor parallelism, such as which gradients to + all-reduce. Defaults to `None`. + """ + + def __init__( + self, + base_optimizer: optimizers.Optimizer, + world_size: int, + distributed_backend: str = "auto", + rank: int = 0, + shard_optimizer_states: bool = True, + tensor_parallel_config=None, + ): + self.base_optimizer = base_optimizer + self.world_size = world_size + self.rank = rank + self.shard_optimizer_states = shard_optimizer_states + self.tensor_parallel_config = tensor_parallel_config + self.sharded_states = {} + self._state_variable_to_parameter = {} + self.distributed_backend = ( + backend_resolver.get_distributed_backend(distributed_backend) + if distributed_backend is not None + else None + ) + self._variables = None # Will be set when optimizer is built + + # In class CoordinatedOptimizer: + +# In class CoordinatedOptimizer: + +# In class CoordinatedOptimizer: + +# In class CoordinatedOptimizer: + +# In class CoordinatedOptimizer: +# In class CoordinatedOptimizer: + + def _get_optimizer_slot_names(self) -> set: + """ + Deduces the slot names ('m', 'v', etc.) by inspecting the variables + created by the base optimizer. This is the most robust method. + """ + slot_names = set() + # The optimizer's variables have paths like 'Adam/m/dense/kernel'. + # We can extract the second part as the slot name. + for var in self.base_optimizer.variables: + # Skip the iteration counter + if "iteration" in var.path.lower(): + continue + path_parts = var.path.split('/') + if len(path_parts) > 1: + slot_names.add(path_parts[1]) + return slot_names + +# In class CoordinatedOptimizer: + +# In class CoordinatedOptimizer: + +# In coordinated_optimizer.py -> class CoordinatedOptimizer: + +# In coordinated_optimizer.py -> class CoordinatedOptimizer: + +# In coordinated_optimizer.py -> class CoordinatedOptimizer: + +# In coordinated_optimizer.py -> class CoordinatedOptimizer: + +# In coordinated_optimizer.py -> class CoordinatedOptimizer: + +# In coordinated_optimizer.py -> class CoordinatedOptimizer: + +# In coordinated_optimizer.py -> class CoordinatedOptimizer: + + def _initialize_sharded_states(self): + """ + Partitions the optimizer's state variables across shards by inspecting + the variables created by the base optimizer. This version correctly + parses variable paths like 'optimizer/param_name_slot_name'. + """ + if not self.shard_optimizer_states or not self.base_optimizer.built: + return + + self.sharded_states = {} + self._state_variable_to_parameter = {} + opt_name = self.base_optimizer.name + + normalized_params = [ + (p.path.replace('/', '_'), p) for p in self._variables + ] + + for state_var in self.base_optimizer.variables: + if state_var is self.base_optimizer.iterations: + continue + + path_parts = state_var.path.split('/') + if len(path_parts) != 2 or path_parts[0] != opt_name: + continue + + state_suffix = path_parts[1] + + found_param = None + slot_name = None + for norm_param_path, param in normalized_params: + if state_suffix.startswith(norm_param_path): + found_param = param + slot_suffix = state_suffix[len(norm_param_path):] + slot_name = slot_suffix.strip('_') + break + + # THE FIX IS HERE: Explicitly check for 'is not None' + if found_param is not None and slot_name is not None: + self._state_variable_to_parameter[state_var.path] = found_param + + sharding_dim = 0 + if self.tensor_parallel_config: + norm_param_name = found_param.path.replace("/", ".") + for (p, a) in self.tensor_parallel_config.state_rules.items(): + if re.search(p, norm_param_name) and hasattr(a, "dim"): + sharding_dim = a.dim + break + + partitioned_state = self._partition_state(state_var, dim=sharding_dim) + self.sharded_states.setdefault(slot_name, {})[found_param.path] = partitioned_state + + if self.base_optimizer.iterations is not None: + self.sharded_states["iterations"] = self._partition_state( + self.base_optimizer.iterations, dim=0 + ) + def _partition_state( + self, state_variable: any, dim: int + ) -> List[np.ndarray]: + """Splits a single state variable numpy array into chunks. + + If the variable cannot be split along the given dimension, it is + replicated across all shards. + + Args: + state_variable: The optimizer state variable. + dim: The dimension along which to partition the variable. + + Returns: + A list of NumPy arrays, where each array is a partition of the + original state variable for a specific shard. + """ + state_array = keras.ops.convert_to_numpy(state_variable) + if state_array.ndim > dim and state_array.shape[dim] >= self.world_size: + return np.array_split(state_array, self.world_size, axis=dim) + else: + return [np.copy(state_array) for _ in range(self.world_size)] + + def get_config(self) -> Dict[str, Any]: + return { + "base_optimizer": self.base_optimizer.get_config(), + "world_size": self.world_size, + "shard_optimizer_states": self.shard_optimizer_states, + } + + def apply_gradients( + self, gradients_and_vars: List[List[tuple]], shard_models: List + ): + """Coordinates gradient synchronization and application. + + This method first synchronizes gradients across all shards based on + tensor parallelism rules. Then, it applies the gradients using either + sharded optimizer states or replicated states. + + Args: + gradients_and_vars: A list of lists, where each inner list contains + (gradient, variable) tuples for a specific model shard. + shard_models: A list of the sharded model instances. + + Raises: + ValueError: If the number of gradient sets does not match the + world size. + """ + if len(gradients_and_vars) != self.world_size: + error_msg = ( + f"Expected {self.world_size} gradient sets, " + f"got {len(gradients_and_vars)}" + ) + raise ValueError(error_msg) + + synchronized_gradients = self._synchronize_gradients(gradients_and_vars) + + if self.shard_optimizer_states and self.sharded_states: + self._apply_gradients_with_sharded_states( + synchronized_gradients, shard_models + ) + else: + self._apply_gradients_with_replicated_states( + synchronized_gradients, shard_models + ) + + def _apply_gradients_with_sharded_states( + self, synchronized_gradients: List[List[tuple]], shard_models: List + ): + """Applies gradients to each shard using its local optimizer state. + + For each shard, this method loads the corresponding partition of the + optimizer state into the base optimizer and then applies the shard's + gradients. + + Args: + synchronized_gradients: The gradients after synchronization. + shard_models: The list of sharded models. + """ + for shard_idx, shard_grads in enumerate(synchronized_gradients): + local_states = self._get_local_optimizer_states(shard_idx) + self._update_optimizer_internal_state( + self.base_optimizer, local_states + ) + self.base_optimizer.apply_gradients(shard_grads) + + def _apply_gradients_with_replicated_states( + self, synchronized_gradients: List[List[tuple]], shard_models: List + ): + """Averages gradients across all shards and applies them once. + + This method is used when optimizer state sharding is disabled. It + calculates the average of the gradients for each variable across all + shards and applies the averaged gradients using the single, replicated + optimizer state. + + Args: + synchronized_gradients: The gradients after synchronization. + shard_models: The list of sharded models. + """ + num_vars = len(synchronized_gradients[0]) + averaged_grads_and_vars = [] + + for i in range(num_vars): + variable = synchronized_gradients[0][i][1] + grads_for_var = [ + shard_grads[i][0] + for shard_grads in synchronized_gradients + if shard_grads[i][0] is not None + ] + + if not grads_for_var: + continue + + summed_grad = grads_for_var[0] + for grad in grads_for_var[1:]: + summed_grad += grad + averaged_grad = summed_grad / len(grads_for_var) + averaged_grads_and_vars.append((averaged_grad, variable)) + + if averaged_grads_and_vars: + self.base_optimizer.apply_gradients(averaged_grads_and_vars) + + def _get_local_optimizer_states(self, shard_idx: int) -> Dict[str, Any]: + """Constructs the state dictionary for a single shard. + + Args: + shard_idx: The index of the shard for which to retrieve the state. + + Returns: + A dictionary containing the optimizer state variables specific to + the given shard index. + """ + local_states = {} + for state_name, state_value in self.sharded_states.items(): + if isinstance(state_value, dict): + local_states[state_name] = {} + for param_name, param_states in state_value.items(): + local_states[state_name][param_name] = param_states[ + shard_idx + ] + else: + local_states[state_name] = state_value[shard_idx] + return local_states + +# In coordinated_optimizer.py -> class CoordinatedOptimizer: + + def _update_optimizer_internal_state(self, local_states: dict): + """Assigns local sharded state values to the optimizer's variables.""" + if not self.base_optimizer.built: + return + + for var in self.base_optimizer.variables: + if var is self.base_optimizer.iterations: + if "iterations" in local_states: + var.assign(local_states["iterations"]) + continue + + # THE FIX IS HERE: Use the variable's path for the lookup. + param = self._state_variable_to_parameter.get(var.path, None) + + if param: + # This internal method is the most reliable way to get the + # slot name (e.g., "momentum") from the variable object. + slot_name = ( + self.base_optimizer._get_slot_name_from_variable(var) + ) + if ( + slot_name in local_states + and param.path in local_states[slot_name] + ): + local_param_state = local_states[slot_name][param.path] + if var.shape == local_param_state.shape: + var.assign(local_param_state) + + def _synchronize_gradients( + self, gradients_and_vars: List[List[tuple]] + ) -> List[List[tuple]]: + """Synchronizes gradients across shards based on tensor parallel rules. + + Specifically, it performs an all-reduce operation on gradients of + weights that are split along a "column" dimension in tensor parallelism. + Other gradients are passed through unchanged. + + Args: + gradients_and_vars: The list of (gradient, variable) lists from + all shards. + + Returns: + The list of (gradient, variable) lists after synchronization. + """ + if not self.tensor_parallel_config: + return gradients_and_vars + + rules = self.tensor_parallel_config.state_rules.items() + column_parallel_patterns = { + pattern + for pattern, action in rules + if hasattr(action, "sharding_type") + and action.sharding_type == "column" + } + + if not column_parallel_patterns: + return gradients_and_vars + + num_weights = len(gradients_and_vars[0]) + for i in range(num_weights): + variable = gradients_and_vars[0][i][1] + var_name = getattr(variable, "path", getattr(variable, "name", "")) + + if any( + re.search(pattern, var_name) + for pattern in column_parallel_patterns + ): + grads_to_reduce = [ + g_and_v[i][0] + for g_and_v in gradients_and_vars + if g_and_v[i][0] is not None + ] + if grads_to_reduce: + synced_grad = self._allreduce_gradients(grads_to_reduce)[0] + for shard_idx in range(self.world_size): + gradients_and_vars[shard_idx][i] = ( + synced_grad, + variable, + ) + return gradients_and_vars + + def _allreduce_gradients(self, gradients: List[Any]) -> List[Any]: + """Performs a mean all-reduce operation on a list of gradients. + + If a distributed backend is available, it uses it. Otherwise, it + falls back to a local mean calculation. + + Args: + gradients: A list of gradients (one from each shard) to be averaged. + + Returns: + A list where each element is the mean of the input gradients. + """ + if not gradients: + return [] + + if ( + self.distributed_backend is not None + and self.distributed_backend.is_initialized + ): + numpy_grad = keras.ops.convert_to_numpy(gradients[0]) + synced_numpy = self.distributed_backend.allreduce( + numpy_grad, op="mean" + ) + synced_tensor = keras.ops.convert_to_tensor(synced_numpy) + return [synced_tensor for _ in range(self.world_size)] + + stacked_grads = keras.ops.stack( + [keras.ops.convert_to_tensor(g) for g in gradients], axis=0 + ) + mean_grad = keras.ops.mean(stacked_grads, axis=0) + return [mean_grad for _ in range(len(gradients))] + + def get_weights(self) -> List[np.ndarray]: + """Returns the weights of the base optimizer.""" + return self.base_optimizer.get_weights() + + def set_weights(self, weights: List[np.ndarray]): + """Sets the weights of the base optimizer.""" + self.base_optimizer.set_weights(weights) + + def enable_optimizer_state_sharding(self, variables: List): + """Enables and initializes optimizer state sharding. + + This method is called from `build()`, which is guarded from running + multiple times. We can assume this should always execute. + """ + # The check 'if not self.shard_optimizer_states:' was here and was + # incorrectly preventing this code from running. It has been removed. + self.shard_optimizer_states = True + self._variables = variables + self._initialize_sharded_states() + + def disable_optimizer_state_sharding(self): + """Disables sharding and clears any sharded states. + + This reverts the optimizer to using a single, replicated state. + """ + if self.shard_optimizer_states: + self.shard_optimizer_states = False + self.sharded_states = {} + + +class TensorParallelOptimizer(optimizers.Optimizer): + """A Keras Optimizer wrapper for tensor-parallel distributed training. + + This optimizer wraps a standard Keras optimizer (e.g., Adam, SGD) and + delegates the complex tasks of state management and gradient synchronization + to a `CoordinatedOptimizer` instance. It is designed to work with models + that have been sharded for tensor parallelism. + + When `apply_gradients` is called with a list of gradient lists (one for each + model shard), it uses the `CoordinatedOptimizer` to handle synchronization + and state sharding. Otherwise, it behaves like the base optimizer. + + Args: + base_optimizer: A Keras optimizer instance or a string identifier + (e.g., 'adam', 'sgd'). + world_size: The total number of devices/processes in the distributed + setup. + distributed_backend: The distributed communication backend to use. + Defaults to "auto". + tensor_parallel_config: An optional configuration object that defines + rules for tensor parallelism. Defaults to `None`. + + Example: + + ```python + import keras + + # Assume model variables and gradients from 4 shards exist. + # The structure is: List[List[Tuple[gradient, variable]]] + trainable_vars = [keras.Variable(1.0), keras.Variable(2.0)] + sharded_grads_and_vars = [ + [(keras.ops.ones_like(v), v) for v in trainable_vars] + for _ in range(4) # 4 shards + ] + + # 1. Wrap a standard Keras optimizer. + base_optimizer = keras.optimizers.Adam() + optimizer = TensorParallelOptimizer(base_optimizer, world_size=4) + optimizer.build(trainable_vars) + + # 2. Apply the sharded gradients. + # The optimizer will handle synchronization (e.g., all-reduce) internally. + optimizer.apply_gradients(sharded_grads_and_vars) + ``` + """ + + def __init__( + self, + base_optimizer: optimizers.Optimizer, + world_size: int, + distributed_backend: str = "auto", + tensor_parallel_config=None, + ): + if isinstance(base_optimizer, str): + resolved_base_optimizer = optimizers.get(base_optimizer) + else: + resolved_base_optimizer = base_optimizer + + if isinstance( + resolved_base_optimizer.learning_rate, + keras.optimizers.schedules.LearningRateSchedule, + ): + lr_value = float( + ops.convert_to_numpy( + resolved_base_optimizer.learning_rate.initial_learning_rate + ) + ) + else: + lr_value = float( + ops.convert_to_numpy(resolved_base_optimizer.learning_rate) + ) + + super().__init__( + learning_rate=lr_value, + name=f"TensorParallel_{resolved_base_optimizer.name}", + ) + + self.base_optimizer = resolved_base_optimizer + self.world_size = world_size + self.distributed_backend = distributed_backend + self.coordinated_optimizer = CoordinatedOptimizer( + self.base_optimizer, + world_size, + distributed_backend=distributed_backend, + tensor_parallel_config=tensor_parallel_config, + ) + + def apply_gradients(self, grads_and_vars: List, **kwargs): + """Applies gradients to the model variables. + + If `grads_and_vars` is a list of lists, it's assumed to be from + sharded models, and the `CoordinatedOptimizer` is used. Otherwise, + it calls the `base_optimizer`'s `apply_gradients` directly. + + Args: + grads_and_vars: A list of (gradient, variable) tuples, or a list + of such lists if running in a sharded context. + **kwargs: Additional arguments. `shard_models` can be passed to + provide the list of model shards. + """ + if ( + isinstance(grads_and_vars, list) + and grads_and_vars + and isinstance(grads_and_vars[0], list) + ): + shard_models = kwargs.get("shard_models", []) + self.coordinated_optimizer.apply_gradients( + grads_and_vars, shard_models + ) + else: + self.base_optimizer.apply_gradients(grads_and_vars) + + def get_config(self) -> Dict[str, Any]: + from keras.src import saving + + config = super().get_config() + config.pop("learning_rate", None) + config.pop("name", None) + + config.update( + { + "base_optimizer": saving.serialize_keras_object( + self.base_optimizer + ), + "world_size": self.world_size, + "distributed_backend": self.distributed_backend, + } + ) + return config + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "TensorParallelOptimizer": + from keras.src import saving + + base_optimizer_config = config.pop("base_optimizer") + base_optimizer = saving.deserialize_keras_object(base_optimizer_config) + + init_kwargs = { + "world_size": config.get("world_size"), + "distributed_backend": config.get("distributed_backend", "auto"), + "tensor_parallel_config": config.get("tensor_parallel_config"), + } + + return cls(base_optimizer=base_optimizer, **init_kwargs) + + def build(self, variables: List): + """Builds the optimizer and initializes sharded states. + + This method is called the first time the optimizer is used. It builds + the base optimizer and then triggers the `CoordinatedOptimizer` to + initialize its sharded states. + + Args: + variables: A list of model variables to be optimized. + """ + if self.built: + return + + # First, build the base optimizer with the variables. + self.base_optimizer.build(variables) + print(f"Variables after build: {[v.path for v in self.base_optimizer.variables]}") + + # THE FINAL FIX: Force slot variable creation by applying zero gradients. + # This is necessary because optimizers create slots lazily on the first + # call to apply_gradients. + if variables: # Only run if there are variables to optimize. + zero_grads = [ops.zeros_like(v) for v in variables] + self.base_optimizer.apply_gradients(zip(zero_grads, variables)) + + # The dry run increments the iteration counter, so we reset it. + if self.base_optimizer.iterations is not None: + self.base_optimizer.iterations.assign(0) + + # Now that all state variables (m, v, etc.) are guaranteed to exist, + # we can safely initialize sharding. + self.coordinated_optimizer.enable_optimizer_state_sharding(variables) + super().build(variables) + + def get_weights(self) -> List[np.ndarray]: + """Returns the weights of the base optimizer.""" + return self.coordinated_optimizer.get_weights() + + def set_weights(self, weights: List[np.ndarray]): + """Sets the weights of the base optimizer.""" + self.coordinated_optimizer.set_weights(weights) + + @property + def variables(self) -> List: + """Returns the list of variables from the base optimizer.""" + return self.base_optimizer.variables + + @property + def learning_rate(self) -> Any: + """Provides access to the learning rate of the base optimizer.""" + return self.base_optimizer.learning_rate diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py new file mode 100644 index 000000000000..59bfa8118b04 --- /dev/null +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -0,0 +1,154 @@ +import numpy as np +from coordinated_optimizer import CoordinatedOptimizer +from coordinated_optimizer import TensorParallelOptimizer + +import keras +from keras.src import optimizers +from keras.src import testing + + +class CoordinatedOptimizerTest(testing.TestCase): + def _get_simple_model(self): + """Creates a simple, uncompiled Keras model.""" + inputs = keras.Input(shape=(10,)) + x = keras.layers.Dense(20, name="dense_1")(inputs) + outputs = keras.layers.Dense(5, name="dense_2")(x) + return keras.Model(inputs, outputs) + + def _get_mock_gradients_and_vars(self, model, world_size): + """Generates mock gradients and variables for N shards.""" + model.build(input_shape=(None, 10)) + variables = model.trainable_variables + grads_and_vars_per_shard = [] + for i in range(world_size): + multiplier = float(i + 1) + gradients = [ + keras.ops.convert_to_tensor( + np.ones_like(v.numpy()) * multiplier, dtype="float32" + ) + for v in variables + ] + grads_and_vars_per_shard.append(list(zip(gradients, variables))) + return grads_and_vars_per_shard + + def test_initialization(self): + """Tests that the optimizer initializes with the correct defaults.""" + base_optimizer = optimizers.Adam() + coord = CoordinatedOptimizer( + base_optimizer, world_size=4, distributed_backend=None + ) + self.assertEqual(coord.base_optimizer, base_optimizer) + self.assertTrue(coord.shard_optimizer_states) + self.assertEqual(coord.sharded_states, {}) + + def test_apply_gradients_with_replicated_states(self): + """Tests that replicated gradients are averaged and applied once.""" + class AdamWithCallCounter(optimizers.Adam): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.apply_gradients_call_count = 0 + self.received_grads = [] + + def apply_gradients(self, grads_and_vars, *args, **kwargs): + self.apply_gradients_call_count += 1 + self.received_grads = [g for g, v in grads_and_vars] + + world_size = 4 + model = self._get_simple_model() + optimizer = AdamWithCallCounter() + model.build((None, 10)) + mock_grads = self._get_mock_gradients_and_vars(model, world_size) + + coord = CoordinatedOptimizer( + optimizer, + world_size, + shard_optimizer_states=False, + distributed_backend=None, + ) + coord.apply_gradients(mock_grads, []) + + self.assertEqual(optimizer.apply_gradients_call_count, 1) + self.assertAllClose( + optimizer.received_grads[0], + np.ones_like(optimizer.received_grads[0]) * 2.5, + ) + + def test_init_from_string(self): + optimizer = TensorParallelOptimizer( + "adam", world_size=4, distributed_backend=None + ) + self.assertIsInstance(optimizer.base_optimizer, optimizers.Adam) + + def test_apply_gradients_delegation(self): + """Tests that apply_gradients correctly delegates.""" + world_size = 4 + base_opt = optimizers.Adam() + optimizer = TensorParallelOptimizer( + base_opt, world_size, distributed_backend=None + ) + model = self._get_simple_model() + mock_grads = self._get_mock_gradients_and_vars(model, world_size) + + coord_apply_tracker = {"called": False} + optimizer.coordinated_optimizer.apply_gradients = ( + lambda *a, **kw: coord_apply_tracker.update({"called": True}) + ) + base_apply_tracker = {"called": False} + optimizer.base_optimizer.apply_gradients = ( + lambda *a, **kw: base_apply_tracker.update({"called": True}) + ) + + optimizer.apply_gradients(mock_grads, shard_models=[]) + self.assertTrue(coord_apply_tracker["called"]) + self.assertFalse(base_apply_tracker["called"]) + + coord_apply_tracker["called"] = False + unsharded_grads = mock_grads[0] + optimizer.apply_gradients(unsharded_grads) + self.assertTrue(base_apply_tracker["called"]) + self.assertFalse(coord_apply_tracker["called"]) + +# In coordinated_optimizer_test.py + +# In coordinated_optimizer_test.py + + def test_build_and_state_sharding(self): + """Tests that the build method correctly initializes sharded states.""" + optimizer = TensorParallelOptimizer( + optimizers.Adam(), world_size=4, distributed_backend=None + ) + model = self._get_simple_model() + + # Build the model so its trainable_variables list is populated. + model.build(input_shape=(None, 10)) + + self.assertEqual(optimizer.coordinated_optimizer.sharded_states, {}) + optimizer.build(model.trainable_variables) + self.assertTrue(optimizer.built) + + sharded_states = optimizer.coordinated_optimizer.sharded_states + + # THE FIX IS HERE: + # Keras Adam uses 'momentum' and 'velocity' as its slot names, not 'm' and 'v'. + self.assertIn("momentum", sharded_states) + self.assertIn("velocity", sharded_states) + self.assertIn("iterations", sharded_states) + + dense_1_kernel_path = model.get_layer("dense_1").kernel.path + self.assertIn(dense_1_kernel_path, sharded_states["momentum"]) + self.assertEqual(len(sharded_states["momentum"][dense_1_kernel_path]), 4) + + def test_serialization(self): + world_size = 4 + base_opt = optimizers.Adam(learning_rate=0.1) + optimizer = TensorParallelOptimizer( + base_opt, world_size, distributed_backend=None + ) + + config = optimizer.get_config() + recreated = TensorParallelOptimizer.from_config(config) + + self.assertEqual(recreated.world_size, world_size) + self.assertIsInstance(recreated.base_optimizer, optimizers.Adam) + self.assertIsNone(recreated.coordinated_optimizer.distributed_backend) + self.assertAllClose(recreated.base_optimizer.learning_rate, 0.1) \ No newline at end of file From bcae2f69ee2d2ee58ce82cebf88d66b8fe4fee89 Mon Sep 17 00:00:00 2001 From: Suhana Date: Wed, 1 Oct 2025 21:35:28 +0530 Subject: [PATCH 2/9] Reformatting --- .../tensor_parallel/coordinated_optimizer.py | 51 +------------------ .../coordinated_optimizer_test.py | 8 +-- 2 files changed, 3 insertions(+), 56 deletions(-) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py index 77e5c13629b6..73ea557995e9 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -57,18 +57,7 @@ def __init__( if distributed_backend is not None else None ) - self._variables = None # Will be set when optimizer is built - - # In class CoordinatedOptimizer: - -# In class CoordinatedOptimizer: - -# In class CoordinatedOptimizer: - -# In class CoordinatedOptimizer: - -# In class CoordinatedOptimizer: -# In class CoordinatedOptimizer: + self._variables = None def _get_optimizer_slot_names(self) -> set: """ @@ -76,10 +65,7 @@ def _get_optimizer_slot_names(self) -> set: created by the base optimizer. This is the most robust method. """ slot_names = set() - # The optimizer's variables have paths like 'Adam/m/dense/kernel'. - # We can extract the second part as the slot name. for var in self.base_optimizer.variables: - # Skip the iteration counter if "iteration" in var.path.lower(): continue path_parts = var.path.split('/') @@ -87,24 +73,6 @@ def _get_optimizer_slot_names(self) -> set: slot_names.add(path_parts[1]) return slot_names -# In class CoordinatedOptimizer: - -# In class CoordinatedOptimizer: - -# In coordinated_optimizer.py -> class CoordinatedOptimizer: - -# In coordinated_optimizer.py -> class CoordinatedOptimizer: - -# In coordinated_optimizer.py -> class CoordinatedOptimizer: - -# In coordinated_optimizer.py -> class CoordinatedOptimizer: - -# In coordinated_optimizer.py -> class CoordinatedOptimizer: - -# In coordinated_optimizer.py -> class CoordinatedOptimizer: - -# In coordinated_optimizer.py -> class CoordinatedOptimizer: - def _initialize_sharded_states(self): """ Partitions the optimizer's state variables across shards by inspecting @@ -141,7 +109,6 @@ def _initialize_sharded_states(self): slot_name = slot_suffix.strip('_') break - # THE FIX IS HERE: Explicitly check for 'is not None' if found_param is not None and slot_name is not None: self._state_variable_to_parameter[state_var.path] = found_param @@ -304,8 +271,6 @@ def _get_local_optimizer_states(self, shard_idx: int) -> Dict[str, Any]: local_states[state_name] = state_value[shard_idx] return local_states -# In coordinated_optimizer.py -> class CoordinatedOptimizer: - def _update_optimizer_internal_state(self, local_states: dict): """Assigns local sharded state values to the optimizer's variables.""" if not self.base_optimizer.built: @@ -317,12 +282,9 @@ def _update_optimizer_internal_state(self, local_states: dict): var.assign(local_states["iterations"]) continue - # THE FIX IS HERE: Use the variable's path for the lookup. param = self._state_variable_to_parameter.get(var.path, None) if param: - # This internal method is the most reliable way to get the - # slot name (e.g., "momentum") from the variable object. slot_name = ( self.base_optimizer._get_slot_name_from_variable(var) ) @@ -433,8 +395,6 @@ def enable_optimizer_state_sharding(self, variables: List): This method is called from `build()`, which is guarded from running multiple times. We can assume this should always execute. """ - # The check 'if not self.shard_optimizer_states:' was here and was - # incorrectly preventing this code from running. It has been removed. self.shard_optimizer_states = True self._variables = variables self._initialize_sharded_states() @@ -607,23 +567,16 @@ def build(self, variables: List): if self.built: return - # First, build the base optimizer with the variables. self.base_optimizer.build(variables) print(f"Variables after build: {[v.path for v in self.base_optimizer.variables]}") - # THE FINAL FIX: Force slot variable creation by applying zero gradients. - # This is necessary because optimizers create slots lazily on the first - # call to apply_gradients. - if variables: # Only run if there are variables to optimize. + if variables: zero_grads = [ops.zeros_like(v) for v in variables] self.base_optimizer.apply_gradients(zip(zero_grads, variables)) - # The dry run increments the iteration counter, so we reset it. if self.base_optimizer.iterations is not None: self.base_optimizer.iterations.assign(0) - # Now that all state variables (m, v, etc.) are guaranteed to exist, - # we can safely initialize sharding. self.coordinated_optimizer.enable_optimizer_state_sharding(variables) super().build(variables) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py index 59bfa8118b04..ca69361fe383 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -108,9 +108,6 @@ def test_apply_gradients_delegation(self): self.assertTrue(base_apply_tracker["called"]) self.assertFalse(coord_apply_tracker["called"]) -# In coordinated_optimizer_test.py - -# In coordinated_optimizer_test.py def test_build_and_state_sharding(self): """Tests that the build method correctly initializes sharded states.""" @@ -119,7 +116,6 @@ def test_build_and_state_sharding(self): ) model = self._get_simple_model() - # Build the model so its trainable_variables list is populated. model.build(input_shape=(None, 10)) self.assertEqual(optimizer.coordinated_optimizer.sharded_states, {}) @@ -127,9 +123,7 @@ def test_build_and_state_sharding(self): self.assertTrue(optimizer.built) sharded_states = optimizer.coordinated_optimizer.sharded_states - - # THE FIX IS HERE: - # Keras Adam uses 'momentum' and 'velocity' as its slot names, not 'm' and 'v'. + self.assertIn("momentum", sharded_states) self.assertIn("velocity", sharded_states) self.assertIn("iterations", sharded_states) From 439643b33fd377041ac299a3ddb76baf15ca52b6 Mon Sep 17 00:00:00 2001 From: Suhana Date: Thu, 2 Oct 2025 15:45:00 +0530 Subject: [PATCH 3/9] Added sharding keras --- .../tensor_parallel/autoconfig.py | 22 ++- .../tensor_parallel/autoconfig_test.py | 15 +- .../tensor_parallel/coordinated_optimizer.py | 152 ++++++++---------- .../coordinated_optimizer_test.py | 47 ++++-- .../tensor_parallel/sharding_keras.py | 85 ++++++++++ 5 files changed, 207 insertions(+), 114 deletions(-) create mode 100644 keras/src/distribution/tensor_parallel/sharding_keras.py diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py index b1c0bb9d5e19..cf5966eb4670 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig.py +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -1,14 +1,12 @@ from typing import Sequence -from keras.src import layers from keras.src.distribution.tensor_parallel.config import ConfigKeras from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras -from keras.src.models import Model -def analyze_dense_layer_directly( - layer: layers.Dense, module: Model, prefix: str -) -> str: +def analyze_dense_layer_directly(layer, module, prefix: str) -> str: + from keras.src import layers + """Analyzes a Dense layer to classify it for tensor parallelism sharding. This function inspects the layer's weight shapes to determine if it's an @@ -63,14 +61,16 @@ def analyze_dense_layer_directly( def _traverse_and_shard_layer( - current_layer: layers.Layer, - module: Model, + current_layer, + module, world_size: int, state_rules: dict, output_rules: dict, processed_layers: set, prefix: str = "", ): + from keras.src import layers + """Traverses a layer and its sub-layers to apply sharding rules. This function navigates through the model's layer hierarchy. For each @@ -145,8 +145,8 @@ def _traverse_and_shard_layer( and current_layer.bias is not None ): state_rules[f"^{full_name}.bias$"] = SplitKeras( - world_size, 0, "column" - ) + world_size, 0, "column" + ) output_rules[f"^{full_name}$"] = {0: "no_comm"} return @@ -184,9 +184,7 @@ def _traverse_and_shard_layer( ) -def get_default_config_keras( - module: Model, device_ids: Sequence[str] -) -> ConfigKeras: +def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: """Generates a smart, recursive sharding configuration for a Keras model. This function traverses the layers of a given Keras model and applies a diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py index 9b955f00525b..ab9a1b4149c1 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig_test.py +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -1,12 +1,12 @@ import os -if "WORLD_SIZE" not in os.environ: - os.environ["WORLD_SIZE"] = "4" +os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" from keras import Input from keras import Model from keras import layers from keras.src import testing +from keras.src.backend.distributed import backend_resolver from keras.src.distribution.tensor_parallel.autoconfig import ( analyze_dense_layer_directly, ) @@ -20,13 +20,18 @@ class TestAutoConfigKeras(testing.TestCase): def setUp(self): """Set up the test case and common variables.""" super().setUp() - self.world_size = int(os.environ["WORLD_SIZE"]) + backend = backend_resolver.get_distributed_backend() + device_info = backend.get_device_info() + self.world_size = device_info["device_count"] self.device_ids = [f"device:{i}" for i in range(self.world_size)] + self.assertGreater( + self.world_size, 1, "Distribution tests require more than 1 device." + ) + def _assert_split_keras_equal(self, rule1, rule2): """ Helper to compare two SplitKeras objects by their attributes. - MODIFIED: Use vars() for robust comparison without knowing attr names. """ self.assertIsInstance(rule1, SplitKeras) self.assertIsInstance(rule2, SplitKeras) @@ -143,4 +148,4 @@ def test_nested_model_sharding(self): } self._assert_rules_equal(config.state_rules, expected_state_rules) - self._assert_rules_equal(config.output_rules, expected_output_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py index 73ea557995e9..726747676c0b 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -58,6 +58,7 @@ def __init__( else None ) self._variables = None + self._variable_to_slot_name = {} def _get_optimizer_slot_names(self) -> set: """ @@ -68,7 +69,7 @@ def _get_optimizer_slot_names(self) -> set: for var in self.base_optimizer.variables: if "iteration" in var.path.lower(): continue - path_parts = var.path.split('/') + path_parts = var.path.split("/") if len(path_parts) > 1: slot_names.add(path_parts[1]) return slot_names @@ -84,20 +85,21 @@ def _initialize_sharded_states(self): self.sharded_states = {} self._state_variable_to_parameter = {} + self._variable_to_slot_name = {} # Reset the map opt_name = self.base_optimizer.name normalized_params = [ - (p.path.replace('/', '_'), p) for p in self._variables + (p.path.replace("/", "_"), p) for p in self._variables ] for state_var in self.base_optimizer.variables: if state_var is self.base_optimizer.iterations: continue - path_parts = state_var.path.split('/') + path_parts = state_var.path.split("/") if len(path_parts) != 2 or path_parts[0] != opt_name: continue - + state_suffix = path_parts[1] found_param = None @@ -105,28 +107,35 @@ def _initialize_sharded_states(self): for norm_param_path, param in normalized_params: if state_suffix.startswith(norm_param_path): found_param = param - slot_suffix = state_suffix[len(norm_param_path):] - slot_name = slot_suffix.strip('_') + slot_suffix = state_suffix[len(norm_param_path) :] + slot_name = slot_suffix.strip("_") break if found_param is not None and slot_name is not None: self._state_variable_to_parameter[state_var.path] = found_param + # MODIFIED: Store the mapping from variable path to slot name + self._variable_to_slot_name[state_var.path] = slot_name sharding_dim = 0 if self.tensor_parallel_config: norm_param_name = found_param.path.replace("/", ".") - for (p, a) in self.tensor_parallel_config.state_rules.items(): + for p, a in self.tensor_parallel_config.state_rules.items(): if re.search(p, norm_param_name) and hasattr(a, "dim"): sharding_dim = a.dim break - - partitioned_state = self._partition_state(state_var, dim=sharding_dim) - self.sharded_states.setdefault(slot_name, {})[found_param.path] = partitioned_state + + partitioned_state = self._partition_state( + state_var, dim=sharding_dim + ) + self.sharded_states.setdefault(slot_name, {})[ + found_param.path + ] = partitioned_state if self.base_optimizer.iterations is not None: self.sharded_states["iterations"] = self._partition_state( self.base_optimizer.iterations, dim=0 ) + def _partition_state( self, state_variable: any, dim: int ) -> List[np.ndarray]: @@ -143,7 +152,7 @@ def _partition_state( A list of NumPy arrays, where each array is a partition of the original state variable for a specific shard. """ - state_array = keras.ops.convert_to_numpy(state_variable) + state_array = ops.convert_to_numpy(state_variable) if state_array.ndim > dim and state_array.shape[dim] >= self.world_size: return np.array_split(state_array, self.world_size, axis=dim) else: @@ -192,26 +201,6 @@ def apply_gradients( synchronized_gradients, shard_models ) - def _apply_gradients_with_sharded_states( - self, synchronized_gradients: List[List[tuple]], shard_models: List - ): - """Applies gradients to each shard using its local optimizer state. - - For each shard, this method loads the corresponding partition of the - optimizer state into the base optimizer and then applies the shard's - gradients. - - Args: - synchronized_gradients: The gradients after synchronization. - shard_models: The list of sharded models. - """ - for shard_idx, shard_grads in enumerate(synchronized_gradients): - local_states = self._get_local_optimizer_states(shard_idx) - self._update_optimizer_internal_state( - self.base_optimizer, local_states - ) - self.base_optimizer.apply_gradients(shard_grads) - def _apply_gradients_with_replicated_states( self, synchronized_gradients: List[List[tuple]], shard_models: List ): @@ -240,10 +229,12 @@ def _apply_gradients_with_replicated_states( if not grads_for_var: continue - summed_grad = grads_for_var[0] - for grad in grads_for_var[1:]: - summed_grad += grad - averaged_grad = summed_grad / len(grads_for_var) + if len(grads_for_var) > 1: + stacked_grads = ops.stack(grads_for_var, axis=0) + averaged_grad = ops.mean(stacked_grads, axis=0) + else: + averaged_grad = grads_for_var[0] + averaged_grads_and_vars.append((averaged_grad, variable)) if averaged_grads_and_vars: @@ -271,30 +262,29 @@ def _get_local_optimizer_states(self, shard_idx: int) -> Dict[str, Any]: local_states[state_name] = state_value[shard_idx] return local_states - def _update_optimizer_internal_state(self, local_states: dict): + def _update_optimizer_internal_state(self, optimizer, local_states: dict): """Assigns local sharded state values to the optimizer's variables.""" - if not self.base_optimizer.built: + if not optimizer.built: return - for var in self.base_optimizer.variables: - if var is self.base_optimizer.iterations: + for var in optimizer.variables: + if var is optimizer.iterations: if "iterations" in local_states: - var.assign(local_states["iterations"]) + ops.assign(var, local_states["iterations"]) continue param = self._state_variable_to_parameter.get(var.path, None) - - if param: - slot_name = ( - self.base_optimizer._get_slot_name_from_variable(var) - ) - if ( - slot_name in local_states - and param.path in local_states[slot_name] - ): - local_param_state = local_states[slot_name][param.path] - if var.shape == local_param_state.shape: - var.assign(local_param_state) + slot_name = self._variable_to_slot_name.get(var.path) + + if ( + param + and slot_name + and slot_name in local_states + and param.path in local_states[slot_name] + ): + local_param_state = local_states[slot_name][param.path] + if var.shape == local_param_state.shape: + ops.assign(var, local_param_state) def _synchronize_gradients( self, gradients_and_vars: List[List[tuple]] @@ -364,26 +354,25 @@ def _allreduce_gradients(self, gradients: List[Any]) -> List[Any]: if not gradients: return [] - if ( - self.distributed_backend is not None - and self.distributed_backend.is_initialized - ): - numpy_grad = keras.ops.convert_to_numpy(gradients[0]) - synced_numpy = self.distributed_backend.allreduce( + if self.distributed_backend is not None: + numpy_grad = ops.convert_to_numpy(gradients[0]) + synced_numpy = self.distributed_backend.all_reduce( numpy_grad, op="mean" ) - synced_tensor = keras.ops.convert_to_tensor(synced_numpy) + synced_tensor = ops.convert_to_tensor(synced_numpy) return [synced_tensor for _ in range(self.world_size)] stacked_grads = keras.ops.stack( - [keras.ops.convert_to_tensor(g) for g in gradients], axis=0 + [ops.convert_to_tensor(g) for g in gradients], axis=0 ) - mean_grad = keras.ops.mean(stacked_grads, axis=0) + mean_grad = ops.mean(stacked_grads, axis=0) return [mean_grad for _ in range(len(gradients))] def get_weights(self) -> List[np.ndarray]: """Returns the weights of the base optimizer.""" - return self.base_optimizer.get_weights() + return [ + ops.convert_to_numpy(var) for var in self.base_optimizer.variables + ] def set_weights(self, weights: List[np.ndarray]): """Sets the weights of the base optimizer.""" @@ -463,30 +452,22 @@ def __init__( tensor_parallel_config=None, ): if isinstance(base_optimizer, str): - resolved_base_optimizer = optimizers.get(base_optimizer) + base_optimizer_instance = optimizers.get(base_optimizer) else: - resolved_base_optimizer = base_optimizer + base_optimizer_instance = base_optimizer - if isinstance( - resolved_base_optimizer.learning_rate, - keras.optimizers.schedules.LearningRateSchedule, - ): - lr_value = float( - ops.convert_to_numpy( - resolved_base_optimizer.learning_rate.initial_learning_rate - ) - ) + learning_rate = base_optimizer_instance.learning_rate + if callable(learning_rate): + lr_value = float(ops.convert_to_numpy(learning_rate(0))) else: - lr_value = float( - ops.convert_to_numpy(resolved_base_optimizer.learning_rate) - ) + lr_value = float(ops.convert_to_numpy(learning_rate)) super().__init__( learning_rate=lr_value, - name=f"TensorParallel_{resolved_base_optimizer.name}", + name=f"TensorParallel_{base_optimizer_instance.name}", ) - self.base_optimizer = resolved_base_optimizer + self.base_optimizer = base_optimizer_instance self.world_size = world_size self.distributed_backend = distributed_backend self.coordinated_optimizer = CoordinatedOptimizer( @@ -568,15 +549,10 @@ def build(self, variables: List): return self.base_optimizer.build(variables) - print(f"Variables after build: {[v.path for v in self.base_optimizer.variables]}") - if variables: zero_grads = [ops.zeros_like(v) for v in variables] self.base_optimizer.apply_gradients(zip(zero_grads, variables)) - if self.base_optimizer.iterations is not None: - self.base_optimizer.iterations.assign(0) - self.coordinated_optimizer.enable_optimizer_state_sharding(variables) super().build(variables) @@ -597,3 +573,13 @@ def variables(self) -> List: def learning_rate(self) -> Any: """Provides access to the learning rate of the base optimizer.""" return self.base_optimizer.learning_rate + + @property + def iterations(self): + """ + Returns the training iteration count, compensating for the initial + dummy step in the build method. + """ + if self.base_optimizer.iterations is None: + return None + return self.base_optimizer.iterations - 1 \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py index ca69361fe383..46579d4147aa 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -1,12 +1,22 @@ import numpy as np -from coordinated_optimizer import CoordinatedOptimizer -from coordinated_optimizer import TensorParallelOptimizer +import pytest import keras +from keras import ops from keras.src import optimizers from keras.src import testing - - +from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + CoordinatedOptimizer, +) +from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + TensorParallelOptimizer, +) + + +@pytest.mark.skipif( + keras.backend.backend() == "openvino", + reason="CoordinatedOptimizer is not yet supported on the OpenVINO backend.", +) class CoordinatedOptimizerTest(testing.TestCase): def _get_simple_model(self): """Creates a simple, uncompiled Keras model.""" @@ -23,7 +33,7 @@ def _get_mock_gradients_and_vars(self, model, world_size): for i in range(world_size): multiplier = float(i + 1) gradients = [ - keras.ops.convert_to_tensor( + ops.convert_to_tensor( np.ones_like(v.numpy()) * multiplier, dtype="float32" ) for v in variables @@ -43,6 +53,7 @@ def test_initialization(self): def test_apply_gradients_with_replicated_states(self): """Tests that replicated gradients are averaged and applied once.""" + class AdamWithCallCounter(optimizers.Adam): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -52,6 +63,8 @@ def __init__(self, *args, **kwargs): def apply_gradients(self, grads_and_vars, *args, **kwargs): self.apply_gradients_call_count += 1 self.received_grads = [g for g, v in grads_and_vars] + # Call the superclass method to ensure variables are updated + super().apply_gradients(grads_and_vars, *args, **kwargs) world_size = 4 model = self._get_simple_model() @@ -68,6 +81,7 @@ def apply_gradients(self, grads_and_vars, *args, **kwargs): coord.apply_gradients(mock_grads, []) self.assertEqual(optimizer.apply_gradients_call_count, 1) + # The average of multipliers 1, 2, 3, 4 is (1+2+3+4)/4 = 10/4 = 2.5 self.assertAllClose( optimizer.received_grads[0], np.ones_like(optimizer.received_grads[0]) * 2.5, @@ -90,13 +104,18 @@ def test_apply_gradients_delegation(self): mock_grads = self._get_mock_gradients_and_vars(model, world_size) coord_apply_tracker = {"called": False} - optimizer.coordinated_optimizer.apply_gradients = ( - lambda *a, **kw: coord_apply_tracker.update({"called": True}) - ) + + def coord_apply_mock(*args, **kwargs): + coord_apply_tracker["called"] = True + + optimizer.coordinated_optimizer.apply_gradients = coord_apply_mock + base_apply_tracker = {"called": False} - optimizer.base_optimizer.apply_gradients = ( - lambda *a, **kw: base_apply_tracker.update({"called": True}) - ) + + def base_apply_mock(*args, **kwargs): + base_apply_tracker["called"] = True + + optimizer.base_optimizer.apply_gradients = base_apply_mock optimizer.apply_gradients(mock_grads, shard_models=[]) self.assertTrue(coord_apply_tracker["called"]) @@ -108,7 +127,6 @@ def test_apply_gradients_delegation(self): self.assertTrue(base_apply_tracker["called"]) self.assertFalse(coord_apply_tracker["called"]) - def test_build_and_state_sharding(self): """Tests that the build method correctly initializes sharded states.""" optimizer = TensorParallelOptimizer( @@ -123,14 +141,15 @@ def test_build_and_state_sharding(self): self.assertTrue(optimizer.built) sharded_states = optimizer.coordinated_optimizer.sharded_states - self.assertIn("momentum", sharded_states) self.assertIn("velocity", sharded_states) self.assertIn("iterations", sharded_states) dense_1_kernel_path = model.get_layer("dense_1").kernel.path self.assertIn(dense_1_kernel_path, sharded_states["momentum"]) - self.assertEqual(len(sharded_states["momentum"][dense_1_kernel_path]), 4) + self.assertEqual( + len(sharded_states["momentum"][dense_1_kernel_path]), 4 + ) def test_serialization(self): world_size = 4 diff --git a/keras/src/distribution/tensor_parallel/sharding_keras.py b/keras/src/distribution/tensor_parallel/sharding_keras.py new file mode 100644 index 000000000000..d6d08a524f5b --- /dev/null +++ b/keras/src/distribution/tensor_parallel/sharding_keras.py @@ -0,0 +1,85 @@ +from typing import Any +from typing import Collection +from typing import Dict +from typing import List +from typing import Sequence + +from keras.src.distribution.tensor_parallel.config import ConfigKeras + + +class ShardedKeras: + """ + Manages sharded parameters for Keras models. + """ + + def __init__( + self, + model_shards, + replicated_param_names: Collection[str], + tensor_parallel_config: ConfigKeras, + devices: Sequence[str], + output_device_index: int, + ): + """ + Initialize the sharding manager. + + Args: + model_shards: List of model shards + replicated_param_names: Names of parameters that are replicated + tensor_parallel_config: Tensor parallel configuration + devices: List of device IDs + output_device_index: Index of the output device + """ + self.model_shards = model_shards + self.replicated_param_names = set(replicated_param_names) + self.tensor_parallel_config = tensor_parallel_config + self.devices = devices + self.output_device_index = output_device_index + + def get_shard_parameters(self, shard_index: int) -> Dict[str, Any]: + """ + Get parameters for a specific shard. + + Args: + shard_index: Index of the shard + + Returns: + Dictionary of parameter names to values + """ + if shard_index >= len(self.model_shards): + raise ValueError(f"Shard index {shard_index} out of range") + + shard = self.model_shards[shard_index] + params = {} + + for layer in shard.layers: + name = layer.name + if hasattr(layer, "weights") and layer.weights: + for i, weight in enumerate(layer.weights): + param_name = f"{name}.weight_{i}" + params[param_name] = weight + + return params + + def get_all_parameters(self) -> List[Dict[str, Any]]: + """ + Get parameters from all shards. + + Returns: + List of parameter dictionaries for each shard + """ + return [ + self.get_shard_parameters(i) for i in range(len(self.model_shards)) + ] + + def apply_sharding(self): + """ + Apply sharding to the model parameters. + """ + pass + + def unshard_parameters(self): + """ + Unshard parameters back to their original form. + """ + pass \ No newline at end of file From b7862d9c3a592f636c435d2e7b2160349713d132 Mon Sep 17 00:00:00 2001 From: Suhana Date: Thu, 2 Oct 2025 15:46:52 +0530 Subject: [PATCH 4/9] Reformatting files --- keras/src/distribution/tensor_parallel/autoconfig.py | 2 +- keras/src/distribution/tensor_parallel/autoconfig_test.py | 2 +- keras/src/distribution/tensor_parallel/coordinated_optimizer.py | 2 +- .../distribution/tensor_parallel/coordinated_optimizer_test.py | 2 +- keras/src/distribution/tensor_parallel/sharding_keras.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py index cf5966eb4670..6e90b10a0bc4 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig.py +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -217,4 +217,4 @@ def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: prefix="", ) - return ConfigKeras(state_rules=state_rules, output_rules=output_rules) \ No newline at end of file + return ConfigKeras(state_rules=state_rules, output_rules=output_rules) diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py index ab9a1b4149c1..ee7519c607ef 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig_test.py +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -148,4 +148,4 @@ def test_nested_model_sharding(self): } self._assert_rules_equal(config.state_rules, expected_state_rules) - self._assert_rules_equal(config.output_rules, expected_output_rules) \ No newline at end of file + self._assert_rules_equal(config.output_rules, expected_output_rules) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py index 726747676c0b..18661f860c6e 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -582,4 +582,4 @@ def iterations(self): """ if self.base_optimizer.iterations is None: return None - return self.base_optimizer.iterations - 1 \ No newline at end of file + return self.base_optimizer.iterations - 1 diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py index 46579d4147aa..c4249d147d73 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -164,4 +164,4 @@ def test_serialization(self): self.assertEqual(recreated.world_size, world_size) self.assertIsInstance(recreated.base_optimizer, optimizers.Adam) self.assertIsNone(recreated.coordinated_optimizer.distributed_backend) - self.assertAllClose(recreated.base_optimizer.learning_rate, 0.1) \ No newline at end of file + self.assertAllClose(recreated.base_optimizer.learning_rate, 0.1) diff --git a/keras/src/distribution/tensor_parallel/sharding_keras.py b/keras/src/distribution/tensor_parallel/sharding_keras.py index d6d08a524f5b..ace810adb024 100644 --- a/keras/src/distribution/tensor_parallel/sharding_keras.py +++ b/keras/src/distribution/tensor_parallel/sharding_keras.py @@ -82,4 +82,4 @@ def unshard_parameters(self): """ Unshard parameters back to their original form. """ - pass \ No newline at end of file + pass From 3383dec7736e8d34bde2263d324d256c9344a070 Mon Sep 17 00:00:00 2001 From: Suhana Date: Fri, 3 Oct 2025 11:55:32 +0530 Subject: [PATCH 5/9] Reformatting according to changes in distributed_backend --- .../tensor_parallel/autoconfig.py | 19 ++- .../tensor_parallel/autoconfig_test.py | 33 +++-- .../tensor_parallel/coordinated_optimizer.py | 140 ++++++++++-------- .../coordinated_optimizer_test.py | 63 ++++---- .../tensor_parallel/sharding_keras.py | 9 +- 5 files changed, 146 insertions(+), 118 deletions(-) diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py index 6e90b10a0bc4..9fa6db430c35 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig.py +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -206,15 +206,14 @@ def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: output_rules = {} processed_layers = set() - for layer in module.layers: - _traverse_and_shard_layer( - current_layer=layer, - module=module, - world_size=world_size, - state_rules=state_rules, - output_rules=output_rules, - processed_layers=processed_layers, - prefix="", - ) + _traverse_and_shard_layer( + current_layer=module, + module=module, + world_size=world_size, + state_rules=state_rules, + output_rules=output_rules, + processed_layers=processed_layers, + prefix="", + ) return ConfigKeras(state_rules=state_rules, output_rules=output_rules) diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py index ee7519c607ef..96467da847e0 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig_test.py +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -6,7 +6,7 @@ from keras import Model from keras import layers from keras.src import testing -from keras.src.backend.distributed import backend_resolver +from keras.src.distribution import distributed_backend from keras.src.distribution.tensor_parallel.autoconfig import ( analyze_dense_layer_directly, ) @@ -20,8 +20,7 @@ class TestAutoConfigKeras(testing.TestCase): def setUp(self): """Set up the test case and common variables.""" super().setUp() - backend = backend_resolver.get_distributed_backend() - device_info = backend.get_device_info() + device_info = distributed_backend.get_device_info() self.world_size = device_info["device_count"] self.device_ids = [f"device:{i}" for i in range(self.world_size)] @@ -78,19 +77,19 @@ def test_simple_mlp_sharding(self): config = get_default_config_keras(model, self.device_ids) expected_state_rules = { - r"^up_projection_layer.kernel$": SplitKeras( + r"^simple_mlp.up_projection_layer.kernel$": SplitKeras( self.world_size, 1, "column" ), - r"^up_projection_layer.bias$": SplitKeras( + r"^simple_mlp.up_projection_layer.bias$": SplitKeras( self.world_size, 0, "column" ), - r"^down_projection_layer.kernel$": SplitKeras( + r"^simple_mlp.down_projection_layer.kernel$": SplitKeras( self.world_size, 0, "row" ), } expected_output_rules = { - r"^up_projection_layer$": {0: "no_comm"}, - r"^down_projection_layer$": {0: "allreduce"}, + r"^simple_mlp.up_projection_layer$": {0: "no_comm"}, + r"^simple_mlp.down_projection_layer$": {0: "allreduce"}, } self._assert_rules_equal(config.state_rules, expected_state_rules) @@ -107,11 +106,13 @@ def test_embedding_sharding(self): config = get_default_config_keras(model, self.device_ids) expected_state_rules = { - r"^token_embedding\.embeddings$": SplitKeras( + r"^embed_model.token_embedding\.embeddings$": SplitKeras( self.world_size, 1, "column" ) } - expected_output_rules = {r"^token_embedding$": {0: "no_comm"}} + expected_output_rules = { + r"^embed_model.token_embedding$": {0: "no_comm"} + } self._assert_rules_equal(config.state_rules, expected_state_rules) self._assert_rules_equal(config.output_rules, expected_output_rules) @@ -134,17 +135,19 @@ def test_nested_model_sharding(self): config = get_default_config_keras(outer_model, self.device_ids) expected_state_rules = { - r"^inner_block.inner_dense.kernel$": SplitKeras( + r"^outer_model.inner_block.inner_dense.kernel$": SplitKeras( self.world_size, 1, "column" ), - r"^inner_block.inner_dense.bias$": SplitKeras( + r"^outer_model.inner_block.inner_dense.bias$": SplitKeras( self.world_size, 0, "column" ), - r"^outer_dense.kernel$": SplitKeras(self.world_size, 0, "row"), + r"^outer_model.outer_dense.kernel$": SplitKeras( + self.world_size, 0, "row" + ), } expected_output_rules = { - r"^inner_block.inner_dense$": {0: "no_comm"}, - r"^outer_dense$": {0: "allreduce"}, + r"^outer_model.inner_block.inner_dense$": {0: "no_comm"}, + r"^outer_model.outer_dense$": {0: "allreduce"}, } self._assert_rules_equal(config.state_rules, expected_state_rules) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py index 18661f860c6e..260d719d3985 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -1,14 +1,12 @@ import re from typing import Any -from typing import Dict -from typing import List import numpy as np import keras from keras.src import ops from keras.src import optimizers -from keras.src.backend.distributed import backend_resolver +from keras.src.distribution import distributed_backend class CoordinatedOptimizer: @@ -47,50 +45,31 @@ def __init__( ): self.base_optimizer = base_optimizer self.world_size = world_size - self.rank = rank self.shard_optimizer_states = shard_optimizer_states self.tensor_parallel_config = tensor_parallel_config self.sharded_states = {} self._state_variable_to_parameter = {} - self.distributed_backend = ( - backend_resolver.get_distributed_backend(distributed_backend) - if distributed_backend is not None - else None - ) self._variables = None self._variable_to_slot_name = {} - def _get_optimizer_slot_names(self) -> set: - """ - Deduces the slot names ('m', 'v', etc.) by inspecting the variables - created by the base optimizer. This is the most robust method. - """ - slot_names = set() - for var in self.base_optimizer.variables: - if "iteration" in var.path.lower(): - continue - path_parts = var.path.split("/") - if len(path_parts) > 1: - slot_names.add(path_parts[1]) - return slot_names - def _initialize_sharded_states(self): """ Partitions the optimizer's state variables across shards by inspecting - the variables created by the base optimizer. This version correctly - parses variable paths like 'optimizer/param_name_slot_name'. + the variables created by the base optimizer. """ if not self.shard_optimizer_states or not self.base_optimizer.built: return self.sharded_states = {} self._state_variable_to_parameter = {} - self._variable_to_slot_name = {} # Reset the map + self._variable_to_slot_name = {} opt_name = self.base_optimizer.name - normalized_params = [ - (p.path.replace("/", "_"), p) for p in self._variables - ] + normalized_params = sorted( + [(p.path.replace("/", "_"), p) for p in self._variables], + key=lambda x: len(x[0]), + reverse=True, + ) for state_var in self.base_optimizer.variables: if state_var is self.base_optimizer.iterations: @@ -113,7 +92,6 @@ def _initialize_sharded_states(self): if found_param is not None and slot_name is not None: self._state_variable_to_parameter[state_var.path] = found_param - # MODIFIED: Store the mapping from variable path to slot name self._variable_to_slot_name[state_var.path] = slot_name sharding_dim = 0 @@ -138,7 +116,7 @@ def _initialize_sharded_states(self): def _partition_state( self, state_variable: any, dim: int - ) -> List[np.ndarray]: + ) -> list[np.ndarray]: """Splits a single state variable numpy array into chunks. If the variable cannot be split along the given dimension, it is @@ -158,7 +136,7 @@ def _partition_state( else: return [np.copy(state_array) for _ in range(self.world_size)] - def get_config(self) -> Dict[str, Any]: + def get_config(self) -> dict[str, Any]: return { "base_optimizer": self.base_optimizer.get_config(), "world_size": self.world_size, @@ -166,7 +144,7 @@ def get_config(self) -> Dict[str, Any]: } def apply_gradients( - self, gradients_and_vars: List[List[tuple]], shard_models: List + self, gradients_and_vars: list[list[tuple]], shard_models: list ): """Coordinates gradient synchronization and application. @@ -202,7 +180,7 @@ def apply_gradients( ) def _apply_gradients_with_replicated_states( - self, synchronized_gradients: List[List[tuple]], shard_models: List + self, synchronized_gradients: list[list[tuple]], shard_models: list ): """Averages gradients across all shards and applies them once. @@ -240,16 +218,23 @@ def _apply_gradients_with_replicated_states( if averaged_grads_and_vars: self.base_optimizer.apply_gradients(averaged_grads_and_vars) - def _get_local_optimizer_states(self, shard_idx: int) -> Dict[str, Any]: - """Constructs the state dictionary for a single shard. + def _apply_gradients_with_sharded_states( + self, synchronized_gradients: list[list[tuple]], shard_models: list + ): + """Applies gradients to each shard using its local optimizer state.""" + for shard_idx in range(self.world_size): + local_states = self._get_local_optimizer_states(shard_idx) + shard_optimizer = shard_models[shard_idx].optimizer + + self._update_optimizer_internal_state(shard_optimizer, local_states) - Args: - shard_idx: The index of the shard for which to retrieve the state. + shard_grads_and_vars = synchronized_gradients[shard_idx] + shard_optimizer.apply_gradients(shard_grads_and_vars) - Returns: - A dictionary containing the optimizer state variables specific to - the given shard index. - """ + self._update_global_sharded_states(shard_optimizer, shard_idx) + + def _get_local_optimizer_states(self, shard_idx: int) -> dict[str, Any]: + """Constructs the state dictionary for a single shard.""" local_states = {} for state_name, state_value in self.sharded_states.items(): if isinstance(state_value, dict): @@ -286,9 +271,34 @@ def _update_optimizer_internal_state(self, optimizer, local_states: dict): if var.shape == local_param_state.shape: ops.assign(var, local_param_state) + def _update_global_sharded_states(self, optimizer, shard_idx: int): + """Updates the main sharded_states dictionary after a gradient step.""" + if not optimizer.built: + return + + for var in optimizer.variables: + if var is optimizer.iterations: + self.sharded_states["iterations"][shard_idx] = ( + ops.convert_to_numpy(var) + ) + continue + + param = self._state_variable_to_parameter.get(var.path, None) + slot_name = self._variable_to_slot_name.get(var.path) + + if ( + param + and slot_name + and slot_name in self.sharded_states + and param.path in self.sharded_states[slot_name] + ): + self.sharded_states[slot_name][param.path][shard_idx] = ( + ops.convert_to_numpy(var) + ) + def _synchronize_gradients( - self, gradients_and_vars: List[List[tuple]] - ) -> List[List[tuple]]: + self, gradients_and_vars: list[list[tuple]] + ) -> list[list[tuple]]: """Synchronizes gradients across shards based on tensor parallel rules. Specifically, it performs an all-reduce operation on gradients of @@ -339,7 +349,7 @@ def _synchronize_gradients( ) return gradients_and_vars - def _allreduce_gradients(self, gradients: List[Any]) -> List[Any]: + def _allreduce_gradients(self, gradients: list[Any]) -> list[Any]: """Performs a mean all-reduce operation on a list of gradients. If a distributed backend is available, it uses it. Otherwise, it @@ -354,11 +364,12 @@ def _allreduce_gradients(self, gradients: List[Any]) -> List[Any]: if not gradients: return [] - if self.distributed_backend is not None: + if distributed_backend.is_multi_device_capable(): + all_reduce_fn = distributed_backend.get_communication_ops()[ + "all_reduce" + ] numpy_grad = ops.convert_to_numpy(gradients[0]) - synced_numpy = self.distributed_backend.all_reduce( - numpy_grad, op="mean" - ) + synced_numpy = all_reduce_fn(numpy_grad, op="mean") synced_tensor = ops.convert_to_tensor(synced_numpy) return [synced_tensor for _ in range(self.world_size)] @@ -368,17 +379,17 @@ def _allreduce_gradients(self, gradients: List[Any]) -> List[Any]: mean_grad = ops.mean(stacked_grads, axis=0) return [mean_grad for _ in range(len(gradients))] - def get_weights(self) -> List[np.ndarray]: + def get_weights(self) -> list[np.ndarray]: """Returns the weights of the base optimizer.""" return [ ops.convert_to_numpy(var) for var in self.base_optimizer.variables ] - def set_weights(self, weights: List[np.ndarray]): + def set_weights(self, weights: list[np.ndarray]): """Sets the weights of the base optimizer.""" self.base_optimizer.set_weights(weights) - def enable_optimizer_state_sharding(self, variables: List): + def enable_optimizer_state_sharding(self, variables: list): """Enables and initializes optimizer state sharding. This method is called from `build()`, which is guarded from running @@ -426,7 +437,7 @@ class TensorParallelOptimizer(optimizers.Optimizer): import keras # Assume model variables and gradients from 4 shards exist. - # The structure is: List[List[Tuple[gradient, variable]]] + # The structure is: list[list[tuple[gradient, variable]]] trainable_vars = [keras.Variable(1.0), keras.Variable(2.0)] sharded_grads_and_vars = [ [(keras.ops.ones_like(v), v) for v in trainable_vars] @@ -477,7 +488,7 @@ def __init__( tensor_parallel_config=tensor_parallel_config, ) - def apply_gradients(self, grads_and_vars: List, **kwargs): + def apply_gradients(self, grads_and_vars: list, **kwargs): """Applies gradients to the model variables. If `grads_and_vars` is a list of lists, it's assumed to be from @@ -490,11 +501,12 @@ def apply_gradients(self, grads_and_vars: List, **kwargs): **kwargs: Additional arguments. `shard_models` can be passed to provide the list of model shards. """ - if ( + is_sharded_grads = ( isinstance(grads_and_vars, list) and grads_and_vars and isinstance(grads_and_vars[0], list) - ): + ) + if is_sharded_grads: shard_models = kwargs.get("shard_models", []) self.coordinated_optimizer.apply_gradients( grads_and_vars, shard_models @@ -502,7 +514,7 @@ def apply_gradients(self, grads_and_vars: List, **kwargs): else: self.base_optimizer.apply_gradients(grads_and_vars) - def get_config(self) -> Dict[str, Any]: + def get_config(self) -> dict[str, Any]: from keras.src import saving config = super().get_config() @@ -521,7 +533,7 @@ def get_config(self) -> Dict[str, Any]: return config @classmethod - def from_config(cls, config: Dict[str, Any]) -> "TensorParallelOptimizer": + def from_config(cls, config: dict[str, Any]) -> "TensorParallelOptimizer": from keras.src import saving base_optimizer_config = config.pop("base_optimizer") @@ -535,7 +547,7 @@ def from_config(cls, config: Dict[str, Any]) -> "TensorParallelOptimizer": return cls(base_optimizer=base_optimizer, **init_kwargs) - def build(self, variables: List): + def build(self, variables: list): """Builds the optimizer and initializes sharded states. This method is called the first time the optimizer is used. It builds @@ -556,16 +568,16 @@ def build(self, variables: List): self.coordinated_optimizer.enable_optimizer_state_sharding(variables) super().build(variables) - def get_weights(self) -> List[np.ndarray]: + def get_weights(self) -> list[np.ndarray]: """Returns the weights of the base optimizer.""" return self.coordinated_optimizer.get_weights() - def set_weights(self, weights: List[np.ndarray]): + def set_weights(self, weights: list[np.ndarray]): """Sets the weights of the base optimizer.""" self.coordinated_optimizer.set_weights(weights) @property - def variables(self) -> List: + def variables(self) -> list: """Returns the list of variables from the base optimizer.""" return self.base_optimizer.variables @@ -574,6 +586,10 @@ def learning_rate(self) -> Any: """Provides access to the learning rate of the base optimizer.""" return self.base_optimizer.learning_rate + @learning_rate.setter + def learning_rate(self, value): + self.base_optimizer.learning_rate = value + @property def iterations(self): """ diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py index c4249d147d73..39cce46de72c 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -5,17 +5,19 @@ from keras import ops from keras.src import optimizers from keras.src import testing -from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( - CoordinatedOptimizer, -) -from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( - TensorParallelOptimizer, -) + +if keras.backend.backend() == "jax": + from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + CoordinatedOptimizer, + ) + from keras.src.distribution.tensor_parallel.coordinated_optimizer import ( + TensorParallelOptimizer, + ) @pytest.mark.skipif( - keras.backend.backend() == "openvino", - reason="CoordinatedOptimizer is not yet supported on the OpenVINO backend.", + keras.backend.backend() != "jax", + reason="This test is JAX-specific.", ) class CoordinatedOptimizerTest(testing.TestCase): def _get_simple_model(self): @@ -44,9 +46,7 @@ def _get_mock_gradients_and_vars(self, model, world_size): def test_initialization(self): """Tests that the optimizer initializes with the correct defaults.""" base_optimizer = optimizers.Adam() - coord = CoordinatedOptimizer( - base_optimizer, world_size=4, distributed_backend=None - ) + coord = CoordinatedOptimizer(base_optimizer, world_size=4) self.assertEqual(coord.base_optimizer, base_optimizer) self.assertTrue(coord.shard_optimizer_states) self.assertEqual(coord.sharded_states, {}) @@ -63,7 +63,6 @@ def __init__(self, *args, **kwargs): def apply_gradients(self, grads_and_vars, *args, **kwargs): self.apply_gradients_call_count += 1 self.received_grads = [g for g, v in grads_and_vars] - # Call the superclass method to ensure variables are updated super().apply_gradients(grads_and_vars, *args, **kwargs) world_size = 4 @@ -76,30 +75,24 @@ def apply_gradients(self, grads_and_vars, *args, **kwargs): optimizer, world_size, shard_optimizer_states=False, - distributed_backend=None, ) coord.apply_gradients(mock_grads, []) self.assertEqual(optimizer.apply_gradients_call_count, 1) - # The average of multipliers 1, 2, 3, 4 is (1+2+3+4)/4 = 10/4 = 2.5 self.assertAllClose( optimizer.received_grads[0], np.ones_like(optimizer.received_grads[0]) * 2.5, ) def test_init_from_string(self): - optimizer = TensorParallelOptimizer( - "adam", world_size=4, distributed_backend=None - ) + optimizer = TensorParallelOptimizer("adam", world_size=4) self.assertIsInstance(optimizer.base_optimizer, optimizers.Adam) def test_apply_gradients_delegation(self): """Tests that apply_gradients correctly delegates.""" world_size = 4 base_opt = optimizers.Adam() - optimizer = TensorParallelOptimizer( - base_opt, world_size, distributed_backend=None - ) + optimizer = TensorParallelOptimizer(base_opt, world_size) model = self._get_simple_model() mock_grads = self._get_mock_gradients_and_vars(model, world_size) @@ -129,11 +122,8 @@ def base_apply_mock(*args, **kwargs): def test_build_and_state_sharding(self): """Tests that the build method correctly initializes sharded states.""" - optimizer = TensorParallelOptimizer( - optimizers.Adam(), world_size=4, distributed_backend=None - ) + optimizer = TensorParallelOptimizer(optimizers.Adam(), world_size=4) model = self._get_simple_model() - model.build(input_shape=(None, 10)) self.assertEqual(optimizer.coordinated_optimizer.sharded_states, {}) @@ -163,5 +153,28 @@ def test_serialization(self): self.assertEqual(recreated.world_size, world_size) self.assertIsInstance(recreated.base_optimizer, optimizers.Adam) - self.assertIsNone(recreated.coordinated_optimizer.distributed_backend) + self.assertIsNone(recreated.distributed_backend) self.assertAllClose(recreated.base_optimizer.learning_rate, 0.1) + + def test_sharding_with_prefixed_variable_names(self): + """Tests that state is correctly mapped with prefixed variable names.""" + inputs = keras.Input(shape=(10,)) + x = keras.layers.Dense(4, name="dense")(inputs) + outputs = keras.layers.Dense(2, name="dense_output")(x) + model = keras.Model(inputs, outputs) + model.build(input_shape=(None, 10)) + + optimizer = TensorParallelOptimizer(optimizers.Adam(), world_size=2) + optimizer.build(model.trainable_variables) + + state_to_param = ( + optimizer.coordinated_optimizer._state_variable_to_parameter + ) + self.assertGreater(len(state_to_param), 0) + + dense_output_kernel = model.get_layer("dense_output").kernel + optimizer_name = optimizer.base_optimizer.name + kernel_path = dense_output_kernel.path.replace("/", "_") + momentum_path = f"{optimizer_name}/{kernel_path}_momentum" + + self.assertIs(state_to_param[momentum_path], dense_output_kernel) diff --git a/keras/src/distribution/tensor_parallel/sharding_keras.py b/keras/src/distribution/tensor_parallel/sharding_keras.py index ace810adb024..012234cb77f4 100644 --- a/keras/src/distribution/tensor_parallel/sharding_keras.py +++ b/keras/src/distribution/tensor_parallel/sharding_keras.py @@ -52,12 +52,9 @@ def get_shard_parameters(self, shard_index: int) -> Dict[str, Any]: shard = self.model_shards[shard_index] params = {} - for layer in shard.layers: - name = layer.name - if hasattr(layer, "weights") and layer.weights: - for i, weight in enumerate(layer.weights): - param_name = f"{name}.weight_{i}" - params[param_name] = weight + for weight in shard.weights: + param_name = weight.path.replace("/", ".") + params[param_name] = weight return params From 5824c66627617e25b087144060da34658562cb36 Mon Sep 17 00:00:00 2001 From: Suhana Date: Fri, 3 Oct 2025 11:57:46 +0530 Subject: [PATCH 6/9] Reformatting according to changes in distributed_backend --- keras/src/distribution/tensor_parallel/coordinated_optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py index 260d719d3985..94b6730b2f20 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -39,7 +39,6 @@ def __init__( base_optimizer: optimizers.Optimizer, world_size: int, distributed_backend: str = "auto", - rank: int = 0, shard_optimizer_states: bool = True, tensor_parallel_config=None, ): From 9cf5c7fe31ff7d7e8affeb4f8b061686b8c745ba Mon Sep 17 00:00:00 2001 From: Suhana Date: Mon, 6 Oct 2025 19:52:51 +0530 Subject: [PATCH 7/9] Refactoring the code --- .../tensor_parallel/autoconfig.py | 205 ++++++++++++------ .../tensor_parallel/autoconfig_test.py | 123 +++++++++-- .../tensor_parallel/coordinated_optimizer.py | 14 ++ 3 files changed, 260 insertions(+), 82 deletions(-) diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py index 9fa6db430c35..0100aeaf7a5e 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig.py +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -1,12 +1,13 @@ +from typing import Any +from typing import Dict from typing import Sequence +from typing import Set from keras.src.distribution.tensor_parallel.config import ConfigKeras from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras def analyze_dense_layer_directly(layer, module, prefix: str) -> str: - from keras.src import layers - """Analyzes a Dense layer to classify it for tensor parallelism sharding. This function inspects the layer's weight shapes to determine if it's an @@ -23,20 +24,24 @@ def analyze_dense_layer_directly(layer, module, prefix: str) -> str: A string indicating the layer's classification: 'up_projection', 'down_projection', or 'generic_dense'. """ + from keras.src import layers + if not isinstance(layer, layers.Dense): return "generic_dense" input_dim = None output_dim = None - if hasattr(layer, "kernel"): + if hasattr(layer, "kernel") and layer.kernel is not None: kernel_shape = layer.kernel.shape if len(kernel_shape) == 2: - input_dim = kernel_shape[0] - output_dim = kernel_shape[1] - else: + input_dim, output_dim = kernel_shape + + if input_dim is None or output_dim is None: if hasattr(layer, "units"): output_dim = layer.units + else: + return "generic_dense" if ( hasattr(layer, "input_shape") @@ -44,6 +49,8 @@ def analyze_dense_layer_directly(layer, module, prefix: str) -> str: and len(layer.input_shape) > 1 ): input_dim = layer.input_shape[-1] + else: + return "generic_dense" if not input_dim or not output_dim: return "generic_dense" @@ -60,34 +67,33 @@ def analyze_dense_layer_directly(layer, module, prefix: str) -> str: return "generic_dense" -def _traverse_and_shard_layer( +def _find_and_shard_layers( current_layer, + prefix: str, module, world_size: int, - state_rules: dict, - output_rules: dict, - processed_layers: set, - prefix: str = "", + state_rules: Dict[str, Any], + output_rules: Dict[str, Any], + processed_layers: Set[int], ): - from keras.src import layers + """Recursively traverses a Keras model to generate sharding rules. - """Traverses a layer and its sub-layers to apply sharding rules. - - This function navigates through the model's layer hierarchy. For each - layer, it identifies its type and applies appropriate sharding logic, - populating the `state_rules` and `output_rules` dictionaries. + This is an internal helper function that navigates through all layers of a + model, including nested ones. For each supported layer, it determines the + appropriate sharding strategy and populates the `state_rules` and + `output_rules` dictionaries. These dictionaries are modified in place. Args: - current_layer: The current keras.Layer object to be processed. - module: The top-level Keras Model, used for context analysis. - world_size: The total number of devices for sharding. - state_rules: The dictionary of state sharding rules to populate. - output_rules: The dictionary of output sharding rules to populate. - processed_layers: A set of layer IDs that have already been processed - to avoid redundant computation and infinite loops. - prefix: The hierarchical name prefix from parent layers, used to - construct the full unique name for the current layer. + current_layer: The Keras layer to be processed in the current step. + prefix: The hierarchical name prefix for the `current_layer`. + module: The top-level Keras model being analyzed. + world_size: The total number of devices to shard the model across. + state_rules: A dictionary with sharding rules for weights. + output_rules: A dictionary with communication rules for outputs. + processed_layers: A set of layer IDs to prevent infinite loops. """ + from keras.src import layers + if id(current_layer) in processed_layers: return processed_layers.add(id(current_layer)) @@ -100,10 +106,24 @@ def _traverse_and_shard_layer( current_layer, module, full_name ) - if mlp_type == "down_projection": + if mlp_type == "up_projection": + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 1, "column" + ) + if current_layer.use_bias: + state_rules[f"^{full_name}.bias$"] = SplitKeras( + world_size, 0, "column" + ) + output_rules[f"^{full_name}$"] = {0: "gather"} + + elif mlp_type == "down_projection": state_rules[f"^{full_name}.kernel$"] = SplitKeras( world_size, 0, "row" ) + if current_layer.use_bias: + state_rules[f"^{full_name}.bias$"] = SplitKeras( + world_size, -1, "replicated" + ) output_rules[f"^{full_name}$"] = {0: "allreduce"} else: @@ -114,27 +134,21 @@ def _traverse_and_shard_layer( state_rules[f"^{full_name}.bias$"] = SplitKeras( world_size, 0, "column" ) - output_rules[f"^{full_name}$"] = {0: "no_comm"} + output_rules[f"^{full_name}$"] = {0: "gather -1"} return elif isinstance(current_layer, layers.EinsumDense): - is_row_parallel = False - if "->" in current_layer.equation: - equation_parts = current_layer.equation.split("->") - if len(equation_parts) == 2: - input_spec = equation_parts[0].split(",")[0].strip() - output_spec = equation_parts[1].strip() - if ( - input_spec - and output_spec - and len(output_spec) < len(input_spec) - ): - is_row_parallel = True - - if is_row_parallel: + if "attention_output" in full_name or "out_proj" in full_name: state_rules[f"^{full_name}.kernel$"] = SplitKeras( world_size, 0, "row" ) + if ( + hasattr(current_layer, "bias") + and current_layer.bias is not None + ): + state_rules[f"^{full_name}.bias$"] = SplitKeras( + world_size, -1, "replicated" + ) output_rules[f"^{full_name}$"] = {0: "allreduce"} else: state_rules[f"^{full_name}.kernel$"] = SplitKeras( @@ -147,18 +161,45 @@ def _traverse_and_shard_layer( state_rules[f"^{full_name}.bias$"] = SplitKeras( world_size, 0, "column" ) - output_rules[f"^{full_name}$"] = {0: "no_comm"} + output_rules[f"^{full_name}$"] = {0: "gather -1"} return elif isinstance(current_layer, layers.Embedding): - weight_name = ( - "embeddings" if hasattr(current_layer, "embeddings") else None + state_rules[f"^{full_name}.embeddings$"] = SplitKeras( + world_size, 0, "vocab_parallel" ) - if weight_name: - state_rules[f"^{full_name}\.{weight_name}$"] = SplitKeras( - world_size, 1, "column" + output_rules[f"^{full_name}$"] = {0: "allreduce"} + return + + elif isinstance(current_layer, layers.MultiHeadAttention): + for proj in ["query", "key", "value"]: + proj_dense_name = f"_{proj}_dense" + if hasattr(current_layer, proj_dense_name): + state_rules[f"^{full_name}\.{proj_dense_name}\.kernel$"] = ( + SplitKeras(world_size, 1, "column") + ) + if getattr(current_layer, proj_dense_name).use_bias: + state_rules[f"^{full_name}\.{proj_dense_name}\.bias$"] = ( + SplitKeras(world_size, 0, "column") + ) + + output_dense_name = "_output_dense" + if hasattr(current_layer, output_dense_name): + state_rules[f"^{full_name}\.{output_dense_name}\.kernel$"] = ( + SplitKeras(world_size, 0, "row") ) - output_rules[f"^{full_name}$"] = {0: "no_comm"} + if getattr(current_layer, output_dense_name).use_bias: + state_rules[f"^{full_name}\.{output_dense_name}\.bias$"] = ( + SplitKeras(world_size, -1, "replicated") + ) + + output_rules[f"^{full_name}$"] = {0: "allreduce"} + return + + elif isinstance(current_layer, layers.Dropout): + if "rng_rules" not in state_rules: + state_rules["rng_rules"] = {} + state_rules["rng_rules"][full_name] = {"type": "parallel"} return elif isinstance( @@ -170,50 +211,80 @@ def _traverse_and_shard_layer( ), ): return - else: - if hasattr(current_layer, "layers"): - for sub_layer in current_layer.layers: - _traverse_and_shard_layer( - sub_layer, + + if hasattr(current_layer, "layers") and current_layer.layers: + for sub_layer in current_layer.layers: + _find_and_shard_layers( + sub_layer, + full_name, + module, + world_size, + state_rules, + output_rules, + processed_layers, + ) + + for attr_name in dir(current_layer): + if attr_name.startswith("_"): + continue + try: + attr = getattr(current_layer, attr_name) + if isinstance(attr, layers.Layer) and attr is not current_layer: + _find_and_shard_layers( + attr, + full_name, module, world_size, state_rules, output_rules, processed_layers, - full_name, ) + elif isinstance(attr, (list, tuple)): + for item in attr: + if isinstance(item, layers.Layer): + _find_and_shard_layers( + item, + full_name, + module, + world_size, + state_rules, + output_rules, + processed_layers, + ) + except Exception: + continue def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: - """Generates a smart, recursive sharding configuration for a Keras model. + """Generates a default sharding configuration for a Keras model. - This function traverses the layers of a given Keras model and applies a - set of heuristics to automatically determine how each layer's weights - and outputs should be sharded for tensor parallelism. It uses a helper - function to perform the recursive traversal. + This function serves as the main entry point for automatically creating a + tensor parallel sharding configuration. It traverses the model and applies + standard sharding patterns for common layer types like Dense, Embedding, and + MultiHeadAttention. Args: - module: The Keras Model to generate a sharding configuration for. - device_ids: A sequence of device identifiers, used to determine the - world size (number of devices) for sharding. + module: The Keras model or layer to be configured for sharding. + device_ids: A sequence of device IDs (e.g., `['gpu:0', 'gpu:1']`) + to shard across. The number of devices determines the `world_size`. Returns: - A ConfigKeras object containing the generated 'state_rules' (for model - parameters) and 'output_rules' (for layer outputs). + A `ConfigKeras` object containing the generated `state_rules` for + sharding weights and `output_rules` for handling communications. """ world_size = len(device_ids) state_rules = {} output_rules = {} processed_layers = set() - _traverse_and_shard_layer( + _find_and_shard_layers( current_layer=module, + prefix="", module=module, world_size=world_size, state_rules=state_rules, output_rules=output_rules, processed_layers=processed_layers, - prefix="", ) return ConfigKeras(state_rules=state_rules, output_rules=output_rules) diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py index 96467da847e0..6845f6000982 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig_test.py +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -1,10 +1,13 @@ import os +import pytest + os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" from keras import Input from keras import Model from keras import layers +from keras.src import backend from keras.src import testing from keras.src.distribution import distributed_backend from keras.src.distribution.tensor_parallel.autoconfig import ( @@ -16,22 +19,24 @@ from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras +@pytest.mark.skipif( + backend.backend() != "jax", + reason="Tensor Parallelism autoconfig tests are only for the JAX backend.", +) class TestAutoConfigKeras(testing.TestCase): def setUp(self): """Set up the test case and common variables.""" super().setUp() device_info = distributed_backend.get_device_info() self.world_size = device_info["device_count"] - self.device_ids = [f"device:{i}" for i in range(self.world_size)] + self.device_ids = [f"cpu:{i}" for i in range(self.world_size)] self.assertGreater( self.world_size, 1, "Distribution tests require more than 1 device." ) def _assert_split_keras_equal(self, rule1, rule2): - """ - Helper to compare two SplitKeras objects by their attributes. - """ + """Helper to compare two SplitKeras objects by their attributes.""" self.assertIsInstance(rule1, SplitKeras) self.assertIsInstance(rule2, SplitKeras) self.assertDictEqual(vars(rule1), vars(rule2)) @@ -65,13 +70,20 @@ def test_analyze_dense_layer(self): "down_projection", ) + generic_layer = layers.Dense(20) + generic_layer.build(input_shape=(None, 16)) + self.assertEqual( + analyze_dense_layer_directly(generic_layer, None, ""), + "generic_dense", + ) + def test_simple_mlp_sharding(self): """Tests a simple MLP with up and down projection layers.""" inputs = Input(shape=(64,)) x = layers.Dense(256, name="up_projection_layer", use_bias=True)(inputs) - outputs = layers.Dense( - 64, name="down_projection_layer", use_bias=False - )(x) + outputs = layers.Dense(64, name="down_projection_layer", use_bias=True)( + x + ) model = Model(inputs=inputs, outputs=outputs, name="simple_mlp") config = get_default_config_keras(model, self.device_ids) @@ -86,17 +98,43 @@ def test_simple_mlp_sharding(self): r"^simple_mlp.down_projection_layer.kernel$": SplitKeras( self.world_size, 0, "row" ), + r"^simple_mlp.down_projection_layer.bias$": SplitKeras( + self.world_size, -1, "replicated" + ), } expected_output_rules = { - r"^simple_mlp.up_projection_layer$": {0: "no_comm"}, + r"^simple_mlp.up_projection_layer$": {0: "gather"}, r"^simple_mlp.down_projection_layer$": {0: "allreduce"}, } self._assert_rules_equal(config.state_rules, expected_state_rules) self._assert_rules_equal(config.output_rules, expected_output_rules) + def test_generic_dense_sharding(self): + """Tests a generic Dense layer that isn't an up/down projection.""" + inputs = Input(shape=(64,)) + outputs = layers.Dense(80, name="generic_layer", use_bias=True)(inputs) + model = Model(inputs=inputs, outputs=outputs, name="generic_model") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + r"^generic_model.generic_layer.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^generic_model.generic_layer.bias$": SplitKeras( + self.world_size, 0, "column" + ), + } + expected_output_rules = { + r"^generic_model.generic_layer$": {0: "gather -1"} + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + def test_embedding_sharding(self): - """Tests an Embedding layer.""" + """Tests an Embedding layer for vocabulary parallelism.""" inputs = Input(shape=(10,), dtype="int32") outputs = layers.Embedding( input_dim=1000, output_dim=128, name="token_embedding" @@ -106,28 +144,79 @@ def test_embedding_sharding(self): config = get_default_config_keras(model, self.device_ids) expected_state_rules = { - r"^embed_model.token_embedding\.embeddings$": SplitKeras( - self.world_size, 1, "column" + r"^embed_model.token_embedding.embeddings$": SplitKeras( + self.world_size, 0, "vocab_parallel" ) } expected_output_rules = { - r"^embed_model.token_embedding$": {0: "no_comm"} + r"^embed_model.token_embedding$": {0: "allreduce"} } self._assert_rules_equal(config.state_rules, expected_state_rules) self._assert_rules_equal(config.output_rules, expected_output_rules) + def test_einsum_dense_sharding(self): + """Tests the special handling for EinsumDense layers.""" + inputs = Input(shape=(64,)) + x = layers.EinsumDense( + "bh,hd->bd", output_shape=128, name="query_proj" + )(inputs) + outputs = layers.EinsumDense( + "bd,dh->bh", output_shape=64, name="attention_output" + )(x) + model = Model(inputs=inputs, outputs=outputs, name="einsum_model") + + config = get_default_config_keras(model, self.device_ids) + + expected_state_rules = { + r"^einsum_model.query_proj.kernel$": SplitKeras( + self.world_size, 1, "column" + ), + r"^einsum_model.attention_output.kernel$": SplitKeras( + self.world_size, 0, "row" + ), + } + expected_output_rules = { + r"^einsum_model.query_proj$": {0: "gather -1"}, + r"^einsum_model.attention_output$": {0: "allreduce"}, + } + + self._assert_rules_equal(config.state_rules, expected_state_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) + + def test_normalization_layers_ignored(self): + """Tests that normalization layers are correctly ignored.""" + inputs = Input(shape=(64,)) + x = layers.Dense(64, name="dense1", use_bias=True)(inputs) + x = layers.LayerNormalization(name="layernorm")(x) + outputs = layers.Dense(64, name="dense2", use_bias=True)(x) + model = Model(inputs=inputs, outputs=outputs, name="norm_model") + + config = get_default_config_keras(model, self.device_ids) + + for key in config.state_rules: + self.assertNotIn("layernorm", key) + for key in config.output_rules: + self.assertNotIn("layernorm", key) + + self.assertIn(r"^norm_model.dense1.kernel$", config.state_rules) + self.assertIn(r"^norm_model.dense2.kernel$", config.state_rules) + self.assertEqual(len(config.state_rules), 4) + self.assertEqual(len(config.output_rules), 2) + def test_nested_model_sharding(self): """Tests that the traversal logic correctly handles nested models.""" inner_inputs = Input(shape=(32,)) - inner_outputs = layers.Dense(128, name="inner_dense")(inner_inputs) + inner_outputs = layers.Dense(128, name="inner_dense", use_bias=True)( + inner_inputs + ) inner_model = Model( inputs=inner_inputs, outputs=inner_outputs, name="inner_block" ) outer_inputs = Input(shape=(32,)) x = inner_model(outer_inputs) - outer_outputs = layers.Dense(32, name="outer_dense")(x) + outer_outputs = layers.Dense(32, name="outer_dense", use_bias=True)(x) outer_model = Model( inputs=outer_inputs, outputs=outer_outputs, name="outer_model" ) @@ -144,11 +233,15 @@ def test_nested_model_sharding(self): r"^outer_model.outer_dense.kernel$": SplitKeras( self.world_size, 0, "row" ), + r"^outer_model.outer_dense.bias$": SplitKeras( + self.world_size, -1, "replicated" + ), } expected_output_rules = { - r"^outer_model.inner_block.inner_dense$": {0: "no_comm"}, + r"^outer_model.inner_block.inner_dense$": {0: "gather"}, r"^outer_model.outer_dense$": {0: "allreduce"}, } + self.maxDiff = None self._assert_rules_equal(config.state_rules, expected_state_rules) self._assert_rules_equal(config.output_rules, expected_output_rules) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py index 94b6730b2f20..ca7f8e5d5fcc 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -39,6 +39,7 @@ def __init__( base_optimizer: optimizers.Optimizer, world_size: int, distributed_backend: str = "auto", + rank: int = 0, shard_optimizer_states: bool = True, tensor_parallel_config=None, ): @@ -531,6 +532,19 @@ def get_config(self) -> dict[str, Any]: ) return config + def update_step(self, gradient, variable, *args, **kwargs): + if hasattr(self.base_optimizer, "update_step"): + try: + return self.base_optimizer.update_step( + gradient, variable, *args, **kwargs + ) + except TypeError: + return self.base_optimizer.update_step(gradient, variable) + try: + return super().update_step(gradient, variable, *args, **kwargs) + except TypeError: + return super().update_step(gradient, variable) + @classmethod def from_config(cls, config: dict[str, Any]) -> "TensorParallelOptimizer": from keras.src import saving From 996a154df5c14da463d94267e3c73d4c4971ecb6 Mon Sep 17 00:00:00 2001 From: Suhana Date: Mon, 6 Oct 2025 20:49:36 +0530 Subject: [PATCH 8/9] refactoring --- .../tensor_parallel/autoconfig.py | 288 +++++++----------- .../tensor_parallel/autoconfig_test.py | 51 ++-- .../tensor_parallel/coordinated_optimizer.py | 10 +- .../coordinated_optimizer_test.py | 6 +- 4 files changed, 140 insertions(+), 215 deletions(-) diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py index 0100aeaf7a5e..9b3a80726b75 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig.py +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -1,70 +1,67 @@ -from typing import Any -from typing import Dict -from typing import Sequence -from typing import Set +from typing import Sequence, Dict, Any, Set from keras.src.distribution.tensor_parallel.config import ConfigKeras from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras def analyze_dense_layer_directly(layer, module, prefix: str) -> str: - """Analyzes a Dense layer to classify it for tensor parallelism sharding. + """Analyzes a Keras Dense layer to classify its sharding strategy. - This function inspects the layer's weight shapes to determine if it's an - "up-projection" (expanding feature dimensions), a "down-projection" - (contracting feature dimensions), or a generic layer. This classification - helps in deciding whether to apply column-wise or row-wise parallelism. + This function inspects the input and output dimensions of a Dense layer + to determine if it functions as an expansion layer ("up-projection"), a + contraction layer ("down-projection"), or neither ("generic_dense"). This + classification is a heuristic commonly used to apply tensor parallelism + in Transformer-based models, such as in an MLP block where an up-projection + is followed by a down-projection. Args: - layer: The keras.layers.Dense instance to analyze. - module: The parent Keras model containing the layer. - prefix: The hierarchical name prefix for the layer. + layer: The Keras `layers.Dense` instance to analyze. + module: The parent module containing the layer (currently unused). + prefix (str): The name prefix for the layer in the model hierarchy + (currently unused). Returns: - A string indicating the layer's classification: 'up_projection', + str: A string classifying the layer as 'up_projection', 'down_projection', or 'generic_dense'. """ from keras.src import layers if not isinstance(layer, layers.Dense): - return "generic_dense" + return 'generic_dense' input_dim = None output_dim = None - if hasattr(layer, "kernel") and layer.kernel is not None: + if hasattr(layer, 'kernel') and layer.kernel is not None: kernel_shape = layer.kernel.shape if len(kernel_shape) == 2: - input_dim, output_dim = kernel_shape + input_dim = kernel_shape[0] + output_dim = kernel_shape[1] if input_dim is None or output_dim is None: - if hasattr(layer, "units"): + if hasattr(layer, 'units'): output_dim = layer.units else: - return "generic_dense" + return 'generic_dense' - if ( - hasattr(layer, "input_shape") - and layer.input_shape - and len(layer.input_shape) > 1 - ): + if hasattr(layer, 'input_shape') and layer.input_shape and len(layer.input_shape) > 1: input_dim = layer.input_shape[-1] else: - return "generic_dense" + return 'generic_dense' if not input_dim or not output_dim: - return "generic_dense" + return 'generic_dense' expansion_threshold = 1.5 is_expansion = output_dim > input_dim * expansion_threshold is_contraction = input_dim > output_dim * expansion_threshold if is_expansion: - return "up_projection" + return 'up_projection' elif is_contraction: - return "down_projection" + return 'down_projection' else: - return "generic_dense" + return 'generic_dense' def _find_and_shard_layers( @@ -76,21 +73,35 @@ def _find_and_shard_layers( output_rules: Dict[str, Any], processed_layers: Set[int], ): - """Recursively traverses a Keras model to generate sharding rules. - - This is an internal helper function that navigates through all layers of a - model, including nested ones. For each supported layer, it determines the - appropriate sharding strategy and populates the `state_rules` and - `output_rules` dictionaries. These dictionaries are modified in place. + """Recursively traverses the model graph to apply sharding rules. + + This function walks through all nested layers of a given Keras model or + layer. For each encountered layer, it determines an appropriate tensor + parallelism strategy and populates the `state_rules` and `output_rules` + dictionaries with the corresponding sharding actions. It uses a set of + processed layer IDs to avoid redundant processing of shared layers. + + The sharding logic is as follows: + - `Dense` layers are sharded based on their classification (up/down proj). + - Up-projections are split along the column axis (output features). + - Down-projections are split along the row axis (input features). + - `EinsumDense` layers in attention blocks are sharded similarly. + - `Embedding` layers are sharded column-wise for vocabulary parallelism. + - Normalization layers are ignored (replicated on all devices). Args: - current_layer: The Keras layer to be processed in the current step. - prefix: The hierarchical name prefix for the `current_layer`. - module: The top-level Keras model being analyzed. - world_size: The total number of devices to shard the model across. - state_rules: A dictionary with sharding rules for weights. - output_rules: A dictionary with communication rules for outputs. - processed_layers: A set of layer IDs to prevent infinite loops. + current_layer: The Keras layer currently being processed. + prefix (str): The hierarchical name prefix for the `current_layer`. + module: The top-level Keras model or layer being configured. + world_size (int): The total number of devices for sharding. + state_rules (Dict[str, Any]): A dictionary to be populated with rules for + sharding layer weights (state). Keys are regex patterns matching + weight names, values are `SplitKeras` actions. + output_rules (Dict[str, Any]): A dictionary to be populated with rules + for handling layer outputs. Keys are regex patterns matching layer + names, values describe the communication op (e.g., 'allreduce'). + processed_layers (Set[int]): A set of `id()`s of layers that have + already been processed to prevent cycles and redundant work. """ from keras.src import layers @@ -102,175 +113,107 @@ def _find_and_shard_layers( full_name = f"{prefix}.{name}" if prefix else name if isinstance(current_layer, layers.Dense): - mlp_type = analyze_dense_layer_directly( - current_layer, module, full_name - ) + mlp_type = analyze_dense_layer_directly(current_layer, module, full_name) - if mlp_type == "up_projection": - state_rules[f"^{full_name}.kernel$"] = SplitKeras( - world_size, 1, "column" - ) + if mlp_type == 'up_projection': + state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 1, "column") if current_layer.use_bias: - state_rules[f"^{full_name}.bias$"] = SplitKeras( - world_size, 0, "column" - ) + state_rules[f"^{full_name}.bias$"] = SplitKeras(world_size, 0, "column") output_rules[f"^{full_name}$"] = {0: "gather"} - elif mlp_type == "down_projection": - state_rules[f"^{full_name}.kernel$"] = SplitKeras( - world_size, 0, "row" - ) - if current_layer.use_bias: - state_rules[f"^{full_name}.bias$"] = SplitKeras( - world_size, -1, "replicated" - ) + elif mlp_type == 'down_projection': + state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 0, "row") output_rules[f"^{full_name}$"] = {0: "allreduce"} else: - state_rules[f"^{full_name}.kernel$"] = SplitKeras( - world_size, 1, "column" - ) + state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 1, "column") if current_layer.use_bias: - state_rules[f"^{full_name}.bias$"] = SplitKeras( - world_size, 0, "column" - ) + state_rules[f"^{full_name}.bias$"] = SplitKeras(world_size, 0, "column") output_rules[f"^{full_name}$"] = {0: "gather -1"} return elif isinstance(current_layer, layers.EinsumDense): - if "attention_output" in full_name or "out_proj" in full_name: - state_rules[f"^{full_name}.kernel$"] = SplitKeras( - world_size, 0, "row" - ) - if ( - hasattr(current_layer, "bias") - and current_layer.bias is not None - ): - state_rules[f"^{full_name}.bias$"] = SplitKeras( - world_size, -1, "replicated" - ) + if "attention_output" in full_name: + state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 0, "row") + if hasattr(current_layer, 'bias') and current_layer.bias is not None: + pass output_rules[f"^{full_name}$"] = {0: "allreduce"} else: - state_rules[f"^{full_name}.kernel$"] = SplitKeras( - world_size, 1, "column" - ) - if ( - hasattr(current_layer, "bias") - and current_layer.bias is not None - ): - state_rules[f"^{full_name}.bias$"] = SplitKeras( - world_size, 0, "column" - ) + state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 1, "column") + if hasattr(current_layer, 'bias') and current_layer.bias is not None: + state_rules[f"^{full_name}.bias$"] = SplitKeras(world_size, 0, "column") output_rules[f"^{full_name}$"] = {0: "gather -1"} return - elif isinstance(current_layer, layers.Embedding): - state_rules[f"^{full_name}.embeddings$"] = SplitKeras( - world_size, 0, "vocab_parallel" - ) - output_rules[f"^{full_name}$"] = {0: "allreduce"} - return - - elif isinstance(current_layer, layers.MultiHeadAttention): - for proj in ["query", "key", "value"]: - proj_dense_name = f"_{proj}_dense" - if hasattr(current_layer, proj_dense_name): - state_rules[f"^{full_name}\.{proj_dense_name}\.kernel$"] = ( - SplitKeras(world_size, 1, "column") - ) - if getattr(current_layer, proj_dense_name).use_bias: - state_rules[f"^{full_name}\.{proj_dense_name}\.bias$"] = ( - SplitKeras(world_size, 0, "column") - ) - - output_dense_name = "_output_dense" - if hasattr(current_layer, output_dense_name): - state_rules[f"^{full_name}\.{output_dense_name}\.kernel$"] = ( - SplitKeras(world_size, 0, "row") - ) - if getattr(current_layer, output_dense_name).use_bias: - state_rules[f"^{full_name}\.{output_dense_name}\.bias$"] = ( - SplitKeras(world_size, -1, "replicated") - ) - - output_rules[f"^{full_name}$"] = {0: "allreduce"} - return - - elif isinstance(current_layer, layers.Dropout): - if "rng_rules" not in state_rules: - state_rules["rng_rules"] = {} - state_rules["rng_rules"][full_name] = {"type": "parallel"} - return - - elif isinstance( - current_layer, - ( - layers.LayerNormalization, - layers.BatchNormalization, - layers.GroupNormalization, - ), - ): + elif isinstance(current_layer, (layers.Embedding,)): + if hasattr(current_layer, 'token_embedding') or hasattr(current_layer, 'position_embedding'): + pass + else: + weight_name = None + if hasattr(current_layer, 'embeddings'): + weight_name = 'embeddings' + elif hasattr(current_layer, 'position_embeddings'): + weight_name = 'position_embeddings' + + if weight_name: + state_rules[f"^{full_name}\\..*{weight_name}$"] = SplitKeras(world_size, 1, "column") + output_rules[f"^{full_name}$"] = {0: "no_comm"} + return + + elif isinstance(current_layer, (layers.LayerNormalization, layers.BatchNormalization, layers.GroupNormalization)): return - if hasattr(current_layer, "layers") and current_layer.layers: + if hasattr(current_layer, 'layers') and current_layer.layers: for sub_layer in current_layer.layers: _find_and_shard_layers( - sub_layer, - full_name, - module, - world_size, - state_rules, - output_rules, - processed_layers, + sub_layer, full_name, module, world_size, + state_rules, output_rules, processed_layers ) for attr_name in dir(current_layer): - if attr_name.startswith("_"): + if attr_name.startswith('__') and attr_name.endswith('__'): continue - try: + if hasattr(current_layer, attr_name): attr = getattr(current_layer, attr_name) + if isinstance(attr, layers.Layer) and attr is not current_layer: _find_and_shard_layers( - attr, - full_name, - module, - world_size, - state_rules, - output_rules, - processed_layers, + attr, full_name, module, world_size, + state_rules, output_rules, processed_layers ) elif isinstance(attr, (list, tuple)): for item in attr: if isinstance(item, layers.Layer): _find_and_shard_layers( - item, - full_name, - module, - world_size, - state_rules, - output_rules, - processed_layers, + item, full_name, module, world_size, + state_rules, output_rules, processed_layers ) - except Exception: - continue - def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: - """Generates a default sharding configuration for a Keras model. + """Generates a default tensor parallelism sharding configuration for a model. + + This function serves as the entry point for automatically creating a sharding + plan for a given Keras model or layer. It initializes the rule dictionaries + and starts the recursive layer traversal to populate them based on a default + set of heuristics for common architectures like Transformers. - This function serves as the main entry point for automatically creating a - tensor parallel sharding configuration. It traverses the model and applies - standard sharding patterns for common layer types like Dense, Embedding, and - MultiHeadAttention. + Example: + ```python + model = MyTransformerModel() + device_ids = ["gpu:0", "gpu:1"] + sharding_config = get_default_config_keras(model, device_ids) + # sharding_config can now be used to distribute the model + ``` Args: - module: The Keras model or layer to be configured for sharding. - device_ids: A sequence of device IDs (e.g., `['gpu:0', 'gpu:1']`) - to shard across. The number of devices determines the `world_size`. + module: The Keras `Model` or `Layer` to generate a config for. + device_ids (Sequence[str]): A sequence of device IDs (e.g., + ["gpu:0", "gpu:1"]) across which the model will be sharded. Returns: - A `ConfigKeras` object containing the generated `state_rules` for - sharding weights and `output_rules` for handling communications. + ConfigKeras: A configuration object containing the generated sharding + rules for model weights (`state_rules`) and layer outputs + (`output_rules`). """ world_size = len(device_ids) state_rules = {} @@ -284,7 +227,10 @@ def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: world_size=world_size, state_rules=state_rules, output_rules=output_rules, - processed_layers=processed_layers, + processed_layers=processed_layers ) - return ConfigKeras(state_rules=state_rules, output_rules=output_rules) + return ConfigKeras( + state_rules=state_rules, + output_rules=output_rules + ) \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py index 6845f6000982..8e549e11c74d 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig_test.py +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -1,5 +1,4 @@ import os - import pytest os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" @@ -10,10 +9,9 @@ from keras.src import backend from keras.src import testing from keras.src.distribution import distributed_backend + from keras.src.distribution.tensor_parallel.autoconfig import ( analyze_dense_layer_directly, -) -from keras.src.distribution.tensor_parallel.autoconfig import ( get_default_config_keras, ) from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras @@ -21,7 +19,7 @@ @pytest.mark.skipif( backend.backend() != "jax", - reason="Tensor Parallelism autoconfig tests are only for the JAX backend.", + reason="Tensor Parallelism autoconfig tests are only for the JAX backend." ) class TestAutoConfigKeras(testing.TestCase): def setUp(self): @@ -43,9 +41,7 @@ def _assert_split_keras_equal(self, rule1, rule2): def _assert_rules_equal(self, actual_rules, expected_rules): """Helper to compare two dictionaries of sharding rules.""" - self.assertSetEqual( - set(actual_rules.keys()), set(expected_rules.keys()) - ) + self.assertSetEqual(set(actual_rules.keys()), set(expected_rules.keys())) for key in expected_rules: actual_val = actual_rules[key] expected_val = expected_rules[key] @@ -59,31 +55,26 @@ def test_analyze_dense_layer(self): up_proj_layer = layers.Dense(32) up_proj_layer.build(input_shape=(None, 16)) self.assertEqual( - analyze_dense_layer_directly(up_proj_layer, None, ""), - "up_projection", + analyze_dense_layer_directly(up_proj_layer, None, ""), "up_projection" ) down_proj_layer = layers.Dense(16) down_proj_layer.build(input_shape=(None, 32)) self.assertEqual( - analyze_dense_layer_directly(down_proj_layer, None, ""), - "down_projection", + analyze_dense_layer_directly(down_proj_layer, None, ""), "down_projection" ) generic_layer = layers.Dense(20) generic_layer.build(input_shape=(None, 16)) self.assertEqual( - analyze_dense_layer_directly(generic_layer, None, ""), - "generic_dense", + analyze_dense_layer_directly(generic_layer, None, ""), "generic_dense" ) def test_simple_mlp_sharding(self): """Tests a simple MLP with up and down projection layers.""" inputs = Input(shape=(64,)) x = layers.Dense(256, name="up_projection_layer", use_bias=True)(inputs) - outputs = layers.Dense(64, name="down_projection_layer", use_bias=True)( - x - ) + outputs = layers.Dense(64, name="down_projection_layer", use_bias=True)(x) model = Model(inputs=inputs, outputs=outputs, name="simple_mlp") config = get_default_config_keras(model, self.device_ids) @@ -98,9 +89,6 @@ def test_simple_mlp_sharding(self): r"^simple_mlp.down_projection_layer.kernel$": SplitKeras( self.world_size, 0, "row" ), - r"^simple_mlp.down_projection_layer.bias$": SplitKeras( - self.world_size, -1, "replicated" - ), } expected_output_rules = { r"^simple_mlp.up_projection_layer$": {0: "gather"}, @@ -144,13 +132,11 @@ def test_embedding_sharding(self): config = get_default_config_keras(model, self.device_ids) expected_state_rules = { - r"^embed_model.token_embedding.embeddings$": SplitKeras( - self.world_size, 0, "vocab_parallel" + r"^embed_model.token_embedding\..*embeddings$": SplitKeras( + self.world_size, 1, "column" ) } - expected_output_rules = { - r"^embed_model.token_embedding$": {0: "allreduce"} - } + expected_output_rules = {r"^embed_model.token_embedding$": {0: "no_comm"}} self._assert_rules_equal(config.state_rules, expected_state_rules) self._assert_rules_equal(config.output_rules, expected_output_rules) @@ -190,7 +176,9 @@ def test_normalization_layers_ignored(self): x = layers.Dense(64, name="dense1", use_bias=True)(inputs) x = layers.LayerNormalization(name="layernorm")(x) outputs = layers.Dense(64, name="dense2", use_bias=True)(x) - model = Model(inputs=inputs, outputs=outputs, name="norm_model") + model = Model( + inputs=inputs, outputs=outputs, name="norm_model" + ) config = get_default_config_keras(model, self.device_ids) @@ -207,9 +195,7 @@ def test_normalization_layers_ignored(self): def test_nested_model_sharding(self): """Tests that the traversal logic correctly handles nested models.""" inner_inputs = Input(shape=(32,)) - inner_outputs = layers.Dense(128, name="inner_dense", use_bias=True)( - inner_inputs - ) + inner_outputs = layers.Dense(128, name="inner_dense", use_bias=True)(inner_inputs) inner_model = Model( inputs=inner_inputs, outputs=inner_outputs, name="inner_block" ) @@ -222,7 +208,7 @@ def test_nested_model_sharding(self): ) config = get_default_config_keras(outer_model, self.device_ids) - + expected_state_rules = { r"^outer_model.inner_block.inner_dense.kernel$": SplitKeras( self.world_size, 1, "column" @@ -233,15 +219,12 @@ def test_nested_model_sharding(self): r"^outer_model.outer_dense.kernel$": SplitKeras( self.world_size, 0, "row" ), - r"^outer_model.outer_dense.bias$": SplitKeras( - self.world_size, -1, "replicated" - ), } expected_output_rules = { r"^outer_model.inner_block.inner_dense$": {0: "gather"}, r"^outer_model.outer_dense$": {0: "allreduce"}, } - + self.maxDiff = None self._assert_rules_equal(config.state_rules, expected_state_rules) - self._assert_rules_equal(config.output_rules, expected_output_rules) + self._assert_rules_equal(config.output_rules, expected_output_rules) \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py index ca7f8e5d5fcc..99fa58592076 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -531,13 +531,11 @@ def get_config(self) -> dict[str, Any]: } ) return config - + def update_step(self, gradient, variable, *args, **kwargs): - if hasattr(self.base_optimizer, "update_step"): + if hasattr(self.base_optimizer, 'update_step'): try: - return self.base_optimizer.update_step( - gradient, variable, *args, **kwargs - ) + return self.base_optimizer.update_step(gradient, variable, *args, **kwargs) except TypeError: return self.base_optimizer.update_step(gradient, variable) try: @@ -611,4 +609,4 @@ def iterations(self): """ if self.base_optimizer.iterations is None: return None - return self.base_optimizer.iterations - 1 + return self.base_optimizer.iterations - 1 \ No newline at end of file diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py index 39cce46de72c..ba9438a5d9ab 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -173,8 +173,6 @@ def test_sharding_with_prefixed_variable_names(self): self.assertGreater(len(state_to_param), 0) dense_output_kernel = model.get_layer("dense_output").kernel - optimizer_name = optimizer.base_optimizer.name - kernel_path = dense_output_kernel.path.replace("/", "_") - momentum_path = f"{optimizer_name}/{kernel_path}_momentum" + momentum_path = f"{optimizer.base_optimizer.name}/{dense_output_kernel.path.replace('/', '_')}_momentum" - self.assertIs(state_to_param[momentum_path], dense_output_kernel) + self.assertIs(state_to_param[momentum_path], dense_output_kernel) \ No newline at end of file From 31994dab60d055e96dbebe2ae40297cca3334e00 Mon Sep 17 00:00:00 2001 From: Suhana Date: Mon, 6 Oct 2025 20:54:30 +0530 Subject: [PATCH 9/9] refactoring --- .../tensor_parallel/autoconfig.py | 151 ++++++++++++------ .../tensor_parallel/autoconfig_test.py | 41 +++-- .../tensor_parallel/coordinated_optimizer.py | 10 +- .../coordinated_optimizer_test.py | 6 +- 4 files changed, 139 insertions(+), 69 deletions(-) diff --git a/keras/src/distribution/tensor_parallel/autoconfig.py b/keras/src/distribution/tensor_parallel/autoconfig.py index 9b3a80726b75..32d6734860cc 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig.py +++ b/keras/src/distribution/tensor_parallel/autoconfig.py @@ -1,4 +1,7 @@ -from typing import Sequence, Dict, Any, Set +from typing import Any +from typing import Dict +from typing import Sequence +from typing import Set from keras.src.distribution.tensor_parallel.config import ConfigKeras from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras @@ -27,41 +30,45 @@ def analyze_dense_layer_directly(layer, module, prefix: str) -> str: from keras.src import layers if not isinstance(layer, layers.Dense): - return 'generic_dense' + return "generic_dense" input_dim = None output_dim = None - if hasattr(layer, 'kernel') and layer.kernel is not None: + if hasattr(layer, "kernel") and layer.kernel is not None: kernel_shape = layer.kernel.shape if len(kernel_shape) == 2: input_dim = kernel_shape[0] output_dim = kernel_shape[1] if input_dim is None or output_dim is None: - if hasattr(layer, 'units'): + if hasattr(layer, "units"): output_dim = layer.units else: - return 'generic_dense' + return "generic_dense" - if hasattr(layer, 'input_shape') and layer.input_shape and len(layer.input_shape) > 1: + if ( + hasattr(layer, "input_shape") + and layer.input_shape + and len(layer.input_shape) > 1 + ): input_dim = layer.input_shape[-1] else: - return 'generic_dense' + return "generic_dense" if not input_dim or not output_dim: - return 'generic_dense' + return "generic_dense" expansion_threshold = 1.5 is_expansion = output_dim > input_dim * expansion_threshold is_contraction = input_dim > output_dim * expansion_threshold if is_expansion: - return 'up_projection' + return "up_projection" elif is_contraction: - return 'down_projection' + return "down_projection" else: - return 'generic_dense' + return "generic_dense" def _find_and_shard_layers( @@ -94,10 +101,10 @@ def _find_and_shard_layers( prefix (str): The hierarchical name prefix for the `current_layer`. module: The top-level Keras model or layer being configured. world_size (int): The total number of devices for sharding. - state_rules (Dict[str, Any]): A dictionary to be populated with rules for + state_rules (Dict[str, Any]): A dictionary with rules for sharding layer weights (state). Keys are regex patterns matching weight names, values are `SplitKeras` actions. - output_rules (Dict[str, Any]): A dictionary to be populated with rules + output_rules (Dict[str, Any]): A dictionary with rules for handling layer outputs. Keys are regex patterns matching layer names, values describe the communication op (e.g., 'allreduce'). processed_layers (Set[int]): A set of `id()`s of layers that have @@ -113,86 +120,137 @@ def _find_and_shard_layers( full_name = f"{prefix}.{name}" if prefix else name if isinstance(current_layer, layers.Dense): - mlp_type = analyze_dense_layer_directly(current_layer, module, full_name) + mlp_type = analyze_dense_layer_directly( + current_layer, module, full_name + ) - if mlp_type == 'up_projection': - state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 1, "column") + if mlp_type == "up_projection": + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 1, "column" + ) if current_layer.use_bias: - state_rules[f"^{full_name}.bias$"] = SplitKeras(world_size, 0, "column") + state_rules[f"^{full_name}.bias$"] = SplitKeras( + world_size, 0, "column" + ) output_rules[f"^{full_name}$"] = {0: "gather"} - elif mlp_type == 'down_projection': - state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 0, "row") + elif mlp_type == "down_projection": + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 0, "row" + ) output_rules[f"^{full_name}$"] = {0: "allreduce"} else: - state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 1, "column") + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 1, "column" + ) if current_layer.use_bias: - state_rules[f"^{full_name}.bias$"] = SplitKeras(world_size, 0, "column") + state_rules[f"^{full_name}.bias$"] = SplitKeras( + world_size, 0, "column" + ) output_rules[f"^{full_name}$"] = {0: "gather -1"} return elif isinstance(current_layer, layers.EinsumDense): if "attention_output" in full_name: - state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 0, "row") - if hasattr(current_layer, 'bias') and current_layer.bias is not None: + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 0, "row" + ) + if ( + hasattr(current_layer, "bias") + and current_layer.bias is not None + ): pass output_rules[f"^{full_name}$"] = {0: "allreduce"} else: - state_rules[f"^{full_name}.kernel$"] = SplitKeras(world_size, 1, "column") - if hasattr(current_layer, 'bias') and current_layer.bias is not None: - state_rules[f"^{full_name}.bias$"] = SplitKeras(world_size, 0, "column") + state_rules[f"^{full_name}.kernel$"] = SplitKeras( + world_size, 1, "column" + ) + if ( + hasattr(current_layer, "bias") + and current_layer.bias is not None + ): + state_rules[f"^{full_name}.bias$"] = SplitKeras( + world_size, 0, "column" + ) output_rules[f"^{full_name}$"] = {0: "gather -1"} return elif isinstance(current_layer, (layers.Embedding,)): - if hasattr(current_layer, 'token_embedding') or hasattr(current_layer, 'position_embedding'): + if hasattr(current_layer, "token_embedding") or hasattr( + current_layer, "position_embedding" + ): pass else: weight_name = None - if hasattr(current_layer, 'embeddings'): - weight_name = 'embeddings' - elif hasattr(current_layer, 'position_embeddings'): - weight_name = 'position_embeddings' + if hasattr(current_layer, "embeddings"): + weight_name = "embeddings" + elif hasattr(current_layer, "position_embeddings"): + weight_name = "position_embeddings" if weight_name: - state_rules[f"^{full_name}\\..*{weight_name}$"] = SplitKeras(world_size, 1, "column") + state_rules[f"^{full_name}\\..*{weight_name}$"] = SplitKeras( + world_size, 1, "column" + ) output_rules[f"^{full_name}$"] = {0: "no_comm"} return - elif isinstance(current_layer, (layers.LayerNormalization, layers.BatchNormalization, layers.GroupNormalization)): + elif isinstance( + current_layer, + ( + layers.LayerNormalization, + layers.BatchNormalization, + layers.GroupNormalization, + ), + ): return - if hasattr(current_layer, 'layers') and current_layer.layers: + if hasattr(current_layer, "layers") and current_layer.layers: for sub_layer in current_layer.layers: _find_and_shard_layers( - sub_layer, full_name, module, world_size, - state_rules, output_rules, processed_layers + sub_layer, + full_name, + module, + world_size, + state_rules, + output_rules, + processed_layers, ) for attr_name in dir(current_layer): - if attr_name.startswith('__') and attr_name.endswith('__'): + if attr_name.startswith("__") and attr_name.endswith("__"): continue if hasattr(current_layer, attr_name): attr = getattr(current_layer, attr_name) if isinstance(attr, layers.Layer) and attr is not current_layer: _find_and_shard_layers( - attr, full_name, module, world_size, - state_rules, output_rules, processed_layers + attr, + full_name, + module, + world_size, + state_rules, + output_rules, + processed_layers, ) elif isinstance(attr, (list, tuple)): for item in attr: if isinstance(item, layers.Layer): _find_and_shard_layers( - item, full_name, module, world_size, - state_rules, output_rules, processed_layers + item, + full_name, + module, + world_size, + state_rules, + output_rules, + processed_layers, ) + def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: - """Generates a default tensor parallelism sharding configuration for a model. + """Generates a default tensor parallel sharding configuration for a model. - This function serves as the entry point for automatically creating a sharding + This function serves as entry point for automatically creating a sharding plan for a given Keras model or layer. It initializes the rule dictionaries and starts the recursive layer traversal to populate them based on a default set of heuristics for common architectures like Transformers. @@ -227,10 +285,7 @@ def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: world_size=world_size, state_rules=state_rules, output_rules=output_rules, - processed_layers=processed_layers + processed_layers=processed_layers, ) - return ConfigKeras( - state_rules=state_rules, - output_rules=output_rules - ) \ No newline at end of file + return ConfigKeras(state_rules=state_rules, output_rules=output_rules) diff --git a/keras/src/distribution/tensor_parallel/autoconfig_test.py b/keras/src/distribution/tensor_parallel/autoconfig_test.py index 8e549e11c74d..228a2b184569 100644 --- a/keras/src/distribution/tensor_parallel/autoconfig_test.py +++ b/keras/src/distribution/tensor_parallel/autoconfig_test.py @@ -1,4 +1,5 @@ import os + import pytest os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" @@ -9,9 +10,10 @@ from keras.src import backend from keras.src import testing from keras.src.distribution import distributed_backend - from keras.src.distribution.tensor_parallel.autoconfig import ( analyze_dense_layer_directly, +) +from keras.src.distribution.tensor_parallel.autoconfig import ( get_default_config_keras, ) from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras @@ -19,7 +21,7 @@ @pytest.mark.skipif( backend.backend() != "jax", - reason="Tensor Parallelism autoconfig tests are only for the JAX backend." + reason="Tensor Parallelism autoconfig tests are only for the JAX backend.", ) class TestAutoConfigKeras(testing.TestCase): def setUp(self): @@ -41,7 +43,9 @@ def _assert_split_keras_equal(self, rule1, rule2): def _assert_rules_equal(self, actual_rules, expected_rules): """Helper to compare two dictionaries of sharding rules.""" - self.assertSetEqual(set(actual_rules.keys()), set(expected_rules.keys())) + self.assertSetEqual( + set(actual_rules.keys()), set(expected_rules.keys()) + ) for key in expected_rules: actual_val = actual_rules[key] expected_val = expected_rules[key] @@ -55,26 +59,31 @@ def test_analyze_dense_layer(self): up_proj_layer = layers.Dense(32) up_proj_layer.build(input_shape=(None, 16)) self.assertEqual( - analyze_dense_layer_directly(up_proj_layer, None, ""), "up_projection" + analyze_dense_layer_directly(up_proj_layer, None, ""), + "up_projection", ) down_proj_layer = layers.Dense(16) down_proj_layer.build(input_shape=(None, 32)) self.assertEqual( - analyze_dense_layer_directly(down_proj_layer, None, ""), "down_projection" + analyze_dense_layer_directly(down_proj_layer, None, ""), + "down_projection", ) generic_layer = layers.Dense(20) generic_layer.build(input_shape=(None, 16)) self.assertEqual( - analyze_dense_layer_directly(generic_layer, None, ""), "generic_dense" + analyze_dense_layer_directly(generic_layer, None, ""), + "generic_dense", ) def test_simple_mlp_sharding(self): """Tests a simple MLP with up and down projection layers.""" inputs = Input(shape=(64,)) x = layers.Dense(256, name="up_projection_layer", use_bias=True)(inputs) - outputs = layers.Dense(64, name="down_projection_layer", use_bias=True)(x) + outputs = layers.Dense(64, name="down_projection_layer", use_bias=True)( + x + ) model = Model(inputs=inputs, outputs=outputs, name="simple_mlp") config = get_default_config_keras(model, self.device_ids) @@ -136,7 +145,9 @@ def test_embedding_sharding(self): self.world_size, 1, "column" ) } - expected_output_rules = {r"^embed_model.token_embedding$": {0: "no_comm"}} + expected_output_rules = { + r"^embed_model.token_embedding$": {0: "no_comm"} + } self._assert_rules_equal(config.state_rules, expected_state_rules) self._assert_rules_equal(config.output_rules, expected_output_rules) @@ -176,9 +187,7 @@ def test_normalization_layers_ignored(self): x = layers.Dense(64, name="dense1", use_bias=True)(inputs) x = layers.LayerNormalization(name="layernorm")(x) outputs = layers.Dense(64, name="dense2", use_bias=True)(x) - model = Model( - inputs=inputs, outputs=outputs, name="norm_model" - ) + model = Model(inputs=inputs, outputs=outputs, name="norm_model") config = get_default_config_keras(model, self.device_ids) @@ -195,7 +204,9 @@ def test_normalization_layers_ignored(self): def test_nested_model_sharding(self): """Tests that the traversal logic correctly handles nested models.""" inner_inputs = Input(shape=(32,)) - inner_outputs = layers.Dense(128, name="inner_dense", use_bias=True)(inner_inputs) + inner_outputs = layers.Dense(128, name="inner_dense", use_bias=True)( + inner_inputs + ) inner_model = Model( inputs=inner_inputs, outputs=inner_outputs, name="inner_block" ) @@ -208,7 +219,7 @@ def test_nested_model_sharding(self): ) config = get_default_config_keras(outer_model, self.device_ids) - + expected_state_rules = { r"^outer_model.inner_block.inner_dense.kernel$": SplitKeras( self.world_size, 1, "column" @@ -224,7 +235,7 @@ def test_nested_model_sharding(self): r"^outer_model.inner_block.inner_dense$": {0: "gather"}, r"^outer_model.outer_dense$": {0: "allreduce"}, } - + self.maxDiff = None self._assert_rules_equal(config.state_rules, expected_state_rules) - self._assert_rules_equal(config.output_rules, expected_output_rules) \ No newline at end of file + self._assert_rules_equal(config.output_rules, expected_output_rules) diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py index 99fa58592076..ca7f8e5d5fcc 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer.py @@ -531,11 +531,13 @@ def get_config(self) -> dict[str, Any]: } ) return config - + def update_step(self, gradient, variable, *args, **kwargs): - if hasattr(self.base_optimizer, 'update_step'): + if hasattr(self.base_optimizer, "update_step"): try: - return self.base_optimizer.update_step(gradient, variable, *args, **kwargs) + return self.base_optimizer.update_step( + gradient, variable, *args, **kwargs + ) except TypeError: return self.base_optimizer.update_step(gradient, variable) try: @@ -609,4 +611,4 @@ def iterations(self): """ if self.base_optimizer.iterations is None: return None - return self.base_optimizer.iterations - 1 \ No newline at end of file + return self.base_optimizer.iterations - 1 diff --git a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py index ba9438a5d9ab..39cce46de72c 100644 --- a/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py +++ b/keras/src/distribution/tensor_parallel/coordinated_optimizer_test.py @@ -173,6 +173,8 @@ def test_sharding_with_prefixed_variable_names(self): self.assertGreater(len(state_to_param), 0) dense_output_kernel = model.get_layer("dense_output").kernel - momentum_path = f"{optimizer.base_optimizer.name}/{dense_output_kernel.path.replace('/', '_')}_momentum" + optimizer_name = optimizer.base_optimizer.name + kernel_path = dense_output_kernel.path.replace("/", "_") + momentum_path = f"{optimizer_name}/{kernel_path}_momentum" - self.assertIs(state_to_param[momentum_path], dense_output_kernel) \ No newline at end of file + self.assertIs(state_to_param[momentum_path], dense_output_kernel)