Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
a27367a
Added tensor parallel for keras (Part 1/3)
buildwithsuhana Sep 26, 2025
488cd8f
Removed unnecessary lines
buildwithsuhana Sep 26, 2025
71ddd1a
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
bc4e4e2
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
d4200b5
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
21f89a2
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
299bd45
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
da625e1
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
c233b8c
Fixing the failing test
buildwithsuhana Sep 26, 2025
7b8d733
Fixing the failing test
buildwithsuhana Sep 26, 2025
f825cd3
Fixing test
buildwithsuhana Sep 26, 2025
3725180
Adding tests for distributed_backends
buildwithsuhana Sep 29, 2025
a6c8a96
Modifications for failing tests
buildwithsuhana Sep 29, 2025
3fabfde
Modified for failing test
buildwithsuhana Sep 29, 2025
b133752
Modified for failing test
buildwithsuhana Sep 29, 2025
83c2e3f
Modified for failing test
buildwithsuhana Sep 29, 2025
3f3be6b
added debuggers
buildwithsuhana Sep 29, 2025
be325ab
removed debuggers
buildwithsuhana Sep 29, 2025
e1282ac
Merge branch 'keras-team:master' into Tensor_parallel_keras
buildwithsuhana Sep 29, 2025
fc11aaa
Removed the tensorflow, numpy and torch backends
buildwithsuhana Sep 30, 2025
ef6e2a0
Merge branch 'Tensor_parallel_keras' of https://github.com/buildwiths…
buildwithsuhana Sep 30, 2025
bea6ffa
Refactoring the code
buildwithsuhana Sep 30, 2025
4e00245
Refactoring the code
buildwithsuhana Sep 30, 2025
2f973b0
refactoring
buildwithsuhana Sep 30, 2025
bdb2b84
Adding necessary docstrings
buildwithsuhana Sep 30, 2025
d77fa71
Merge branch 'keras-team:master' into Tensor_parallel_keras
buildwithsuhana Oct 1, 2025
b9990b0
Removing redundancies
buildwithsuhana Oct 3, 2025
0aeee6f
Merge branch 'Tensor_parallel_keras' of https://github.com/buildwiths…
buildwithsuhana Oct 3, 2025
f784956
Modifying tests
buildwithsuhana Oct 3, 2025
8895a78
Reformatting
buildwithsuhana Oct 3, 2025
fe97f3b
Reformatting the code
buildwithsuhana Oct 3, 2025
77f01aa
Fixing failing tests
buildwithsuhana Oct 3, 2025
7080328
fixes
buildwithsuhana Oct 3, 2025
af711fd
Fixing tests
buildwithsuhana Oct 3, 2025
97dde17
formatting
buildwithsuhana Oct 3, 2025
f322a97
fixing test
buildwithsuhana Oct 3, 2025
5269ac9
fixing test
buildwithsuhana Oct 3, 2025
b9f36e9
Removing redundant lines
buildwithsuhana Oct 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions keras/api/_tf_keras/keras/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,21 @@
since your modifications would be overwritten.
"""

from keras.src.distribution.distributed_backend import (
apply_gradients as apply_gradients,
)
from keras.src.distribution.distributed_backend import (
create_optimizer as create_optimizer,
)
from keras.src.distribution.distributed_backend import (
get_communication_ops as get_communication_ops,
)
from keras.src.distribution.distributed_backend import (
get_device_info as get_device_info,
)
from keras.src.distribution.distributed_backend import (
is_multi_device_capable as is_multi_device_capable,
)
from keras.src.distribution.distribution_lib import DataParallel as DataParallel
from keras.src.distribution.distribution_lib import DeviceMesh as DeviceMesh
from keras.src.distribution.distribution_lib import LayoutMap as LayoutMap
Expand Down
15 changes: 15 additions & 0 deletions keras/api/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,21 @@
since your modifications would be overwritten.
"""

from keras.src.distribution.distributed_backend import (
apply_gradients as apply_gradients,
)
from keras.src.distribution.distributed_backend import (
create_optimizer as create_optimizer,
)
from keras.src.distribution.distributed_backend import (
get_communication_ops as get_communication_ops,
)
from keras.src.distribution.distributed_backend import (
get_device_info as get_device_info,
)
from keras.src.distribution.distributed_backend import (
is_multi_device_capable as is_multi_device_capable,
)
from keras.src.distribution.distribution_lib import DataParallel as DataParallel
from keras.src.distribution.distribution_lib import DeviceMesh as DeviceMesh
from keras.src.distribution.distribution_lib import LayoutMap as LayoutMap
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,29 @@
if backend() == "tensorflow":
from keras.src.backend.tensorflow import * # noqa: F403
from keras.src.backend.tensorflow.core import Variable as BackendVariable

distributed_backend = None
elif backend() == "jax":
from keras.src.backend.jax import * # noqa: F403
from keras.src.backend.jax.core import Variable as BackendVariable
elif backend() == "torch":
from keras.src.backend.torch import * # noqa: F403
from keras.src.backend.torch.core import Variable as BackendVariable

distributed_backend = None
distribution_lib = None
elif backend() == "numpy":
from keras.src.backend.numpy import * # noqa: F403
from keras.src.backend.numpy.core import Variable as BackendVariable

distribution_lib = None
distributed_backend = None
elif backend() == "openvino":
from keras.src.backend.openvino import * # noqa: F403
from keras.src.backend.openvino.core import Variable as BackendVariable

distribution_lib = None
distributed_backend = None
else:
raise ValueError(f"Unable to import backend : {backend()}")

Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/jax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from keras.src.backend.config import is_nnx_enabled
from keras.src.backend.jax import core
from keras.src.backend.jax import distributed_backend
from keras.src.backend.jax import distribution_lib
from keras.src.backend.jax import image
from keras.src.backend.jax import linalg
Expand Down
248 changes: 248 additions & 0 deletions keras/src/backend/jax/distributed_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Literal

import jax
import jax.lax as lax
import jax.numpy as jnp

import keras


def compute_gradients(
_loss: jnp.ndarray, trainable_vars: List[jnp.ndarray]
Comment on lines +14 to +15
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This signature doesn't work for JAX. You cannot take the gradient of a tensor. You can only transform a function so that you can take its gradient.

) -> List[jnp.ndarray]:
"""Computes gradients of the loss with respect to trainable variables.

Note: This is a placeholder implementation that returns zeros. A real
implementation would use `jax.grad`.
Comment on lines +19 to +20
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So why are we doing this if it's not a real implementation?


Args:
_loss (jnp.ndarray): The loss value for which to compute gradients.
trainable_vars (List[jnp.ndarray]): A list of variables to compute
gradients with respect to.

Returns:
List[jnp.ndarray]: A list of gradients corresponding to the
trainable variables.
"""
return [jnp.zeros_like(var) for var in trainable_vars]


def apply_gradients(
gradients: List[jnp.ndarray],
trainable_vars: List[jnp.ndarray],
learning_rate: float = 0.001,
) -> List[jnp.ndarray]:
"""Applies gradients and returns the updated variables."""
updated_vars = []
for grad, var in zip(gradients, trainable_vars):
if grad is not None:
new_var = var - (learning_rate * grad)
updated_vars.append(new_var)
else:
updated_vars.append(var)
return updated_vars
Comment on lines +34 to +47
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an inline implementation of SGD. Why is this needed?



def create_optimizer(optimizer_class: str, **kwargs) -> Dict[str, Any]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this? Does this mean that TensorParallel won't work with Keras optimizers?

"""Creates a configuration dictionary for an optimizer.

This function returns a dictionary containing the optimizer's configuration,
removing the need for a specific optimizer library like Optax.

Args:
optimizer_class (str): The name of the optimizer to create (e.g.,
`"adam"`, `"sgd"`).
**kwargs: Keyword arguments to be passed to the optimizer's
constructor (e.g., `learning_rate`).

Returns:
Dict[str, Any]: A dictionary representing the optimizer configuration.
"""
config = kwargs.copy()
config["name"] = optimizer_class.lower()
config.setdefault("learning_rate", 0.001)
return config


def get_device_info() -> Dict[str, Any]:
"""Retrieves information about the available JAX devices.

Returns:
Dict[str, Any]: A dictionary containing the backend name, a list of
available device strings, and the total device count.
"""
available_devices = jax.devices()
return {
"backend": "jax",
"devices": [str(d) for d in available_devices],
"device_count": len(available_devices),
}


def is_multi_device_capable() -> bool:
"""Checks if more than one JAX device is available.

Returns:
bool: `True` if JAX reports more than one local device, `False`
otherwise.
"""
return jax.local_device_count() > 1


def get_communication_ops() -> Dict[str, Callable]:
"""Provides a dictionary of JAX collective communication operations.

These operations are designed to work within a `jax.pmap` context for
multi-device computation. If not in a `pmap` context, they generally
behave as no-ops or simulate the operation on the single local device.

Returns:
Dict[str, Callable]: A dictionary mapping operation names to their
JAX implementations.
"""

def _is_in_pmap(axis_name: str = "data") -> bool:
"""Checks if currently inside a pmap by probing the axis name."""
try:
lax.axis_index(axis_name)
return True
except NameError:
return False

def all_reduce(
x: jnp.ndarray,
op: Literal["sum", "mean"] = "sum",
axis_name: str = "data",
) -> jnp.ndarray:
"""Reduces a tensor across all devices in a `pmap`.

Args:
x (jnp.ndarray): The tensor to reduce.
op (Literal["sum", "mean"], optional): The reduction operation.
Defaults to "sum".
axis_name (str, optional): The name of the `pmap` axis.
Defaults to "data".

Returns:
jnp.ndarray: The reduced tensor. Returns the input tensor `x` if
not in a `pmap` context.
"""
if _is_in_pmap(axis_name):
reduce_ops = {
"sum": lax.psum,
"mean": lax.pmean,
}
reduce_fn = reduce_ops.get(op)

if reduce_fn is None:
raise ValueError(f"Unsupported all_reduce op: {op}")
return reduce_fn(x, axis_name=axis_name)
else:
world_size = jax.local_device_count()
if world_size <= 1:
return x
if op == "sum":
return keras.ops.multiply(x, float(world_size))
elif op == "mean":
return x
else:
raise ValueError(f"Unsupported all_reduce op: {op}")

def all_gather(
x: jnp.ndarray, axis: int = 0, axis_name: str = "data"
) -> jnp.ndarray:
"""Gathers tensors from all devices and concatenates them.

Args:
x (jnp.ndarray): The local tensor to gather.
axis (int, optional): The axis along which to concatenate the
gathered tensors. Defaults to 0.
axis_name (str, optional): The name of the `pmap` axis.
Defaults to "data".

Returns:
jnp.ndarray: The concatenated tensor from all devices.
"""
if _is_in_pmap(axis_name):
return lax.all_gather(x, axis_name=axis_name, axis=axis)
else:
world_size = jax.local_device_count()
if world_size <= 1:
return x
return keras.ops.concatenate([x] * world_size, axis=axis)

def broadcast(
x: jnp.ndarray, root: int = 0, axis_name: str = "data"
) -> jnp.ndarray:
"""Broadcasts a tensor from a root device to all other devices.

Args:
x (jnp.ndarray): The tensor to broadcast. On the root device, this
is the tensor to be sent.
root (int, optional): The rank of the device from which to
broadcast. Defaults to 0.
axis_name (str, optional): The name of the `pmap` axis.
Defaults to "data".

Returns:
jnp.ndarray: The tensor received from the root device.
"""
if _is_in_pmap(axis_name):
return lax.all_gather(x, axis_name=axis_name, axis=0)[root]
else:
return x

def scatter(
x: jnp.ndarray,
root: int = 0,
axis: int = 0,
axis_name: str = "data",
) -> jnp.ndarray:
"""Scatters a tensor from a root device to all devices.

Args:
x (jnp.ndarray): The tensor on the root device to be scattered.
root (int, optional): The rank of the device that holds the full
tensor. Defaults to 0.
axis (int, optional): The axis along which to split the tensor.
Defaults to 0.
axis_name (str, optional): The name of the `pmap` axis.
Defaults to "data".

Returns:
jnp.ndarray: The chunk of the tensor for the local device.
"""
if _is_in_pmap(axis_name):
full_tensor = lax.all_gather(x, axis_name=axis_name, axis=0)[root]
device_id = lax.axis_index(axis_name=axis_name)
num_devices = lax.psum(1, axis_name=axis_name)
chunk_size = full_tensor.shape[axis] // num_devices
start_index = device_id * chunk_size
return lax.dynamic_slice_in_dim(
operand=full_tensor,
start_index=start_index,
slice_size=chunk_size,
axis=axis,
)
else:
world_size = jax.local_device_count()
if world_size <= 1:
return x
if x.shape[axis] % world_size != 0:
raise ValueError(
f"Tensor with shape {x.shape} cannot be scattered along "
f"axis {axis} across {world_size} devices."
)
chunks = keras.ops.split(x, world_size, axis=axis)
return chunks[0]

return {
"all_reduce": all_reduce,
"all_gather": all_gather,
"broadcast": broadcast,
"scatter": scatter,
}
Loading