-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Introducing Parameter Sharding and Torch backend for Tensor Parallelism #21724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
buildwithsuhana
wants to merge
5
commits into
keras-team:master
Choose a base branch
from
buildwithsuhana:Tensor_parallel_keras_3
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
e8f1a5a
Adding parameter sharding and test
buildwithsuhana c01ec8a
Added docstrings to parameter_shardimg
buildwithsuhana c22facf
Added docstrings
buildwithsuhana 276530a
Torch backend added
buildwithsuhana 6c0189e
Added torch backend and fixed circular import issue
buildwithsuhana File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
from typing import Any | ||
from typing import Callable | ||
from typing import Dict | ||
from typing import List | ||
from typing import Literal | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
|
||
def compute_gradients( | ||
loss: torch.Tensor, trainable_vars: List[torch.Tensor] | ||
) -> List[torch.Tensor]: | ||
"""Computes gradients of the loss with respect to trainable variables. | ||
|
||
This function leverages PyTorch's `autograd.grad` for a stateless, | ||
functional approach similar to `jax.grad`. | ||
|
||
Args: | ||
loss (torch.Tensor): The loss value for which to compute gradients. | ||
trainable_vars (List[torch.Tensor]): A list of variables (tensors with | ||
`requires_grad=True`) to compute gradients with respect to. | ||
|
||
Returns: | ||
List[torch.Tensor]: A list of gradients corresponding to the | ||
trainable variables. | ||
""" | ||
return list(torch.autograd.grad(loss, trainable_vars)) | ||
|
||
|
||
def apply_gradients( | ||
gradients: List[torch.Tensor], | ||
trainable_vars: List[torch.Tensor], | ||
learning_rate: float = 0.001, | ||
) -> List[torch.Tensor]: | ||
"""Applies gradients and returns the updated variables. | ||
|
||
Updates are performed in-place within a `torch.no_grad()` context | ||
to prevent the update operation from being part of the computation graph. | ||
""" | ||
with torch.no_grad(): | ||
updated_vars = [] | ||
for grad, var in zip(gradients, trainable_vars): | ||
if grad is not None: | ||
var.sub_(learning_rate * grad) | ||
updated_vars.append(var) | ||
return updated_vars | ||
|
||
|
||
def create_optimizer(optimizer_class: str, **kwargs) -> Dict[str, Any]: | ||
"""Creates a configuration dictionary for a PyTorch optimizer. | ||
|
||
This function returns a dictionary containing the optimizer's configuration, | ||
maintaining a consistent interface with the JAX backend. The user is | ||
expected to instantiate the optimizer from this config. | ||
|
||
Args: | ||
optimizer_class (str): The name of the optimizer to create (e.g., | ||
`"adam"`, `"sgd"`). | ||
**kwargs: Keyword arguments for the optimizer (e.g., `learning_rate`). | ||
|
||
Returns: | ||
Dict[str, Any]: A dictionary representing the optimizer configuration. | ||
""" | ||
config = kwargs.copy() | ||
config["name"] = optimizer_class.lower() | ||
config.setdefault("learning_rate", 0.001) | ||
return config | ||
|
||
|
||
def get_device_info() -> Dict[str, Any]: | ||
"""Retrieves information about the available PyTorch devices. | ||
|
||
Returns: | ||
Dict[str, Any]: A dictionary containing the backend name, a list of | ||
available device strings, and the total device count. | ||
""" | ||
if torch.cuda.is_available(): | ||
device_count = torch.cuda.device_count() | ||
devices = [torch.cuda.get_device_name(i) for i in range(device_count)] | ||
else: | ||
device_count = 1 | ||
devices = ["cpu"] | ||
return { | ||
"backend": "pytorch", | ||
"devices": devices, | ||
"device_count": device_count, | ||
} | ||
|
||
|
||
def is_multi_device_capable() -> bool: | ||
"""Checks if more than one CUDA device is available. | ||
|
||
Returns: | ||
bool: `True` if PyTorch reports more than one CUDA device, `False` | ||
otherwise. | ||
""" | ||
return torch.cuda.device_count() > 1 | ||
|
||
|
||
def get_communication_ops() -> Dict[str, Callable]: | ||
"""Provides a dictionary of PyTorch collective communication operations. | ||
|
||
These operations rely on the `torch.distributed` package. They are | ||
designed to work in a multi-process, multi-device environment. If the | ||
distributed package is not initialized, they provide a sensible fallback | ||
for single-device execution. | ||
|
||
Returns: | ||
Dict[str, Callable]: A dictionary mapping operation names to their | ||
PyTorch implementations. | ||
""" | ||
|
||
def _is_distributed() -> bool: | ||
"""Checks if the default process group is initialized.""" | ||
return dist.is_available() and dist.is_initialized() | ||
|
||
def all_reduce( | ||
x: torch.Tensor, | ||
op: Literal["sum", "mean"] = "sum", | ||
axis_name: str = None, | ||
) -> torch.Tensor: | ||
"""Reduces a tensor across all devices.""" | ||
if not _is_distributed(): | ||
world_size = ( | ||
torch.cuda.device_count() if torch.cuda.is_available() else 1 | ||
) | ||
if world_size <= 1: | ||
return x | ||
if op == "sum": | ||
return x * float(world_size) | ||
elif op == "mean": | ||
return x | ||
else: | ||
raise ValueError(f"Unsupported all_reduce op: {op}") | ||
|
||
reduce_op = {"sum": dist.ReduceOp.SUM, "mean": dist.ReduceOp.AVG}.get( | ||
op | ||
) | ||
if reduce_op is None: | ||
raise ValueError(f"Unsupported all_reduce op: {op}") | ||
|
||
result = x.clone() | ||
dist.all_reduce(result, op=reduce_op) | ||
return result | ||
|
||
def all_gather( | ||
x: torch.Tensor, axis: int = 0, axis_name: str = None | ||
) -> torch.Tensor: | ||
"""Gathers tensors from all devices and concatenates them.""" | ||
if not _is_distributed(): | ||
world_size = ( | ||
torch.cuda.device_count() if torch.cuda.is_available() else 1 | ||
) | ||
if world_size <= 1: | ||
return x | ||
return torch.cat([x] * world_size, dim=axis) | ||
|
||
world_size = dist.get_world_size() | ||
tensor_list = [torch.empty_like(x) for _ in range(world_size)] | ||
dist.all_gather(tensor_list, x) | ||
return torch.cat(tensor_list, dim=axis) | ||
|
||
def broadcast( | ||
x: torch.Tensor, root: int = 0, axis_name: str = None | ||
) -> torch.Tensor: | ||
"""Broadcasts a tensor from a root device to all other devices.""" | ||
if not _is_distributed(): | ||
return x | ||
|
||
dist.broadcast(x, src=root) | ||
return x | ||
|
||
def scatter( | ||
x: torch.Tensor, | ||
root: int = 0, | ||
axis: int = 0, | ||
axis_name: str = None, | ||
) -> torch.Tensor: | ||
"""Scatters a tensor from a root device to all devices.""" | ||
if not _is_distributed(): | ||
world_size = ( | ||
torch.cuda.device_count() if torch.cuda.is_available() else 1 | ||
) | ||
if world_size <= 1: | ||
return x | ||
if x.shape[axis] % world_size != 0: | ||
raise ValueError( | ||
f"Tensor with shape {x.shape} cannot be scattered along " | ||
f"axis {axis} across {world_size} devices." | ||
) | ||
return torch.chunk(x, world_size, dim=axis)[0] | ||
|
||
world_size = dist.get_world_size() | ||
rank = dist.get_rank() | ||
|
||
if x.shape[axis] % world_size != 0: | ||
raise ValueError( | ||
f"Tensor with shape {x.shape} cannot be scattered along " | ||
f"axis {axis} across {world_size} devices." | ||
) | ||
|
||
if rank == root: | ||
scatter_list = list(torch.chunk(x, world_size, dim=axis)) | ||
else: | ||
scatter_list = None | ||
|
||
chunk_shape = list(x.shape) | ||
chunk_shape[axis] //= world_size | ||
local_chunk = torch.empty(chunk_shape, dtype=x.dtype, device=x.device) | ||
|
||
dist.scatter(local_chunk, scatter_list, src=root) | ||
return local_chunk | ||
|
||
return { | ||
"all_reduce": all_reduce, | ||
"all_gather": all_gather, | ||
"broadcast": broadcast, | ||
"scatter": scatter, | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,129 @@ | ||||||||||||||||||
import pytest | ||||||||||||||||||
import torch | ||||||||||||||||||
|
||||||||||||||||||
from keras.src import backend | ||||||||||||||||||
from keras.src.backend import distributed_backend | ||||||||||||||||||
|
||||||||||||||||||
|
||||||||||||||||||
@pytest.mark.skipif( | ||||||||||||||||||
backend.backend() != "torch", | ||||||||||||||||||
reason="Jax Backend specific test", | ||||||||||||||||||
) | ||||||||||||||||||
Comment on lines
+8
to
+11
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||||||||
class TestPytorchDistributedFunctions: | ||||||||||||||||||
"""Unit tests for the PyTorch distributed backend standalone functions.""" | ||||||||||||||||||
|
||||||||||||||||||
def test_compute_gradients_computes_correctly(self): | ||||||||||||||||||
"""Test that compute_gradients returns correct gradients.""" | ||||||||||||||||||
w = torch.tensor([2.0, 3.0], requires_grad=True) | ||||||||||||||||||
b = torch.tensor(1.0, requires_grad=True) | ||||||||||||||||||
x = torch.tensor([4.0, 5.0]) | ||||||||||||||||||
y_true = torch.tensor(25.0) | ||||||||||||||||||
|
||||||||||||||||||
# loss = (w.x + b - y_true)^2 = ((2*4 + 3*5 + 1) - 25)^2 = (24-25)^2 = 1 | ||||||||||||||||||
y_pred = torch.dot(w, x) + b | ||||||||||||||||||
loss = (y_pred - y_true) ** 2 | ||||||||||||||||||
|
||||||||||||||||||
trainable_vars = [w, b] | ||||||||||||||||||
gradients = distributed_backend.compute_gradients(loss, trainable_vars) | ||||||||||||||||||
|
||||||||||||||||||
# d_loss/d_w = 2*(y_pred - y_true)*x = 2*(-1)*[4, 5] = [-8, -10] | ||||||||||||||||||
# d_loss/d_b = 2*(y_pred - y_true)*1 = 2*(-1)*1 = -2 | ||||||||||||||||||
expected_grad_w = torch.tensor([-8.0, -10.0]) | ||||||||||||||||||
expected_grad_b = torch.tensor(-2.0) | ||||||||||||||||||
|
||||||||||||||||||
assert len(gradients) == 2 | ||||||||||||||||||
torch.testing.assert_close(gradients[0], expected_grad_w) | ||||||||||||||||||
torch.testing.assert_close(gradients[1], expected_grad_b) | ||||||||||||||||||
|
||||||||||||||||||
def test_apply_gradients(self): | ||||||||||||||||||
"""Test the application of gradients to PyTorch tensors.""" | ||||||||||||||||||
var1 = torch.tensor([1.0, 2.0], requires_grad=True) | ||||||||||||||||||
var2 = torch.tensor(5.0, requires_grad=True) | ||||||||||||||||||
trainable_vars = [var1, var2] | ||||||||||||||||||
grad1 = torch.tensor([0.1, 0.2]) | ||||||||||||||||||
grad2 = torch.tensor(0.5) | ||||||||||||||||||
gradients = [grad1, grad2] | ||||||||||||||||||
learning_rate = 0.1 | ||||||||||||||||||
|
||||||||||||||||||
original_var1 = var1.clone() | ||||||||||||||||||
original_var2 = var2.clone() | ||||||||||||||||||
|
||||||||||||||||||
updated_vars = distributed_backend.apply_gradients( | ||||||||||||||||||
gradients, trainable_vars, learning_rate | ||||||||||||||||||
) | ||||||||||||||||||
|
||||||||||||||||||
assert updated_vars[0] is var1 | ||||||||||||||||||
assert updated_vars[1] is var2 | ||||||||||||||||||
|
||||||||||||||||||
expected_var1 = original_var1 - (grad1 * learning_rate) | ||||||||||||||||||
expected_var2 = original_var2 - (grad2 * learning_rate) | ||||||||||||||||||
torch.testing.assert_close(updated_vars[0], expected_var1) | ||||||||||||||||||
torch.testing.assert_close(updated_vars[1], expected_var2) | ||||||||||||||||||
|
||||||||||||||||||
def test_create_optimizer(self): | ||||||||||||||||||
"""Test optimizer configuration creation.""" | ||||||||||||||||||
adam_config = distributed_backend.create_optimizer( | ||||||||||||||||||
"adam", learning_rate=0.01 | ||||||||||||||||||
) | ||||||||||||||||||
assert isinstance(adam_config, dict) | ||||||||||||||||||
assert adam_config["name"] == "adam" | ||||||||||||||||||
assert adam_config["learning_rate"] == 0.01 | ||||||||||||||||||
|
||||||||||||||||||
sgd_config = distributed_backend.create_optimizer( | ||||||||||||||||||
"sgd", learning_rate=0.1, momentum=0.9 | ||||||||||||||||||
) | ||||||||||||||||||
assert isinstance(sgd_config, dict) | ||||||||||||||||||
assert sgd_config["name"] == "sgd" | ||||||||||||||||||
assert sgd_config["learning_rate"] == 0.1 | ||||||||||||||||||
assert sgd_config["momentum"] == 0.9 | ||||||||||||||||||
|
||||||||||||||||||
def test_get_device_info(self): | ||||||||||||||||||
"""Test retrieving device information from the PyTorch backend.""" | ||||||||||||||||||
info = distributed_backend.get_device_info() | ||||||||||||||||||
assert info["backend"] == "pytorch" | ||||||||||||||||||
assert isinstance(info["devices"], list) | ||||||||||||||||||
assert isinstance(info["device_count"], int) | ||||||||||||||||||
assert info["device_count"] > 0 | ||||||||||||||||||
assert len(info["devices"]) == info["device_count"] | ||||||||||||||||||
if torch.cuda.is_available(): | ||||||||||||||||||
assert info["device_count"] == torch.cuda.device_count() | ||||||||||||||||||
else: | ||||||||||||||||||
assert info["device_count"] == 1 | ||||||||||||||||||
assert info["devices"] == ["cpu"] | ||||||||||||||||||
|
||||||||||||||||||
def test_is_multi_device_capable(self): | ||||||||||||||||||
"""Test the boolean check for multi-device capability.""" | ||||||||||||||||||
assert isinstance(distributed_backend.is_multi_device_capable(), bool) | ||||||||||||||||||
|
||||||||||||||||||
def test_communication_ops_simulation_logic(self): | ||||||||||||||||||
"""Test the simulated communication ops in a single-device context.""" | ||||||||||||||||||
comm_ops = distributed_backend.get_communication_ops() | ||||||||||||||||||
device_info = distributed_backend.get_device_info() | ||||||||||||||||||
world_size = device_info.get("device_count", 1) | ||||||||||||||||||
|
||||||||||||||||||
# Test all_reduce | ||||||||||||||||||
x_reduce = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) | ||||||||||||||||||
reduced = comm_ops["all_reduce"](x_reduce, op="sum") | ||||||||||||||||||
expected_reduce = ( | ||||||||||||||||||
x_reduce * float(world_size) if world_size > 1 else x_reduce | ||||||||||||||||||
) | ||||||||||||||||||
torch.testing.assert_close(reduced, expected_reduce) | ||||||||||||||||||
|
||||||||||||||||||
# Test all_gather | ||||||||||||||||||
x_gather = torch.tensor([[1.0, 2.0]]) | ||||||||||||||||||
gathered = comm_ops["all_gather"](x_gather, axis=0) | ||||||||||||||||||
expected_gather = torch.cat([x_gather] * world_size, dim=0) | ||||||||||||||||||
torch.testing.assert_close(gathered, expected_gather) | ||||||||||||||||||
|
||||||||||||||||||
# Test broadcast | ||||||||||||||||||
x_broadcast = torch.tensor([5.0, 6.0]) | ||||||||||||||||||
broadcasted = comm_ops["broadcast"](x_broadcast) | ||||||||||||||||||
torch.testing.assert_close(broadcasted, x_broadcast) | ||||||||||||||||||
|
||||||||||||||||||
# Test scatter | ||||||||||||||||||
if world_size > 0: | ||||||||||||||||||
scatter_data = torch.arange(world_size * 4, dtype=torch.float32) | ||||||||||||||||||
x_scatter = scatter_data.reshape(world_size * 2, 2) | ||||||||||||||||||
scattered = comm_ops["scatter"](x_scatter, axis=0) | ||||||||||||||||||
expected_scatter = torch.chunk(x_scatter, world_size, dim=0)[0] | ||||||||||||||||||
torch.testing.assert_close(scattered, expected_scatter) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
apply_gradients
function implements a simple SGD update. As a public function in the distributed backend, this could be misleading for users who might expect it to integrate with their configured Keras optimizer. If this is only intended for testing purposes, consider making it a private function (e.g.,_apply_gradients
) or moving it into the test suite to avoid confusion.