Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
if backend() == "tensorflow":
from keras.src.backend.tensorflow import * # noqa: F403
from keras.src.backend.tensorflow.core import Variable as BackendVariable

distributed_backend = None
elif backend() == "jax":
from keras.src.backend.jax import * # noqa: F403
from keras.src.backend.jax.core import Variable as BackendVariable
Expand All @@ -50,11 +52,13 @@
from keras.src.backend.numpy.core import Variable as BackendVariable

distribution_lib = None
distributed_backend = None
elif backend() == "openvino":
from keras.src.backend.openvino import * # noqa: F403
from keras.src.backend.openvino.core import Variable as BackendVariable

distribution_lib = None
distributed_backend = None
else:
raise ValueError(f"Unable to import backend : {backend()}")

Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from keras.src.backend.common.name_scope import name_scope
from keras.src.backend.torch import core
from keras.src.backend.torch import distributed_backend
from keras.src.backend.torch import image
from keras.src.backend.torch import linalg
from keras.src.backend.torch import math
Expand Down
257 changes: 257 additions & 0 deletions keras/src/backend/torch/distributed_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Literal

import torch
import torch.distributed as dist


def compute_gradients(
loss: torch.Tensor, trainable_vars: List[torch.Tensor]
) -> List[torch.Tensor]:
"""Computes gradients of the loss with respect to trainable variables.

This function leverages PyTorch's `autograd.grad` for a stateless,
functional approach similar to `jax.grad`.

Args:
loss (torch.Tensor): The loss value for which to compute gradients.
trainable_vars (List[torch.Tensor]): A list of variables (tensors with
`requires_grad=True`) to compute gradients with respect to.

Returns:
List[torch.Tensor]: A list of gradients corresponding to the
trainable variables.
"""
return list(torch.autograd.grad(loss, trainable_vars))


def apply_gradients(
gradients: List[torch.Tensor],
trainable_vars: List[torch.Tensor],
learning_rate: float = 0.001,
) -> List[torch.Tensor]:
"""Applies gradients and returns the updated variables.

Updates are performed in-place within a `torch.no_grad()` context
to prevent the update operation from being part of the computation graph.
"""
with torch.no_grad():
updated_vars = []
for grad, var in zip(gradients, trainable_vars):
if grad is not None:
var.sub_(learning_rate * grad)
updated_vars.append(var)
return updated_vars
Comment on lines +31 to +47
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The apply_gradients function implements a simple SGD update. As a public function in the distributed backend, this could be misleading for users who might expect it to integrate with their configured Keras optimizer. If this is only intended for testing purposes, consider making it a private function (e.g., _apply_gradients) or moving it into the test suite to avoid confusion.



def create_optimizer(optimizer_class: str, **kwargs) -> Dict[str, Any]:
"""Creates a configuration dictionary for a PyTorch optimizer.

This function returns a dictionary containing the optimizer's configuration,
maintaining a consistent interface with the JAX backend. The user is
expected to instantiate the optimizer from this config.

Args:
optimizer_class (str): The name of the optimizer to create (e.g.,
`"adam"`, `"sgd"`).
**kwargs: Keyword arguments for the optimizer (e.g., `learning_rate`).

Returns:
Dict[str, Any]: A dictionary representing the optimizer configuration.
"""
config = kwargs.copy()
config["name"] = optimizer_class.lower()
config.setdefault("learning_rate", 0.001)
return config


def get_device_info() -> Dict[str, Any]:
"""Retrieves information about the available PyTorch devices.

Returns:
Dict[str, Any]: A dictionary containing the backend name, a list of
available device strings, and the total device count.
"""
if torch.cuda.is_available():
device_count = torch.cuda.device_count()
devices = [torch.cuda.get_device_name(i) for i in range(device_count)]
else:
device_count = 1
devices = ["cpu"]
return {
"backend": "pytorch",
"devices": devices,
"device_count": device_count,
}


def is_multi_device_capable() -> bool:
"""Checks if more than one CUDA device is available.

Returns:
bool: `True` if PyTorch reports more than one CUDA device, `False`
otherwise.
"""
return torch.cuda.device_count() > 1


def get_communication_ops() -> Dict[str, Callable]:
"""Provides a dictionary of PyTorch collective communication operations.

These operations rely on the `torch.distributed` package. They are
designed to work in a multi-process, multi-device environment. If the
distributed package is not initialized, they provide a sensible fallback
for single-device execution.

Returns:
Dict[str, Callable]: A dictionary mapping operation names to their
PyTorch implementations.
"""

def _is_distributed() -> bool:
"""Checks if the default process group is initialized."""
return dist.is_available() and dist.is_initialized()

def all_reduce(
x: torch.Tensor,
op: Literal["sum", "mean"] = "sum",
) -> torch.Tensor:
"""Reduces a tensor across all devices.

Args:
x (torch.Tensor): The tensor to reduce.
op (Literal["sum", "mean"], optional): The reduction operation.
Defaults to "sum".

Returns:
torch.Tensor: The reduced tensor.
"""
if not _is_distributed():
world_size = (
torch.cuda.device_count() if torch.cuda.is_available() else 1
)
if world_size <= 1:
return x
if op == "sum":
return x * float(world_size)
elif op == "mean":
return x
else:
raise ValueError(f"Unsupported all_reduce op: {op}")

reduce_op = {"sum": dist.ReduceOp.SUM, "mean": dist.ReduceOp.AVG}.get(
op
)
if reduce_op is None:
raise ValueError(f"Unsupported all_reduce op: {op}")

result = x.clone()
dist.all_reduce(result, op=reduce_op)
return result

def all_gather(x: torch.Tensor, axis: int = 0) -> torch.Tensor:
"""Gathers tensors from all devices and concatenates them.

Args:
x (torch.Tensor): The local tensor to gather.
axis (int, optional): The axis along which to concatenate.
Defaults to 0.

Returns:
torch.Tensor: The concatenated tensor from all devices.
"""
if not _is_distributed():
world_size = (
torch.cuda.device_count() if torch.cuda.is_available() else 1
)
if world_size <= 1:
return x
return torch.cat([x] * world_size, dim=axis)

world_size = dist.get_world_size()
tensor_list = [torch.empty_like(x) for _ in range(world_size)]
dist.all_gather(tensor_list, x)
return torch.cat(tensor_list, dim=axis)

def broadcast(x: torch.Tensor, root: int = 0) -> torch.Tensor:
"""Broadcasts a tensor from a root device to all other devices.

Args:
x (torch.Tensor): The tensor to broadcast.
root (int, optional): The rank of the source device. Defaults to 0.

Returns:
torch.Tensor: The tensor received from the root device.
"""
if not _is_distributed():
return x

# `dist.broadcast` is in-place.
dist.broadcast(x, src=root)
return x

def scatter(
x: torch.Tensor,
root: int = 0,
axis: int = 0,
) -> torch.Tensor:
"""Scatters a tensor from a root device to all devices.

Note: The current implementation of `dist.scatter` requires the input
tensor `x` to be organized differently for the root process. This
wrapper simplifies it by handling the splitting automatically on the
root process.

Args:
x (torch.Tensor): The tensor on the root device to be scattered.
root (int, optional): The rank of the device holding the tensor.
Defaults to 0.
axis (int, optional): The axis along which to split the tensor.
Defaults to 0.

Returns:
torch.Tensor: The chunk of the tensor for the local device.
"""
if not _is_distributed():
world_size = (
torch.cuda.device_count() if torch.cuda.is_available() else 1
)
if world_size <= 1:
return x
if x.shape[axis] % world_size != 0:
raise ValueError(
f"Tensor with shape {x.shape} cannot be scattered along "
f"axis {axis} across {world_size} devices."
)
return torch.chunk(x, world_size, dim=axis)[0]

world_size = dist.get_world_size()
rank = dist.get_rank()

if x.shape[axis] % world_size != 0:
raise ValueError(
f"Tensor with shape {x.shape} cannot be scattered along "
f"axis {axis} across {world_size} devices."
)

if rank == root:
scatter_list = list(torch.chunk(x, world_size, dim=axis))
else:
scatter_list = None

chunk_shape = list(x.shape)
chunk_shape[axis] //= world_size
local_chunk = torch.empty(chunk_shape, dtype=x.dtype, device=x.device)

dist.scatter(local_chunk, scatter_list, src=root)
return local_chunk

return {
"all_reduce": all_reduce,
"all_gather": all_gather,
"broadcast": broadcast,
"scatter": scatter,
}
Loading
Loading