Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,27 @@
if backend() == "tensorflow":
from keras.src.backend.tensorflow import * # noqa: F403
from keras.src.backend.tensorflow.core import Variable as BackendVariable

distributed_backend = None
elif backend() == "jax":
from keras.src.backend.jax import * # noqa: F403
from keras.src.backend.jax.core import Variable as BackendVariable
elif backend() == "torch":
from keras.src.backend.torch import * # noqa: F403
from keras.src.backend.torch.core import Variable as BackendVariable

distribution_lib = None
elif backend() == "numpy":
from keras.src.backend.numpy import * # noqa: F403
from keras.src.backend.numpy.core import Variable as BackendVariable

distribution_lib = None
distributed_backend = None
elif backend() == "openvino":
from keras.src.backend.openvino import * # noqa: F403
from keras.src.backend.openvino.core import Variable as BackendVariable

distribution_lib = None
distributed_backend = None
else:
raise ValueError(f"Unable to import backend : {backend()}")

Expand Down
2 changes: 2 additions & 0 deletions keras/src/backend/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

from keras.src.backend.common.name_scope import name_scope
from keras.src.backend.torch import core
from keras.src.backend.torch import distributed_backend
from keras.src.backend.torch import distribution_lib
from keras.src.backend.torch import image
from keras.src.backend.torch import linalg
from keras.src.backend.torch import math
Expand Down
220 changes: 220 additions & 0 deletions keras/src/backend/torch/distributed_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Literal

import torch
import torch.distributed as dist


def compute_gradients(
loss: torch.Tensor, trainable_vars: List[torch.Tensor]
) -> List[torch.Tensor]:
"""Computes gradients of the loss with respect to trainable variables.

This function leverages PyTorch's `autograd.grad` for a stateless,
functional approach similar to `jax.grad`.

Args:
loss (torch.Tensor): The loss value for which to compute gradients.
trainable_vars (List[torch.Tensor]): A list of variables (tensors with
`requires_grad=True`) to compute gradients with respect to.

Returns:
List[torch.Tensor]: A list of gradients corresponding to the
trainable variables.
"""
return list(torch.autograd.grad(loss, trainable_vars))


def apply_gradients(
gradients: List[torch.Tensor],
trainable_vars: List[torch.Tensor],
learning_rate: float = 0.001,
) -> List[torch.Tensor]:
"""Applies gradients and returns the updated variables.

Updates are performed in-place within a `torch.no_grad()` context
to prevent the update operation from being part of the computation graph.
"""
with torch.no_grad():
updated_vars = []
for grad, var in zip(gradients, trainable_vars):
if grad is not None:
var.sub_(learning_rate * grad)
updated_vars.append(var)
return updated_vars
Comment on lines +31 to +47
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The apply_gradients function implements a simple SGD update. As a public function in the distributed backend, this could be misleading for users who might expect it to integrate with their configured Keras optimizer. If this is only intended for testing purposes, consider making it a private function (e.g., _apply_gradients) or moving it into the test suite to avoid confusion.



def create_optimizer(optimizer_class: str, **kwargs) -> Dict[str, Any]:
"""Creates a configuration dictionary for a PyTorch optimizer.

This function returns a dictionary containing the optimizer's configuration,
maintaining a consistent interface with the JAX backend. The user is
expected to instantiate the optimizer from this config.

Args:
optimizer_class (str): The name of the optimizer to create (e.g.,
`"adam"`, `"sgd"`).
**kwargs: Keyword arguments for the optimizer (e.g., `learning_rate`).

Returns:
Dict[str, Any]: A dictionary representing the optimizer configuration.
"""
config = kwargs.copy()
config["name"] = optimizer_class.lower()
config.setdefault("learning_rate", 0.001)
return config


def get_device_info() -> Dict[str, Any]:
"""Retrieves information about the available PyTorch devices.

Returns:
Dict[str, Any]: A dictionary containing the backend name, a list of
available device strings, and the total device count.
"""
if torch.cuda.is_available():
device_count = torch.cuda.device_count()
devices = [torch.cuda.get_device_name(i) for i in range(device_count)]
else:
device_count = 1
devices = ["cpu"]
return {
"backend": "pytorch",
"devices": devices,
"device_count": device_count,
}


def is_multi_device_capable() -> bool:
"""Checks if more than one CUDA device is available.

Returns:
bool: `True` if PyTorch reports more than one CUDA device, `False`
otherwise.
"""
return torch.cuda.device_count() > 1


def get_communication_ops() -> Dict[str, Callable]:
"""Provides a dictionary of PyTorch collective communication operations.

These operations rely on the `torch.distributed` package. They are
designed to work in a multi-process, multi-device environment. If the
distributed package is not initialized, they provide a sensible fallback
for single-device execution.

Returns:
Dict[str, Callable]: A dictionary mapping operation names to their
PyTorch implementations.
"""

def _is_distributed() -> bool:
"""Checks if the default process group is initialized."""
return dist.is_available() and dist.is_initialized()

def all_reduce(
x: torch.Tensor,
op: Literal["sum", "mean"] = "sum",
axis_name: str = None,
) -> torch.Tensor:
"""Reduces a tensor across all devices."""
if not _is_distributed():
world_size = (
torch.cuda.device_count() if torch.cuda.is_available() else 1
)
if world_size <= 1:
return x
if op == "sum":
return x * float(world_size)
elif op == "mean":
return x
else:
raise ValueError(f"Unsupported all_reduce op: {op}")

reduce_op = {"sum": dist.ReduceOp.SUM, "mean": dist.ReduceOp.AVG}.get(
op
)
if reduce_op is None:
raise ValueError(f"Unsupported all_reduce op: {op}")

result = x.clone()
dist.all_reduce(result, op=reduce_op)
return result

def all_gather(
x: torch.Tensor, axis: int = 0, axis_name: str = None
) -> torch.Tensor:
"""Gathers tensors from all devices and concatenates them."""
if not _is_distributed():
world_size = (
torch.cuda.device_count() if torch.cuda.is_available() else 1
)
if world_size <= 1:
return x
return torch.cat([x] * world_size, dim=axis)

world_size = dist.get_world_size()
tensor_list = [torch.empty_like(x) for _ in range(world_size)]
dist.all_gather(tensor_list, x)
return torch.cat(tensor_list, dim=axis)

def broadcast(
x: torch.Tensor, root: int = 0, axis_name: str = None
) -> torch.Tensor:
"""Broadcasts a tensor from a root device to all other devices."""
if not _is_distributed():
return x

dist.broadcast(x, src=root)
return x

def scatter(
x: torch.Tensor,
root: int = 0,
axis: int = 0,
axis_name: str = None,
) -> torch.Tensor:
"""Scatters a tensor from a root device to all devices."""
if not _is_distributed():
world_size = (
torch.cuda.device_count() if torch.cuda.is_available() else 1
)
if world_size <= 1:
return x
if x.shape[axis] % world_size != 0:
raise ValueError(
f"Tensor with shape {x.shape} cannot be scattered along "
f"axis {axis} across {world_size} devices."
)
return torch.chunk(x, world_size, dim=axis)[0]

world_size = dist.get_world_size()
rank = dist.get_rank()

if x.shape[axis] % world_size != 0:
raise ValueError(
f"Tensor with shape {x.shape} cannot be scattered along "
f"axis {axis} across {world_size} devices."
)

if rank == root:
scatter_list = list(torch.chunk(x, world_size, dim=axis))
else:
scatter_list = None

chunk_shape = list(x.shape)
chunk_shape[axis] //= world_size
local_chunk = torch.empty(chunk_shape, dtype=x.dtype, device=x.device)

dist.scatter(local_chunk, scatter_list, src=root)
return local_chunk

return {
"all_reduce": all_reduce,
"all_gather": all_gather,
"broadcast": broadcast,
"scatter": scatter,
}
129 changes: 129 additions & 0 deletions keras/src/backend/torch/distributed_backend_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import pytest
import torch

from keras.src import backend
from keras.src.backend import distributed_backend


@pytest.mark.skipif(
backend.backend() != "torch",
reason="Jax Backend specific test",
)
Comment on lines +8 to +11
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The skipif reason "Jax Backend specific test" appears to be a copy-paste error. This test class is for the PyTorch distributed backend and should be labeled as such.

Suggested change
@pytest.mark.skipif(
backend.backend() != "torch",
reason="Jax Backend specific test",
)
@pytest.mark.skipif(
backend.backend() != "torch",
reason="Torch Backend specific test",
)

class TestPytorchDistributedFunctions:
"""Unit tests for the PyTorch distributed backend standalone functions."""

def test_compute_gradients_computes_correctly(self):
"""Test that compute_gradients returns correct gradients."""
w = torch.tensor([2.0, 3.0], requires_grad=True)
b = torch.tensor(1.0, requires_grad=True)
x = torch.tensor([4.0, 5.0])
y_true = torch.tensor(25.0)

# loss = (w.x + b - y_true)^2 = ((2*4 + 3*5 + 1) - 25)^2 = (24-25)^2 = 1
y_pred = torch.dot(w, x) + b
loss = (y_pred - y_true) ** 2

trainable_vars = [w, b]
gradients = distributed_backend.compute_gradients(loss, trainable_vars)

# d_loss/d_w = 2*(y_pred - y_true)*x = 2*(-1)*[4, 5] = [-8, -10]
# d_loss/d_b = 2*(y_pred - y_true)*1 = 2*(-1)*1 = -2
expected_grad_w = torch.tensor([-8.0, -10.0])
expected_grad_b = torch.tensor(-2.0)

assert len(gradients) == 2
torch.testing.assert_close(gradients[0], expected_grad_w)
torch.testing.assert_close(gradients[1], expected_grad_b)

def test_apply_gradients(self):
"""Test the application of gradients to PyTorch tensors."""
var1 = torch.tensor([1.0, 2.0], requires_grad=True)
var2 = torch.tensor(5.0, requires_grad=True)
trainable_vars = [var1, var2]
grad1 = torch.tensor([0.1, 0.2])
grad2 = torch.tensor(0.5)
gradients = [grad1, grad2]
learning_rate = 0.1

original_var1 = var1.clone()
original_var2 = var2.clone()

updated_vars = distributed_backend.apply_gradients(
gradients, trainable_vars, learning_rate
)

assert updated_vars[0] is var1
assert updated_vars[1] is var2

expected_var1 = original_var1 - (grad1 * learning_rate)
expected_var2 = original_var2 - (grad2 * learning_rate)
torch.testing.assert_close(updated_vars[0], expected_var1)
torch.testing.assert_close(updated_vars[1], expected_var2)

def test_create_optimizer(self):
"""Test optimizer configuration creation."""
adam_config = distributed_backend.create_optimizer(
"adam", learning_rate=0.01
)
assert isinstance(adam_config, dict)
assert adam_config["name"] == "adam"
assert adam_config["learning_rate"] == 0.01

sgd_config = distributed_backend.create_optimizer(
"sgd", learning_rate=0.1, momentum=0.9
)
assert isinstance(sgd_config, dict)
assert sgd_config["name"] == "sgd"
assert sgd_config["learning_rate"] == 0.1
assert sgd_config["momentum"] == 0.9

def test_get_device_info(self):
"""Test retrieving device information from the PyTorch backend."""
info = distributed_backend.get_device_info()
assert info["backend"] == "pytorch"
assert isinstance(info["devices"], list)
assert isinstance(info["device_count"], int)
assert info["device_count"] > 0
assert len(info["devices"]) == info["device_count"]
if torch.cuda.is_available():
assert info["device_count"] == torch.cuda.device_count()
else:
assert info["device_count"] == 1
assert info["devices"] == ["cpu"]

def test_is_multi_device_capable(self):
"""Test the boolean check for multi-device capability."""
assert isinstance(distributed_backend.is_multi_device_capable(), bool)

def test_communication_ops_simulation_logic(self):
"""Test the simulated communication ops in a single-device context."""
comm_ops = distributed_backend.get_communication_ops()
device_info = distributed_backend.get_device_info()
world_size = device_info.get("device_count", 1)

# Test all_reduce
x_reduce = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
reduced = comm_ops["all_reduce"](x_reduce, op="sum")
expected_reduce = (
x_reduce * float(world_size) if world_size > 1 else x_reduce
)
torch.testing.assert_close(reduced, expected_reduce)

# Test all_gather
x_gather = torch.tensor([[1.0, 2.0]])
gathered = comm_ops["all_gather"](x_gather, axis=0)
expected_gather = torch.cat([x_gather] * world_size, dim=0)
torch.testing.assert_close(gathered, expected_gather)

# Test broadcast
x_broadcast = torch.tensor([5.0, 6.0])
broadcasted = comm_ops["broadcast"](x_broadcast)
torch.testing.assert_close(broadcasted, x_broadcast)

# Test scatter
if world_size > 0:
scatter_data = torch.arange(world_size * 4, dtype=torch.float32)
x_scatter = scatter_data.reshape(world_size * 2, 2)
scattered = comm_ops["scatter"](x_scatter, axis=0)
expected_scatter = torch.chunk(x_scatter, world_size, dim=0)[0]
torch.testing.assert_close(scattered, expected_scatter)
Loading
Loading