Skip to content
Draft
220 changes: 220 additions & 0 deletions keras/src/distribution/tensor_parallel/autoconfig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
from typing import Sequence

from keras.src.distribution.tensor_parallel.config import ConfigKeras
from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras


def analyze_dense_layer_directly(layer, module, prefix: str) -> str:
from keras.src import layers

"""Analyzes a Dense layer to classify it for tensor parallelism sharding.

This function inspects the layer's weight shapes to determine if it's an
"up-projection" (expanding feature dimensions), a "down-projection"
(contracting feature dimensions), or a generic layer. This classification
helps in deciding whether to apply column-wise or row-wise parallelism.

Args:
layer: The keras.layers.Dense instance to analyze.
module: The parent Keras model containing the layer.
prefix: The hierarchical name prefix for the layer.

Returns:
A string indicating the layer's classification: 'up_projection',
'down_projection', or 'generic_dense'.
"""
if not isinstance(layer, layers.Dense):
return "generic_dense"

input_dim = None
output_dim = None

if hasattr(layer, "kernel"):
kernel_shape = layer.kernel.shape
if len(kernel_shape) == 2:
input_dim = kernel_shape[0]
output_dim = kernel_shape[1]
else:
if hasattr(layer, "units"):
output_dim = layer.units

if (
hasattr(layer, "input_shape")
and layer.input_shape
and len(layer.input_shape) > 1
):
input_dim = layer.input_shape[-1]

if not input_dim or not output_dim:
return "generic_dense"

expansion_threshold = 1.5
is_expansion = output_dim > input_dim * expansion_threshold
is_contraction = input_dim > output_dim * expansion_threshold

if is_expansion:
return "up_projection"
elif is_contraction:
return "down_projection"
else:
return "generic_dense"


def _traverse_and_shard_layer(
current_layer,
module,
world_size: int,
state_rules: dict,
output_rules: dict,
processed_layers: set,
prefix: str = "",
):
from keras.src import layers

"""Traverses a layer and its sub-layers to apply sharding rules.

This function navigates through the model's layer hierarchy. For each
layer, it identifies its type and applies appropriate sharding logic,
populating the `state_rules` and `output_rules` dictionaries.

Args:
current_layer: The current keras.Layer object to be processed.
module: The top-level Keras Model, used for context analysis.
world_size: The total number of devices for sharding.
state_rules: The dictionary of state sharding rules to populate.
output_rules: The dictionary of output sharding rules to populate.
processed_layers: A set of layer IDs that have already been processed
to avoid redundant computation and infinite loops.
prefix: The hierarchical name prefix from parent layers, used to
construct the full unique name for the current layer.
"""
if id(current_layer) in processed_layers:
return
processed_layers.add(id(current_layer))

name = current_layer.name
full_name = f"{prefix}.{name}" if prefix else name

if isinstance(current_layer, layers.Dense):
mlp_type = analyze_dense_layer_directly(
current_layer, module, full_name
)

if mlp_type == "down_projection":
state_rules[f"^{full_name}.kernel$"] = SplitKeras(
world_size, 0, "row"
)
output_rules[f"^{full_name}$"] = {0: "allreduce"}

else:
state_rules[f"^{full_name}.kernel$"] = SplitKeras(
world_size, 1, "column"
)
if current_layer.use_bias:
state_rules[f"^{full_name}.bias$"] = SplitKeras(
world_size, 0, "column"
)
output_rules[f"^{full_name}$"] = {0: "no_comm"}
return

elif isinstance(current_layer, layers.EinsumDense):
is_row_parallel = False
if "->" in current_layer.equation:
equation_parts = current_layer.equation.split("->")
if len(equation_parts) == 2:
input_spec = equation_parts[0].split(",")[0].strip()
output_spec = equation_parts[1].strip()
if (
input_spec
and output_spec
and len(output_spec) < len(input_spec)
):
is_row_parallel = True

if is_row_parallel:
state_rules[f"^{full_name}.kernel$"] = SplitKeras(
world_size, 0, "row"
)
output_rules[f"^{full_name}$"] = {0: "allreduce"}
else:
state_rules[f"^{full_name}.kernel$"] = SplitKeras(
world_size, 1, "column"
)
if (
hasattr(current_layer, "bias")
and current_layer.bias is not None
):
state_rules[f"^{full_name}.bias$"] = SplitKeras(
world_size, 0, "column"
)
output_rules[f"^{full_name}$"] = {0: "no_comm"}
return

elif isinstance(current_layer, layers.Embedding):
weight_name = (
"embeddings" if hasattr(current_layer, "embeddings") else None
)
if weight_name:
state_rules[f"^{full_name}\.{weight_name}$"] = SplitKeras(
world_size, 1, "column"
)
output_rules[f"^{full_name}$"] = {0: "no_comm"}
return

elif isinstance(
current_layer,
(
layers.LayerNormalization,
layers.BatchNormalization,
layers.GroupNormalization,
),
):
return
else:
if hasattr(current_layer, "layers"):
for sub_layer in current_layer.layers:
_traverse_and_shard_layer(
sub_layer,
module,
world_size,
state_rules,
output_rules,
processed_layers,
full_name,
)


def get_default_config_keras(module, device_ids: Sequence[str]) -> ConfigKeras:
"""Generates a smart, recursive sharding configuration for a Keras model.

This function traverses the layers of a given Keras model and applies a
set of heuristics to automatically determine how each layer's weights
and outputs should be sharded for tensor parallelism. It uses a helper
function to perform the recursive traversal.

Args:
module: The Keras Model to generate a sharding configuration for.
device_ids: A sequence of device identifiers, used to determine the
world size (number of devices) for sharding.

Returns:
A ConfigKeras object containing the generated 'state_rules' (for model
parameters) and 'output_rules' (for layer outputs).
"""
world_size = len(device_ids)
state_rules = {}
output_rules = {}
processed_layers = set()

for layer in module.layers:
_traverse_and_shard_layer(
current_layer=layer,
module=module,
world_size=world_size,
state_rules=state_rules,
output_rules=output_rules,
processed_layers=processed_layers,
prefix="",
)

return ConfigKeras(state_rules=state_rules, output_rules=output_rules)
151 changes: 151 additions & 0 deletions keras/src/distribution/tensor_parallel/autoconfig_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=2"

from keras import Input
from keras import Model
from keras import layers
from keras.src import testing
from keras.src.backend.distributed import backend_resolver
from keras.src.distribution.tensor_parallel.autoconfig import (
analyze_dense_layer_directly,
)
from keras.src.distribution.tensor_parallel.autoconfig import (
get_default_config_keras,
)
from keras.src.distribution.tensor_parallel.state_action_keras import SplitKeras


class TestAutoConfigKeras(testing.TestCase):
def setUp(self):
"""Set up the test case and common variables."""
super().setUp()
backend = backend_resolver.get_distributed_backend()
device_info = backend.get_device_info()
self.world_size = device_info["device_count"]
self.device_ids = [f"device:{i}" for i in range(self.world_size)]

self.assertGreater(
self.world_size, 1, "Distribution tests require more than 1 device."
)

def _assert_split_keras_equal(self, rule1, rule2):
"""
Helper to compare two SplitKeras objects by their attributes.
"""
self.assertIsInstance(rule1, SplitKeras)
self.assertIsInstance(rule2, SplitKeras)
self.assertDictEqual(vars(rule1), vars(rule2))

def _assert_rules_equal(self, actual_rules, expected_rules):
"""Helper to compare two dictionaries of sharding rules."""
self.assertSetEqual(
set(actual_rules.keys()), set(expected_rules.keys())
)
for key in expected_rules:
actual_val = actual_rules[key]
expected_val = expected_rules[key]
if isinstance(expected_val, SplitKeras):
self._assert_split_keras_equal(actual_val, expected_val)
else:
self.assertEqual(actual_val, expected_val)

def test_analyze_dense_layer(self):
"""Tests the direct analysis and classification of Dense layers."""
up_proj_layer = layers.Dense(32)
up_proj_layer.build(input_shape=(None, 16))
self.assertEqual(
analyze_dense_layer_directly(up_proj_layer, None, ""),
"up_projection",
)

down_proj_layer = layers.Dense(16)
down_proj_layer.build(input_shape=(None, 32))
self.assertEqual(
analyze_dense_layer_directly(down_proj_layer, None, ""),
"down_projection",
)

def test_simple_mlp_sharding(self):
"""Tests a simple MLP with up and down projection layers."""
inputs = Input(shape=(64,))
x = layers.Dense(256, name="up_projection_layer", use_bias=True)(inputs)
outputs = layers.Dense(
64, name="down_projection_layer", use_bias=False
)(x)
model = Model(inputs=inputs, outputs=outputs, name="simple_mlp")

config = get_default_config_keras(model, self.device_ids)

expected_state_rules = {
r"^up_projection_layer.kernel$": SplitKeras(
self.world_size, 1, "column"
),
r"^up_projection_layer.bias$": SplitKeras(
self.world_size, 0, "column"
),
r"^down_projection_layer.kernel$": SplitKeras(
self.world_size, 0, "row"
),
}
expected_output_rules = {
r"^up_projection_layer$": {0: "no_comm"},
r"^down_projection_layer$": {0: "allreduce"},
}

self._assert_rules_equal(config.state_rules, expected_state_rules)
self._assert_rules_equal(config.output_rules, expected_output_rules)

def test_embedding_sharding(self):
"""Tests an Embedding layer."""
inputs = Input(shape=(10,), dtype="int32")
outputs = layers.Embedding(
input_dim=1000, output_dim=128, name="token_embedding"
)(inputs)
model = Model(inputs=inputs, outputs=outputs, name="embed_model")

config = get_default_config_keras(model, self.device_ids)

expected_state_rules = {
r"^token_embedding\.embeddings$": SplitKeras(
self.world_size, 1, "column"
)
}
expected_output_rules = {r"^token_embedding$": {0: "no_comm"}}

self._assert_rules_equal(config.state_rules, expected_state_rules)
self._assert_rules_equal(config.output_rules, expected_output_rules)

def test_nested_model_sharding(self):
"""Tests that the traversal logic correctly handles nested models."""
inner_inputs = Input(shape=(32,))
inner_outputs = layers.Dense(128, name="inner_dense")(inner_inputs)
inner_model = Model(
inputs=inner_inputs, outputs=inner_outputs, name="inner_block"
)

outer_inputs = Input(shape=(32,))
x = inner_model(outer_inputs)
outer_outputs = layers.Dense(32, name="outer_dense")(x)
outer_model = Model(
inputs=outer_inputs, outputs=outer_outputs, name="outer_model"
)

config = get_default_config_keras(outer_model, self.device_ids)

expected_state_rules = {
r"^inner_block.inner_dense.kernel$": SplitKeras(
self.world_size, 1, "column"
),
r"^inner_block.inner_dense.bias$": SplitKeras(
self.world_size, 0, "column"
),
r"^outer_dense.kernel$": SplitKeras(self.world_size, 0, "row"),
}
expected_output_rules = {
r"^inner_block.inner_dense$": {0: "no_comm"},
r"^outer_dense$": {0: "allreduce"},
}

self._assert_rules_equal(config.state_rules, expected_state_rules)
self._assert_rules_equal(config.output_rules, expected_output_rules)
Loading
Loading