Skip to content
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
a27367a
Added tensor parallel for keras (Part 1/3)
buildwithsuhana Sep 26, 2025
488cd8f
Removed unnecessary lines
buildwithsuhana Sep 26, 2025
71ddd1a
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
bc4e4e2
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
d4200b5
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
21f89a2
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
299bd45
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
da625e1
Fixes suggested by Gemini
buildwithsuhana Sep 26, 2025
c233b8c
Fixing the failing test
buildwithsuhana Sep 26, 2025
7b8d733
Fixing the failing test
buildwithsuhana Sep 26, 2025
f825cd3
Fixing test
buildwithsuhana Sep 26, 2025
3725180
Adding tests for distributed_backends
buildwithsuhana Sep 29, 2025
a6c8a96
Modifications for failing tests
buildwithsuhana Sep 29, 2025
3fabfde
Modified for failing test
buildwithsuhana Sep 29, 2025
b133752
Modified for failing test
buildwithsuhana Sep 29, 2025
83c2e3f
Modified for failing test
buildwithsuhana Sep 29, 2025
3f3be6b
added debuggers
buildwithsuhana Sep 29, 2025
be325ab
removed debuggers
buildwithsuhana Sep 29, 2025
e1282ac
Merge branch 'keras-team:master' into Tensor_parallel_keras
buildwithsuhana Sep 29, 2025
fc11aaa
Removed the tensorflow, numpy and torch backends
buildwithsuhana Sep 30, 2025
ef6e2a0
Merge branch 'Tensor_parallel_keras' of https://github.com/buildwiths…
buildwithsuhana Sep 30, 2025
bea6ffa
Refactoring the code
buildwithsuhana Sep 30, 2025
4e00245
Refactoring the code
buildwithsuhana Sep 30, 2025
2f973b0
refactoring
buildwithsuhana Sep 30, 2025
bdb2b84
Adding necessary docstrings
buildwithsuhana Sep 30, 2025
d77fa71
Merge branch 'keras-team:master' into Tensor_parallel_keras
buildwithsuhana Oct 1, 2025
b9990b0
Removing redundancies
buildwithsuhana Oct 3, 2025
0aeee6f
Merge branch 'Tensor_parallel_keras' of https://github.com/buildwiths…
buildwithsuhana Oct 3, 2025
f784956
Modifying tests
buildwithsuhana Oct 3, 2025
8895a78
Reformatting
buildwithsuhana Oct 3, 2025
fe97f3b
Reformatting the code
buildwithsuhana Oct 3, 2025
77f01aa
Fixing failing tests
buildwithsuhana Oct 3, 2025
7080328
fixes
buildwithsuhana Oct 3, 2025
af711fd
Fixing tests
buildwithsuhana Oct 3, 2025
97dde17
formatting
buildwithsuhana Oct 3, 2025
f322a97
fixing test
buildwithsuhana Oct 3, 2025
5269ac9
fixing test
buildwithsuhana Oct 3, 2025
b9f36e9
Removing redundant lines
buildwithsuhana Oct 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions keras/src/backend/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .base import BaseDistributedBackend
from .factory import get_distributed_backend

__all__ = ["get_distributed_backend", "BaseDistributedBackend"]
57 changes: 57 additions & 0 deletions keras/src/backend/distributed/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from abc import ABC
from abc import abstractmethod
from typing import Any
from typing import List


class BaseDistributedBackend(ABC):
"""
Abstract Base Class for a distributed backend.
"""

@abstractmethod
def get_tensor_lib(self):
"""Get the appropriate tensor library for the backend."""
raise NotImplementedError

@abstractmethod
def convert_to_backend_tensor(self, tensor: Any) -> Any:
"""Convert a tensor to the appropriate backend format."""
raise NotImplementedError

@abstractmethod
def compute_gradients(
self, loss: Any, trainable_vars: List[Any]
) -> List[Any]:
"""Compute gradients using the backend's automatic differentiation."""
raise NotImplementedError

@abstractmethod
def apply_gradients(
self,
gradients: List[Any],
trainable_vars: List[Any],
learning_rate: float = 0.001,
) -> None:
"""Apply gradients to trainable variables."""
raise NotImplementedError

@abstractmethod
def create_optimizer(self, optimizer_class: str, **kwargs):
"""Create an optimizer for the backend."""
raise NotImplementedError

@abstractmethod
def get_device_info(self) -> dict:
"""Get information about available devices."""
raise NotImplementedError

@abstractmethod
def is_multi_device_capable(self) -> bool:
"""Check if the backend supports multi-device operations."""
raise NotImplementedError

@abstractmethod
def get_communication_ops(self) -> dict:
"""Get collective communication operations for the backend."""
raise NotImplementedError
79 changes: 79 additions & 0 deletions keras/src/backend/distributed/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import logging

from keras.src.backend.distributed.base import BaseDistributedBackend

logger = logging.getLogger(__name__)


def get_distributed_backend(
backend_name: str = "auto",
) -> BaseDistributedBackend:
"""
Factory to get the best available or a specific distributed backend.
"""
if backend_name == "auto":
try:
from keras.src.backend.jax.distributed_backend import (
JaxDistributedBackend,
)

logger.info("Auto-detected JAX for distributed backend.")
return JaxDistributedBackend()
except ImportError:
try:
from keras.src.backend.tensorflow.distributed_backend import (
TensorflowDistributedBackend,
)

logger.info("Auto-detected TensorFlow for distributed backend.")
return TensorflowDistributedBackend()
except ImportError:
try:
from keras.src.backend.torch.distributed_backend import (
TorchDistributedBackend,
)

logger.info(
"Auto-detected PyTorch for distributed backend."
)
return TorchDistributedBackend()
except ImportError:
error_msg = (
"Could not automatically detect a distributed backend "
"(JAX, TensorFlow, or PyTorch). Please install them "
"or explicitly specify a backend."
)
logger.error(error_msg)
raise ImportError(error_msg)

elif backend_name == "jax":
from keras.src.backend.jax.distributed_backend import (
JaxDistributedBackend,
)

return JaxDistributedBackend()
elif backend_name == "tensorflow":
from keras.src.backend.tensorflow.distributed_backend import (
TensorflowDistributedBackend,
)

return TensorflowDistributedBackend()
elif backend_name == "torch":
from keras.src.backend.torch.distributed_backend import (
TorchDistributedBackend,
)

return TorchDistributedBackend()
elif backend_name == "numpy":
from keras.src.backend.numpy.distributed_backend import (
NumpyDistributedBackend,
)

logger.warning(
"Using explicitly requested NumPy distributed backend. "
"This backend is for simulation and does not support "
"multi-device computation."
)
return NumpyDistributedBackend()
else:
raise ValueError(f"Unknown distributed backend: {backend_name}")
172 changes: 172 additions & 0 deletions keras/src/backend/jax/distributed_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import logging
from typing import Any
from typing import List

import jax
import jax.lax as lax
import jax.numpy as jnp
import optax

import keras
from keras.src.backend.distributed.base import BaseDistributedBackend

logger = logging.getLogger(__name__)


class JaxDistributedBackend(BaseDistributedBackend):
"""JAX-specific implementation of distributed operations."""

def get_tensor_lib(self):
return jnp

def convert_to_backend_tensor(self, tensor: Any) -> Any:
if isinstance(tensor, jax.Array):
return tensor
return jnp.array(tensor)

def compute_gradients(
self, loss: Any, trainable_vars: List[Any]
) -> List[Any]:
"""
JAX backend doesn't support gradient computation with pre-computed loss.

This method returns zero gradients as a fallback. For JAX, gradient
computation must be done via `jax.grad` on a function that computes
the loss from the parameters, which requires a different architecture.
"""
logger.warning(
"JAX backend `compute_gradients` is a fallback and returns "
"zero gradients. A functional `jax.grad` approach should be used "
"for training."
)
return [jnp.zeros_like(var) for var in trainable_vars]

def apply_gradients(
self,
gradients: List[Any],
trainable_vars: List[Any],
learning_rate: float = 0.001,
) -> None:
for grad, var in zip(gradients, trainable_vars):
if grad is not None:
new_value = var - (learning_rate * grad)
if hasattr(var, "assign"):
var.assign(new_value)
else:
logger.warning(
"Applying gradients to a standard JAX array has no "
"effect as JAX arrays are immutable. This operation "
"only works for mutable objects with an `.assign()` "
"method."
)

def create_optimizer(self, optimizer_class: str, **kwargs):
if optimizer_class.lower() == "adam":
return optax.adam(**kwargs)
elif optimizer_class.lower() == "sgd":
return optax.sgd(**kwargs)
else:
kwargs.setdefault("learning_rate", 0.001)
return optax.adam(**kwargs)

def get_device_info(self) -> dict:
info = {"backend": "jax", "devices": [], "device_count": 0}
try:
info["devices"] = [str(d) for d in jax.devices()]
info["device_count"] = jax.local_device_count()
except Exception as e:
logger.warning(f"Could not get device info for JAX: {e}")
info["devices"] = ["cpu"]
info["device_count"] = 1
return info

def is_multi_device_capable(self) -> bool:
return self.get_device_info()["device_count"] > 1

def get_communication_ops(self) -> dict:
try:
if not self.is_multi_device_capable():
raise RuntimeError("JAX is not running on multiple devices.")

logger.info("Using real JAX collective communication ops.")

def all_reduce_jax(x, op="sum", axis_name="data"):
if op == "sum":
return lax.psum(x, axis_name=axis_name)
elif op == "mean":
return lax.pmean(x, axis_name=axis_name)
raise ValueError(f"Unsupported all_reduce op: {op}")

def all_gather_jax(x, axis=0, axis_name="model"):
return lax.all_gather(x, axis_name=axis_name, axis=axis)

def broadcast_jax(x, root=0, axis_name="data"):
return lax.all_gather(x, axis_name=axis_name, axis=0)[root]

def scatter_jax(x, root=0):
logger.warning(
"Scatter is not a native op in JAX pmap; returning the "
"input tensor as a fallback."
)
return x

return {
"all_reduce": all_reduce_jax,
"all_gather": all_gather_jax,
"broadcast": broadcast_jax,
"scatter": scatter_jax,
}
except (ImportError, RuntimeError) as e:
logger.warning(
"JAX collective ops not available or multiple devices not "
f"configured: {e}. Using SIMULATED ops."
)

device_info = self.get_device_info()
simulated_world_size = device_info.get("device_count", 1)
if simulated_world_size == 0:
simulated_world_size = 1

logger.info(
f"Simulating with world_size={simulated_world_size} "
"based on available devices."
)

def all_reduce_simulated(x, op="sum"):
if simulated_world_size <= 1:
return x
if op == "sum":
return keras.ops.multiply(x, simulated_world_size)
elif op == "mean":
return x
else:
raise ValueError(f"Unsupported all_reduce op: {op}")

def all_gather_simulated(x, axis=0):
if simulated_world_size <= 1:
return x
return keras.ops.concatenate(
[x] * simulated_world_size, axis=axis
)

def broadcast_simulated(x, root=0):
return x

def scatter_simulated(x, root=0):
if simulated_world_size <= 1:
return x
if keras.ops.shape(x)[0] % simulated_world_size != 0:
raise ValueError(
"For simulation, the first dimension of tensor must "
f"be divisible by the simulated world size "
f"({simulated_world_size})."
)
chunks = keras.ops.split(x, simulated_world_size, axis=0)
return chunks[0]

return {
"all_reduce": all_reduce_simulated,
"all_gather": all_gather_simulated,
"broadcast": broadcast_simulated,
"scatter": scatter_simulated,
}
Loading
Loading