-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Add Autoconfig, Coordinated_Optimizer and Sharding keras implementations for Tensor Parallel Autosharding #21707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
buildwithsuhana
wants to merge
11
commits into
keras-team:master
Choose a base branch
from
buildwithsuhana:Tensor_parallel_keras_2
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,408
−0
Draft
Changes from 6 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
dd3181e
adding autoconfig and coordinated_optimizer
buildwithsuhana bcae2f6
Reformatting
buildwithsuhana 439643b
Added sharding keras
buildwithsuhana 36edcb9
Merge branch 'keras-team:master' into Tensor_parallel_keras_2
buildwithsuhana b7862d9
Reformatting files
buildwithsuhana e8b51f7
Merge branch 'Tensor_parallel_keras_2' of https://github.com/buildwit…
buildwithsuhana 3383dec
Reformatting according to changes in distributed_backend
buildwithsuhana 5824c66
Reformatting according to changes in distributed_backend
buildwithsuhana 9cf5c7f
Refactoring the code
buildwithsuhana 996a154
refactoring
buildwithsuhana 31994da
refactoring
buildwithsuhana File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
from typing import Sequence | ||
|
||
from keras.src.distribution.tensor_parallel.config import ConfigKeras | ||
from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras | ||
|
||
|
||
def analyze_dense_layer_directly(layer, module, prefix: str) -> str: | ||
from keras.src import layers | ||
|
||
"""Analyzes a Dense layer to classify it for tensor parallelism sharding. | ||
|
||
This function inspects the layer's weight shapes to determine if it's an | ||
"up-projection" (expanding feature dimensions), a "down-projection" | ||
(contracting feature dimensions), or a generic layer. This classification | ||
helps in deciding whether to apply column-wise or row-wise parallelism. | ||
|
||
Args: | ||
layer: The keras.layers.Dense instance to analyze. | ||
module: The parent Keras model containing the layer. | ||
prefix: The hierarchical name prefix for the layer. | ||
|
||
Returns: | ||
A string indicating the layer's classification: 'up_projection', | ||
'down_projection', or 'generic_dense'. | ||
""" | ||
if not isinstance(layer, layers.Dense): | ||
return "generic_dense" | ||
|
||
input_dim = None | ||
output_dim = None | ||
|
||
if hasattr(layer, "kernel"): | ||
kernel_shape = layer.kernel.shape | ||
if len(kernel_shape) == 2: | ||
input_dim = kernel_shape[0] | ||
output_dim = kernel_shape[1] | ||
else: | ||
if hasattr(layer, "units"): | ||
output_dim = layer.units | ||
|
||
if ( | ||
hasattr(layer, "input_shape") | ||
and layer.input_shape | ||
and len(layer.input_shape) > 1 | ||
): | ||
input_dim = layer.input_shape[-1] | ||
|
||
if not input_dim or not output_dim: | ||
return "generic_dense" | ||
|
||
expansion_threshold = 1.5 | ||
is_expansion = output_dim > input_dim * expansion_threshold | ||
is_contraction = input_dim > output_dim * expansion_threshold | ||
|
||
if is_expansion: | ||
return "up_projection" | ||
elif is_contraction: | ||
return "down_projection" | ||
else: | ||
return "generic_dense" | ||
|
||
|
||
def _traverse_and_shard_layer( | ||
current_layer, | ||
module, | ||
world_size: int, | ||
state_rules: dict, | ||
output_rules: dict, | ||
processed_layers: set, | ||
prefix: str = "", | ||
): | ||
from keras.src import layers | ||
|
||
"""Traverses a layer and its sub-layers to apply sharding rules. | ||
|
||
This function navigates through the model's layer hierarchy. For each | ||
layer, it identifies its type and applies appropriate sharding logic, | ||
populating the `state_rules` and `output_rules` dictionaries. | ||
|
||
Args: | ||
current_layer: The current keras.Layer object to be processed. | ||
module: The top-level Keras Model, used for context analysis. | ||
world_size: The total number of devices for sharding. | ||
state_rules: The dictionary of state sharding rules to populate. | ||
output_rules: The dictionary of output sharding rules to populate. | ||
processed_layers: A set of layer IDs that have already been processed | ||
to avoid redundant computation and infinite loops. | ||
prefix: The hierarchical name prefix from parent layers, used to | ||
construct the full unique name for the current layer. | ||
""" | ||
if id(current_layer) in processed_layers: | ||
return | ||
processed_layers.add(id(current_layer)) | ||
|
||
name = current_layer.name | ||
full_name = f"{prefix}.{name}" if prefix else name | ||
|
||
if isinstance(current_layer, layers.Dense): | ||
mlp_type = analyze_dense_layer_directly( | ||
current_layer, module, full_name | ||
) | ||
|
||
if mlp_type == "down_projection": | ||
state_rules[f"^{full_name}.kernel$"] = SplitKeras( | ||
world_size, 0, "row" | ||
) | ||
output_rules[f"^{full_name}$"] = {0: "allreduce"} | ||
|
||
else: | ||
state_rules[f"^{full_name}.kernel$"] = SplitKeras( | ||
world_size, 1, "column" | ||
) | ||
if current_layer.use_bias: | ||
state_rules[f"^{full_name}.bias$"] = SplitKeras( | ||
world_size, 0, "column" | ||
) | ||
output_rules[f"^{full_name}$"] = {0: "no_comm"} | ||
return | ||
|
||
elif isinstance(current_layer, layers.EinsumDense): | ||
is_row_parallel = False | ||
if "->" in current_layer.equation: | ||
equation_parts = current_layer.equation.split("->") | ||
if len(equation_parts) == 2: | ||
input_spec = equation_parts[0].split(",")[0].strip() | ||
output_spec = equation_parts[1].strip() | ||
if ( | ||
input_spec | ||
and output_spec | ||
and len(output_spec) < len(input_spec) | ||
): | ||
is_row_parallel = True | ||
|
||
if is_row_parallel: | ||
state_rules[f"^{full_name}.kernel$"] = SplitKeras( | ||
world_size, 0, "row" | ||
) | ||
output_rules[f"^{full_name}$"] = {0: "allreduce"} | ||
else: | ||
state_rules[f"^{full_name}.kernel$"] = SplitKeras( | ||
world_size, 1, "column" | ||
) | ||
if ( | ||
hasattr(current_layer, "bias") | ||
and current_layer.bias is not None | ||
): | ||
state_rules[f"^{full_name}.bias$"] = SplitKeras( | ||
world_size, 0, "column" | ||
) | ||
output_rules[f"^{full_name}$"] = {0: "no_comm"} | ||
return | ||
|
||
elif isinstance(current_layer, layers.Embedding): | ||
weight_name = ( | ||
"embeddings" if hasattr(current_layer, "embeddings") else None | ||
) | ||
if weight_name: | ||
state_rules[f"^{full_name}\.{weight_name}$"] = SplitKeras( | ||
world_size, 1, "column" | ||
) | ||
output_rules[f"^{full_name}$"] = {0: "no_comm"} | ||
return | ||
|
||
elif isinstance( | ||
current_layer, | ||
( | ||
layers.LayerNormalization, | ||
layers.BatchNormalization, | ||
layers.GroupNormalization, | ||
), | ||
): | ||
return | ||
else: | ||
if hasattr(current_layer, "layers"): | ||
for sub_layer in current_layer.layers: | ||
_traverse_and_shard_layer( | ||
sub_layer, | ||
module, | ||
world_size, | ||
state_rules, | ||
output_rules, | ||
processed_layers, | ||
full_name, | ||
) | ||
|
||
|
||
def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras: | ||
"""Generates a smart, recursive sharding configuration for a Keras model. | ||
|
||
This function traverses the layers of a given Keras model and applies a | ||
set of heuristics to automatically determine how each layer's weights | ||
and outputs should be sharded for tensor parallelism. It uses a helper | ||
function to perform the recursive traversal. | ||
|
||
Args: | ||
module: The Keras Model to generate a sharding configuration for. | ||
device_ids: A sequence of device identifiers, used to determine the | ||
world size (number of devices) for sharding. | ||
|
||
Returns: | ||
A ConfigKeras object containing the generated 'state_rules' (for model | ||
parameters) and 'output_rules' (for layer outputs). | ||
""" | ||
world_size = len(device_ids) | ||
state_rules = {} | ||
output_rules = {} | ||
processed_layers = set() | ||
|
||
for layer in module.layers: | ||
_traverse_and_shard_layer( | ||
current_layer=layer, | ||
module=module, | ||
world_size=world_size, | ||
state_rules=state_rules, | ||
output_rules=output_rules, | ||
processed_layers=processed_layers, | ||
prefix="", | ||
) | ||
|
||
return ConfigKeras(state_rules=state_rules, output_rules=output_rules) |
151 changes: 151 additions & 0 deletions
151
keras/src/distribution/tensor_parallel/autoconfig_test.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
import os | ||
|
||
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2" | ||
|
||
from keras import Input | ||
from keras import Model | ||
from keras import layers | ||
from keras.src import testing | ||
from keras.src.backend.distributed import backend_resolver | ||
from keras.src.distribution.tensor_parallel.autoconfig import ( | ||
analyze_dense_layer_directly, | ||
) | ||
from keras.src.distribution.tensor_parallel.autoconfig import ( | ||
get_default_config_keras, | ||
) | ||
from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras | ||
|
||
|
||
class TestAutoConfigKeras(testing.TestCase): | ||
def setUp(self): | ||
"""Set up the test case and common variables.""" | ||
super().setUp() | ||
backend = backend_resolver.get_distributed_backend() | ||
device_info = backend.get_device_info() | ||
self.world_size = device_info["device_count"] | ||
self.device_ids = [f"device:{i}" for i in range(self.world_size)] | ||
|
||
self.assertGreater( | ||
self.world_size, 1, "Distribution tests require more than 1 device." | ||
) | ||
|
||
def _assert_split_keras_equal(self, rule1, rule2): | ||
""" | ||
Helper to compare two SplitKeras objects by their attributes. | ||
""" | ||
self.assertIsInstance(rule1, SplitKeras) | ||
self.assertIsInstance(rule2, SplitKeras) | ||
self.assertDictEqual(vars(rule1), vars(rule2)) | ||
|
||
def _assert_rules_equal(self, actual_rules, expected_rules): | ||
"""Helper to compare two dictionaries of sharding rules.""" | ||
self.assertSetEqual( | ||
set(actual_rules.keys()), set(expected_rules.keys()) | ||
) | ||
for key in expected_rules: | ||
actual_val = actual_rules[key] | ||
expected_val = expected_rules[key] | ||
if isinstance(expected_val, SplitKeras): | ||
self._assert_split_keras_equal(actual_val, expected_val) | ||
else: | ||
self.assertEqual(actual_val, expected_val) | ||
|
||
def test_analyze_dense_layer(self): | ||
"""Tests the direct analysis and classification of Dense layers.""" | ||
up_proj_layer = layers.Dense(32) | ||
up_proj_layer.build(input_shape=(None, 16)) | ||
self.assertEqual( | ||
analyze_dense_layer_directly(up_proj_layer, None, ""), | ||
"up_projection", | ||
) | ||
|
||
down_proj_layer = layers.Dense(16) | ||
down_proj_layer.build(input_shape=(None, 32)) | ||
self.assertEqual( | ||
analyze_dense_layer_directly(down_proj_layer, None, ""), | ||
"down_projection", | ||
) | ||
|
||
def test_simple_mlp_sharding(self): | ||
"""Tests a simple MLP with up and down projection layers.""" | ||
inputs = Input(shape=(64,)) | ||
x = layers.Dense(256, name="up_projection_layer", use_bias=True)(inputs) | ||
outputs = layers.Dense( | ||
64, name="down_projection_layer", use_bias=False | ||
)(x) | ||
model = Model(inputs=inputs, outputs=outputs, name="simple_mlp") | ||
|
||
config = get_default_config_keras(model, self.device_ids) | ||
|
||
expected_state_rules = { | ||
r"^up_projection_layer.kernel$": SplitKeras( | ||
self.world_size, 1, "column" | ||
), | ||
r"^up_projection_layer.bias$": SplitKeras( | ||
self.world_size, 0, "column" | ||
), | ||
r"^down_projection_layer.kernel$": SplitKeras( | ||
self.world_size, 0, "row" | ||
), | ||
} | ||
expected_output_rules = { | ||
r"^up_projection_layer$": {0: "no_comm"}, | ||
r"^down_projection_layer$": {0: "allreduce"}, | ||
} | ||
|
||
self._assert_rules_equal(config.state_rules, expected_state_rules) | ||
self._assert_rules_equal(config.output_rules, expected_output_rules) | ||
|
||
def test_embedding_sharding(self): | ||
"""Tests an Embedding layer.""" | ||
inputs = Input(shape=(10,), dtype="int32") | ||
outputs = layers.Embedding( | ||
input_dim=1000, output_dim=128, name="token_embedding" | ||
)(inputs) | ||
model = Model(inputs=inputs, outputs=outputs, name="embed_model") | ||
|
||
config = get_default_config_keras(model, self.device_ids) | ||
|
||
expected_state_rules = { | ||
r"^token_embedding\.embeddings$": SplitKeras( | ||
self.world_size, 1, "column" | ||
) | ||
} | ||
expected_output_rules = {r"^token_embedding$": {0: "no_comm"}} | ||
|
||
self._assert_rules_equal(config.state_rules, expected_state_rules) | ||
self._assert_rules_equal(config.output_rules, expected_output_rules) | ||
|
||
def test_nested_model_sharding(self): | ||
"""Tests that the traversal logic correctly handles nested models.""" | ||
inner_inputs = Input(shape=(32,)) | ||
inner_outputs = layers.Dense(128, name="inner_dense")(inner_inputs) | ||
inner_model = Model( | ||
inputs=inner_inputs, outputs=inner_outputs, name="inner_block" | ||
) | ||
|
||
outer_inputs = Input(shape=(32,)) | ||
x = inner_model(outer_inputs) | ||
outer_outputs = layers.Dense(32, name="outer_dense")(x) | ||
outer_model = Model( | ||
inputs=outer_inputs, outputs=outer_outputs, name="outer_model" | ||
) | ||
|
||
config = get_default_config_keras(outer_model, self.device_ids) | ||
|
||
expected_state_rules = { | ||
r"^inner_block.inner_dense.kernel$": SplitKeras( | ||
self.world_size, 1, "column" | ||
), | ||
r"^inner_block.inner_dense.bias$": SplitKeras( | ||
self.world_size, 0, "column" | ||
), | ||
r"^outer_dense.kernel$": SplitKeras(self.world_size, 0, "row"), | ||
} | ||
expected_output_rules = { | ||
r"^inner_block.inner_dense$": {0: "no_comm"}, | ||
r"^outer_dense$": {0: "allreduce"}, | ||
} | ||
|
||
self._assert_rules_equal(config.state_rules, expected_state_rules) | ||
self._assert_rules_equal(config.output_rules, expected_output_rules) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.