forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 1
[AutoDeploy] dist_ops revisited #96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
lucaslie
wants to merge
3
commits into
feat/ad-2025-07-07
Choose a base branch
from
ll/dist_ops_revisited
base: feat/ad-2025-07-07
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,17 +1,16 @@ | ||
| """Custom ops and make sure they are all registered.""" | ||
| """AutoDeploy's custom ops library. | ||
|
|
||
| from ._triton_attention_internal import * | ||
| from .dist import * | ||
| from .flashinfer_attention import * | ||
| from .flashinfer_rope import * | ||
| from .linear import * | ||
| from .mla import * | ||
| from .quant import * | ||
| from .rms_norm import * | ||
| from .torch_attention import * | ||
| from .torch_backend_attention import * | ||
| from .torch_moe import * | ||
| from .torch_rope import * | ||
| from .triton_attention import * | ||
| from .triton_rope import * | ||
| from .trtllm_moe import * | ||
| This file ensures that all publicly listed files/custom ops in the custom_ops folder are | ||
| auto-imported and the corresponding custom ops are registered. | ||
| """ | ||
|
|
||
| import importlib | ||
| import pkgutil | ||
|
|
||
| __all__ = [] | ||
|
|
||
| for _, module_name, is_pkg in pkgutil.iter_modules(__path__): | ||
| if module_name.startswith("_"): | ||
| continue | ||
| __all__.append(module_name) | ||
| importlib.import_module(f"{__name__}.{module_name}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| """TRT-LLM optimized dist ops.""" | ||
|
|
||
| from typing import List, Optional | ||
|
|
||
| import torch | ||
|
|
||
| from ....functional import AllReduceParams | ||
| from ....mapping import Mapping | ||
| from ...distributed import AllReduce, allgather | ||
| from ...modules.linear import AllReduceFusionOp, AllReduceStrategy | ||
| from ..distributed.common import get_rank_world_size | ||
|
|
||
|
|
||
| @torch.library.custom_op("auto_deploy::trtllm_all_gather", mutates_args=(), device_types="cuda") | ||
| def all_gather( | ||
| tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None | ||
| ) -> torch.Tensor: | ||
| """TRT-LLM all gather followed by concat in dim = 0.""" | ||
| rank, world_size = get_rank_world_size() | ||
| p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) | ||
| result = allgather(tensor, p_config, dim=dim, sizes=sizes) | ||
| assert isinstance(result, torch.Tensor), "Expected tensor result from allgather" | ||
| return result | ||
|
|
||
|
|
||
| @all_gather.register_fake | ||
| def all_gather_fake(tensor, dim=0, sizes=None): | ||
| rank, world_size = get_rank_world_size() | ||
| return torch.cat([torch.empty_like(tensor) for _ in range(world_size)], dim=dim) | ||
|
|
||
|
|
||
| @torch.library.custom_op("auto_deploy::trtllm_all_reduce", mutates_args=(), device_types="cuda") | ||
| def all_reduce(tensor: torch.Tensor, strategy: int = int(AllReduceStrategy.AUTO)) -> torch.Tensor: | ||
| """TRT-LLM all_reduce across the ranks.""" | ||
| rank, world_size = get_rank_world_size() | ||
| p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) | ||
| torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy(strategy)) | ||
| result = torch_op(tensor) | ||
| assert isinstance(result, torch.Tensor), "Expected tensor result from allreduce" | ||
| return result | ||
|
|
||
|
|
||
| @all_reduce.register_fake | ||
| def all_reduce_fake(tensor): | ||
| return torch.empty_like(tensor) | ||
|
|
||
|
|
||
| @torch.library.custom_op( | ||
| "auto_deploy::trtllm_fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda" | ||
| ) | ||
| def trtllm_fused_allreduce_residual_rmsnorm( | ||
| tensor: torch.Tensor, | ||
| residual: torch.Tensor, | ||
| norm_weight: torch.Tensor, | ||
| eps: float, | ||
| strategy: int = int(AllReduceStrategy.AUTO), | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| """Fusing allreduce, residual (add), and hf_rms_norm together.""" | ||
| rank, world_size = get_rank_world_size() | ||
| p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) | ||
|
|
||
| # Use AllReduceParams like the old implementation | ||
| all_reduce_params = AllReduceParams( | ||
| fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, | ||
| bias=None, | ||
| residual=residual, | ||
| norm_weight=norm_weight, | ||
| eps=eps, | ||
| ) | ||
|
|
||
| torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy(strategy)) | ||
| output = torch_op(tensor, all_reduce_params=all_reduce_params) | ||
| assert len(output) == 2, "Expected 2 outputs from trtllm_fused_allreduce_residual_rmsnorm" | ||
| return output[0], output[1] | ||
|
|
||
|
|
||
| @trtllm_fused_allreduce_residual_rmsnorm.register_fake | ||
| def trtllm_fused_allreduce_residual_rmsnorm_fake( | ||
| tensor: torch.Tensor, | ||
| residual: torch.Tensor, | ||
| norm_weight: torch.Tensor, | ||
| eps: float, | ||
| strategy: int = int(AllReduceStrategy.AUTO), | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| return torch.empty_like(tensor), torch.empty_like(tensor) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from .common import * |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,15 @@ | ||
| """Common utilities for distributed inference.""" | ||
| """Common utilities for distributed inference. | ||
|
|
||
| These utilities serve two purposes: | ||
|
|
||
| 1. Drop-in replacement for torch.distributed to ensure that any function in torch.distributed works | ||
| out of the box. | ||
| 2. Provide a simple interface to spawn multiple processes and communicate with them. We support | ||
| three supports: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| a. MPI initialized via TRT-LLM runtime or mpirun. | ||
| b. TorchElastic initialized via torchrun. | ||
| c. Simple python multiprocessing started explicitly. | ||
| """ | ||
|
|
||
| import atexit | ||
| import os | ||
|
|
@@ -12,8 +23,6 @@ | |
|
|
||
| from ..utils.logger import ad_logger | ||
|
|
||
| # TODO: check to what extend we can reuse _torch/distributed.py | ||
|
|
||
|
|
||
| class _DistGroup: | ||
| """Global instance to set/get the default process group for distributed ops.""" | ||
|
|
@@ -93,17 +102,22 @@ def initialize_or_skip(*args, **kwargs) -> Tuple[int, int]: | |
| return get_rank(), get_world_size() | ||
|
|
||
|
|
||
| def is_ompi(): | ||
| def is_ompi() -> bool: | ||
| """Check whether multi-processing was initialized with explicitly calling mpirun.""" | ||
| return "OMPI_COMM_WORLD_SIZE" in os.environ | ||
|
|
||
|
|
||
| def is_torchelastic(): | ||
| def is_trtllm_dist_available() -> bool: | ||
| """Check if TRT-LLM distributed operations are available (they only work with MPI!).""" | ||
| return is_ompi() | ||
|
|
||
|
|
||
| def is_torchelastic() -> bool: | ||
| """Check whether multi-processing was initialized with torchelastic.""" | ||
| return "TORCHELASTIC_RUN_ID" in os.environ | ||
|
|
||
|
|
||
| def cleanup(): | ||
| def cleanup() -> None: | ||
| """Destroy process group when the program exits.""" | ||
| if dist.is_initialized(): | ||
| ad_logger.info("Destroying process group") | ||
|
|
||
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.