Skip to content

Conversation

buildwithsuhana
Copy link
Contributor

This pull request introduces a comprehensive framework for automatic tensor parallelism in Keras, enabling users to train models that are too large to fit on a single accelerator. The core of this feature is a new distribution strategy, AutoTPDistribution, which provides a simple, high-level API to shard an existing Keras model across multiple devices.

Description
This framework is designed to automate the complex process of model sharding and inter-device communication, making large-scale model training more accessible. The implementation is broken down into several key components:

  1. High-Level User API: AutoTPDistribution
    A new distribution strategy, AutoTPDistribution, is introduced in keras/src/distribution/distribution_lib.py. This class serves as the primary user entry point. The workflow is straightforward:
  • A user defines their hardware topology using a DeviceMesh.
  • They instantiate the AutoTPDistribution strategy.
  • They pass their standard Keras model to the distribution.shard() method.
  • This returns a new, sharded model instance ready for distributed training.
  1. Core Sharding Engine: TensorParallelKeras
    A new TensorParallelKeras model class (keras.src.distribution.tensor_parallel.tensor_parallel.py) acts as the core engine. When distribution.shard() is called, this class wraps the original model and performs the following actions:

Auto-Configures Hardware: Discovers and assigns available devices (TPU, GPU, or CPU).

Shards Parameters: It analyzes the model's layers and applies column-parallel or row-parallel sharding strategies to the weights and biases of relevant layers (e.g., Dense).

Builds a Unified Graph: It creates a single, assembled Keras Functional model that internally manages the parallel computation. This clever design encapsulates the communication logic (e.g., AllGather, ReduceScatter) within the model's call graph, simplifying the execution and enabling JIT compilation. Partial outputs from each device shard are correctly combined (e.g., concatenation for column-parallel, summation for row-parallel).

Coordinates Gradients: It overrides the compile method to wrap the user's optimizer in a TensorParallelOptimizer, which handles the synchronized computation and application of gradients across all shards.

@buildwithsuhana buildwithsuhana marked this pull request as draft October 8, 2025 07:47
Copy link
Contributor

Summary of Changes

Hello @buildwithsuhana, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This PR introduces automatic tensor parallelism to Keras, allowing models to be sharded across multiple devices for training larger models. It provides a high-level AutoTPDistribution API that uses a new TensorParallelKeras class to handle device auto-configuration, parameter sharding (column-parallel, row-parallel), and the construction of a unified functional model that manages inter-device communication and gradient coordination. New JAX backend utilities and extensive tests for sharding logic and communication have also been added.

Highlights

  • Intent: This pull request introduces a comprehensive framework for automatic tensor parallelism in Keras. The primary goal is to enable users to train models that are too large to fit on a single accelerator by automating model sharding and inter-device communication. It provides a high-level API to shard existing Keras models across multiple devices.
  • Key Changes: The core of this feature is the new AutoTPDistribution strategy, which serves as the user's entry point to shard a Keras model. This strategy leverages a new TensorParallelKeras model class that acts as the core sharding engine. This engine automatically configures hardware, applies column-parallel or row-parallel sharding to relevant layer weights (e.g., Dense layers), and builds a unified Keras Functional model that encapsulates parallel computation and communication logic (like AllGather, ReduceScatter). It also coordinates gradients by wrapping the user's optimizer in a TensorParallelOptimizer.
  • Detailed File Changes: 1. keras/src/backend/jax/distribution_lib.py: Adds new JAX-specific utilities for device information, auto-configuration of tensor parallelism, and device validation. It also includes helper functions like get_best_devices and get_device_memory_info.
  1. keras/src/distribution/distribution_lib.py: Introduces the AutoTPDistribution class, which is the public API for automatic tensor parallelism. This class handles the creation of a DeviceMesh (if not provided) and orchestrates the sharding process by instantiating and configuring the TensorParallelKeras model.
  2. keras/src/distribution/tensor_parallel/tensor_parallel.py (NEW FILE): This new file defines the TensorParallelKeras class, which wraps an original Keras model. It contains the logic for discovering devices, adjusting device lists, applying sharding rules to model parameters, and building a unified functional model (assembled_model) that manages the parallel execution and communication. It also overrides the compile method to integrate with a TensorParallelOptimizer for gradient coordination.
  3. Test Files (keras/src/backend/jax/distribution_lib_test.py, keras/src/distribution/distribution_lib_test.py, keras/src/distribution/tensor_parallel/tensor_parallel_test.py): Significant additions to test the new functionality. This includes environment setup for JAX backend, tests for AutoTPDistribution's sharding correctness (including uneven splits), and comprehensive tests for TensorParallelKeras covering initialization, non-distributed behavior, forward pass output shapes, and the correct invocation of communication primitives during gradient slicing and backward passes.
  • Reviewer Activity: No specific reviewer activity was provided in the context.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This is a great PR that introduces a powerful and much-needed feature for automatic tensor parallelism in Keras. The high-level API via AutoTPDistribution and the core engine TensorParallelKeras are well-designed, making large-scale model training more accessible. The use of a unified Functional model to encapsulate the parallel logic is particularly clever and should simplify execution and JIT compilation.

I've identified a few areas for improvement, mainly concerning API clarity, robustness, and consistency with Keras design principles. My comments focus on improving docstrings, handling edge cases more gracefully, and ensuring the code is as clear and maintainable as possible. I've also pointed out a few potential bugs and inconsistencies.

Overall, this is a fantastic contribution. Addressing these points will help ensure the new API is robust, intuitive, and easy for users to adopt.

Comment on lines +49 to +61
"""Return all the available devices based on the device type.
Note: in a distributed setting, global devices are returned.
Args:
device_type: string, one of `"cpu"`, `"gpu"` or `"tpu"`.
Defaults to `"gpu"` or `"tpu"` if available when
`device_type` is not provided. Otherwise
will return the `"cpu"` devices.
Return:
List of devices that are available for distribute computation.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The docstring for get_best_devices is incorrect and misleading. It describes a device_type argument, but the function's signature is get_best_devices(count). This will cause significant confusion for users trying to use this function.

This is a critical documentation issue that violates the Keras API design guidelines on documentation clarity (lines 135-137, 144-145).

Please update the docstring to accurately reflect the count parameter and the function's behavior.

    """Return the first `count` available devices.

    Note: in a distributed setting, global devices are returned. This function
    does not guarantee that the returned devices are the 'best' in terms of
    performance, only that they are available.

    Args:
        count: The number of devices to return.

    Return:
        List of device identifier strings.
    """


selected_devices = all_devices[:world_size]

recommended_backend = "jax"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The backend parameter is accepted by the function but its value is never used. recommended_backend is always hardcoded to "jax". This is a bug, as the function does not behave as its signature suggests.

To fix this, you should use the provided backend parameter, falling back to "jax" if it's not provided.

Suggested change
recommended_backend = "jax"
recommended_backend = backend or "jax"

Comment on lines +475 to +476
else:
return "cpu"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The else case here silently converts any non-string/non-int device_spec to "cpu". This can hide bugs. For example, if a jax.Device object is passed by mistake, it will be incorrectly treated as a CPU device.

It would be more robust to raise a TypeError for unsupported types to make such errors explicit.

        else:
            raise TypeError(
                f"Expected `device_spec` to be a str or int, but got "
                f"{type(device_spec)}"
            )

Comment on lines +671 to +680
def distribute_dataset(self, dataset):
if distribution_lib.num_processes() <= 1 or not self.auto_shard_dataset:
return dataset
from keras.src.utils.module_utils import tensorflow as tf

if not tf.available or not isinstance(dataset, tf.data.Dataset):
raise ValueError(
"Only `tf.data.Dataset` is supported for auto-sharding."
)
return dataset.with_options(tf.data.Options())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The implementation of distribute_dataset for AutoTPDistribution seems incorrect or incomplete. Calling dataset.with_options(tf.data.Options()) with empty options is a no-op and does not actually shard the dataset across processes.

This is problematic in a multi-process setting, as each process would get the full dataset, leading to redundant computation and incorrect results.

The implementation should handle dataset sharding, similar to how DataParallel.distribute_dataset does. It needs to consider the number of processes and the data parallelism dimension of the device mesh to correctly shard the data.

A `TensorParallelKeras` model instance ready for distributed
training.
"""
print(f"INFO: Sharding model `{model.name}` for Tensor Parallelism...")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This method uses print() for logging informational messages (here and on line 649). It's better to use the logging module for consistency with other parts of the codebase (like keras/src/backend/jax/distribution_lib.py) and to allow users more control over log verbosity.

Please replace print() with logging.info(). You will need to ensure logging is imported and a logger is configured at the module level.

        import logging
        logging.info(f"Sharding model `{model.name}` for Tensor Parallelism...")

Comment on lines +143 to +156
params_per_shard = []
for i, shard in enumerate(self.model_shards):
total_params = 0
for p in shard.weights:
if hasattr(p, "num_elements"):
total_params += p.num_elements()
elif hasattr(p, "numel"):
total_params += p.numel()
elif hasattr(p.shape, "num_elements"):
total_params += p.shape.num_elements()
else:
total_params += np.prod(p.shape)

params_per_shard.append(int(total_params))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The params_per_shard list is calculated but never used. This is dead code and should be removed to improve maintainability.

Comment on lines +159 to +161
from keras.src.distribution import distributed_backend

self.distributed_backend = distributed_backend
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Here, self.distributed_backend (which was a string from the constructor argument) is being reassigned to the imported distributed_backend module. The original string value is stored in self.distributed_backend_name. This is confusing because the type of self.distributed_backend changes, and its name no longer reflects that it's a module.

To improve clarity and avoid type confusion, consider renaming the attribute that holds the module, for example:

from keras.src.distribution import distributed_backend as dist_backend_module

self.distributed_backend_module = dist_backend_module

Comment on lines +169 to +176
def variables(self):
"""Returns a unique list of all variables from all model shards."""
unique_vars = {}
for shard in self.model_shards:
for var in shard.variables:
if id(var) not in unique_vars:
unique_vars[id(var)] = var
return list(unique_vars.values())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This property re-calculates the list of unique variables every time it's accessed. This is inefficient, especially if accessed frequently in a loop. The same applies to the other properties like trainable_variables, weights, etc.

Since the model shards are created once in __init__, you can cache the result of this computation to improve performance.

Suggested change
def variables(self):
"""Returns a unique list of all variables from all model shards."""
unique_vars = {}
for shard in self.model_shards:
for var in shard.variables:
if id(var) not in unique_vars:
unique_vars[id(var)] = var
return list(unique_vars.values())
@property
def variables(self):
"""Returns a unique list of all variables from all model shards."""
if hasattr(self, "_variables"):
return self._variables
unique_vars = {}
for shard in self.model_shards:
for var in shard.variables:
if id(var) not in unique_vars:
unique_vars[id(var)] = var
self._variables = list(unique_vars.values())
return self._variables

Comment on lines +60 to +62
device_type, device_index = device_id.split(":")
device_info["type"] = device_type.upper()
device_info["index"] = int(device_index)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of device_id.split(":") is not robust and will raise a ValueError if device_id does not contain a colon (e.g., for a device like "cpu"). This could lead to unexpected crashes if such a device ID is passed.

To make this function more robust, you should handle cases where the device ID might not have an index. A simple way is to check for the presence of ":" before splitting.

Suggested change
device_type, device_index = device_id.split(":")
device_info["type"] = device_type.upper()
device_info["index"] = int(device_index)
parts = device_id.split(":")
device_info["type"] = parts[0].upper()
device_info["index"] = int(parts[1]) if len(parts) > 1 else 0


# 3. ASSERT
# For a dimension of 17 split across 3 devices, the expected
# sharded shapes are (6, 5, 5).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment here states that the expected sharded shapes are (6, 5, 5). However, for a dimension of 17 split across 3 devices, the split is (6, 6, 5). The test assertions correctly check for this (6, 6, 5) split. The comment is incorrect and misleading.

Please update the comment to match the actual implementation and assertion.

Suggested change
# sharded shapes are (6, 5, 5).
# sharded shapes are (6, 6, 5).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants