-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Introducing Parameter Sharding and Torch backend for Tensor Parallelism #21724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
buildwithsuhana
wants to merge
5
commits into
keras-team:master
Choose a base branch
from
buildwithsuhana:Tensor_parallel_keras_3
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 4 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
e8f1a5a
Adding parameter sharding and test
buildwithsuhana c01ec8a
Added docstrings to parameter_shardimg
buildwithsuhana c22facf
Added docstrings
buildwithsuhana 276530a
Torch backend added
buildwithsuhana 6c0189e
Added torch backend and fixed circular import issue
buildwithsuhana File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
from typing import Any | ||
from typing import Callable | ||
from typing import Dict | ||
from typing import List | ||
from typing import Literal | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
|
||
def compute_gradients( | ||
loss: torch.Tensor, trainable_vars: List[torch.Tensor] | ||
) -> List[torch.Tensor]: | ||
"""Computes gradients of the loss with respect to trainable variables. | ||
|
||
This function leverages PyTorch's `autograd.grad` for a stateless, | ||
functional approach similar to `jax.grad`. | ||
|
||
Args: | ||
loss (torch.Tensor): The loss value for which to compute gradients. | ||
trainable_vars (List[torch.Tensor]): A list of variables (tensors with | ||
`requires_grad=True`) to compute gradients with respect to. | ||
|
||
Returns: | ||
List[torch.Tensor]: A list of gradients corresponding to the | ||
trainable variables. | ||
""" | ||
return list(torch.autograd.grad(loss, trainable_vars)) | ||
|
||
|
||
def apply_gradients( | ||
gradients: List[torch.Tensor], | ||
trainable_vars: List[torch.Tensor], | ||
learning_rate: float = 0.001, | ||
) -> List[torch.Tensor]: | ||
"""Applies gradients and returns the updated variables. | ||
|
||
Updates are performed in-place within a `torch.no_grad()` context | ||
to prevent the update operation from being part of the computation graph. | ||
""" | ||
with torch.no_grad(): | ||
updated_vars = [] | ||
for grad, var in zip(gradients, trainable_vars): | ||
if grad is not None: | ||
var.sub_(learning_rate * grad) | ||
updated_vars.append(var) | ||
return updated_vars | ||
|
||
|
||
def create_optimizer(optimizer_class: str, **kwargs) -> Dict[str, Any]: | ||
"""Creates a configuration dictionary for a PyTorch optimizer. | ||
|
||
This function returns a dictionary containing the optimizer's configuration, | ||
maintaining a consistent interface with the JAX backend. The user is | ||
expected to instantiate the optimizer from this config. | ||
|
||
Args: | ||
optimizer_class (str): The name of the optimizer to create (e.g., | ||
`"adam"`, `"sgd"`). | ||
**kwargs: Keyword arguments for the optimizer (e.g., `learning_rate`). | ||
|
||
Returns: | ||
Dict[str, Any]: A dictionary representing the optimizer configuration. | ||
""" | ||
config = kwargs.copy() | ||
config["name"] = optimizer_class.lower() | ||
config.setdefault("learning_rate", 0.001) | ||
return config | ||
|
||
|
||
def get_device_info() -> Dict[str, Any]: | ||
"""Retrieves information about the available PyTorch devices. | ||
|
||
Returns: | ||
Dict[str, Any]: A dictionary containing the backend name, a list of | ||
available device strings, and the total device count. | ||
""" | ||
if torch.cuda.is_available(): | ||
device_count = torch.cuda.device_count() | ||
devices = [torch.cuda.get_device_name(i) for i in range(device_count)] | ||
else: | ||
device_count = 1 | ||
devices = ["cpu"] | ||
return { | ||
"backend": "pytorch", | ||
"devices": devices, | ||
"device_count": device_count, | ||
} | ||
|
||
|
||
def is_multi_device_capable() -> bool: | ||
"""Checks if more than one CUDA device is available. | ||
|
||
Returns: | ||
bool: `True` if PyTorch reports more than one CUDA device, `False` | ||
otherwise. | ||
""" | ||
return torch.cuda.device_count() > 1 | ||
|
||
|
||
def get_communication_ops() -> Dict[str, Callable]: | ||
"""Provides a dictionary of PyTorch collective communication operations. | ||
|
||
These operations rely on the `torch.distributed` package. They are | ||
designed to work in a multi-process, multi-device environment. If the | ||
distributed package is not initialized, they provide a sensible fallback | ||
for single-device execution. | ||
|
||
Returns: | ||
Dict[str, Callable]: A dictionary mapping operation names to their | ||
PyTorch implementations. | ||
""" | ||
|
||
def _is_distributed() -> bool: | ||
"""Checks if the default process group is initialized.""" | ||
return dist.is_available() and dist.is_initialized() | ||
|
||
def all_reduce( | ||
x: torch.Tensor, | ||
op: Literal["sum", "mean"] = "sum", | ||
) -> torch.Tensor: | ||
"""Reduces a tensor across all devices. | ||
|
||
Args: | ||
x (torch.Tensor): The tensor to reduce. | ||
op (Literal["sum", "mean"], optional): The reduction operation. | ||
Defaults to "sum". | ||
|
||
Returns: | ||
torch.Tensor: The reduced tensor. | ||
""" | ||
if not _is_distributed(): | ||
world_size = ( | ||
torch.cuda.device_count() if torch.cuda.is_available() else 1 | ||
) | ||
if world_size <= 1: | ||
return x | ||
if op == "sum": | ||
return x * float(world_size) | ||
elif op == "mean": | ||
return x | ||
else: | ||
raise ValueError(f"Unsupported all_reduce op: {op}") | ||
|
||
reduce_op = {"sum": dist.ReduceOp.SUM, "mean": dist.ReduceOp.AVG}.get( | ||
op | ||
) | ||
if reduce_op is None: | ||
raise ValueError(f"Unsupported all_reduce op: {op}") | ||
|
||
result = x.clone() | ||
dist.all_reduce(result, op=reduce_op) | ||
return result | ||
|
||
def all_gather(x: torch.Tensor, axis: int = 0) -> torch.Tensor: | ||
"""Gathers tensors from all devices and concatenates them. | ||
|
||
Args: | ||
x (torch.Tensor): The local tensor to gather. | ||
axis (int, optional): The axis along which to concatenate. | ||
Defaults to 0. | ||
|
||
Returns: | ||
torch.Tensor: The concatenated tensor from all devices. | ||
""" | ||
if not _is_distributed(): | ||
world_size = ( | ||
torch.cuda.device_count() if torch.cuda.is_available() else 1 | ||
) | ||
if world_size <= 1: | ||
return x | ||
return torch.cat([x] * world_size, dim=axis) | ||
|
||
world_size = dist.get_world_size() | ||
tensor_list = [torch.empty_like(x) for _ in range(world_size)] | ||
dist.all_gather(tensor_list, x) | ||
return torch.cat(tensor_list, dim=axis) | ||
|
||
def broadcast(x: torch.Tensor, root: int = 0) -> torch.Tensor: | ||
"""Broadcasts a tensor from a root device to all other devices. | ||
|
||
Args: | ||
x (torch.Tensor): The tensor to broadcast. | ||
root (int, optional): The rank of the source device. Defaults to 0. | ||
|
||
Returns: | ||
torch.Tensor: The tensor received from the root device. | ||
""" | ||
if not _is_distributed(): | ||
return x | ||
|
||
# `dist.broadcast` is in-place. | ||
dist.broadcast(x, src=root) | ||
return x | ||
|
||
def scatter( | ||
x: torch.Tensor, | ||
root: int = 0, | ||
axis: int = 0, | ||
) -> torch.Tensor: | ||
"""Scatters a tensor from a root device to all devices. | ||
|
||
Note: The current implementation of `dist.scatter` requires the input | ||
tensor `x` to be organized differently for the root process. This | ||
wrapper simplifies it by handling the splitting automatically on the | ||
root process. | ||
|
||
Args: | ||
x (torch.Tensor): The tensor on the root device to be scattered. | ||
root (int, optional): The rank of the device holding the tensor. | ||
Defaults to 0. | ||
axis (int, optional): The axis along which to split the tensor. | ||
Defaults to 0. | ||
|
||
Returns: | ||
torch.Tensor: The chunk of the tensor for the local device. | ||
""" | ||
if not _is_distributed(): | ||
world_size = ( | ||
torch.cuda.device_count() if torch.cuda.is_available() else 1 | ||
) | ||
if world_size <= 1: | ||
return x | ||
if x.shape[axis] % world_size != 0: | ||
raise ValueError( | ||
f"Tensor with shape {x.shape} cannot be scattered along " | ||
f"axis {axis} across {world_size} devices." | ||
) | ||
return torch.chunk(x, world_size, dim=axis)[0] | ||
|
||
world_size = dist.get_world_size() | ||
rank = dist.get_rank() | ||
|
||
if x.shape[axis] % world_size != 0: | ||
raise ValueError( | ||
f"Tensor with shape {x.shape} cannot be scattered along " | ||
f"axis {axis} across {world_size} devices." | ||
) | ||
|
||
if rank == root: | ||
scatter_list = list(torch.chunk(x, world_size, dim=axis)) | ||
else: | ||
scatter_list = None | ||
|
||
chunk_shape = list(x.shape) | ||
chunk_shape[axis] //= world_size | ||
local_chunk = torch.empty(chunk_shape, dtype=x.dtype, device=x.device) | ||
|
||
dist.scatter(local_chunk, scatter_list, src=root) | ||
return local_chunk | ||
|
||
return { | ||
"all_reduce": all_reduce, | ||
"all_gather": all_gather, | ||
"broadcast": broadcast, | ||
"scatter": scatter, | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
apply_gradients
function implements a simple SGD update. As a public function in the distributed backend, this could be misleading for users who might expect it to integrate with their configured Keras optimizer. If this is only intended for testing purposes, consider making it a private function (e.g.,_apply_gradients
) or moving it into the test suite to avoid confusion.