Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ The table below lists the operators ordered by their backend.
| `torch.ops.auto_deploy.torch_attention_grouped_sdpa` | Grouped SDPA implementation |
| `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for attention |
| `torch.ops.auto_deploy.torch_attention_sdpa` | Standard SDPA implementation |
| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation |
| `torch.ops.auto_deploy.torch_dist_all_reduce` | Distributed all-reduce operation |
| `torch.ops.auto_deploy.torch_all_gather` | Distributed all-gather operation |
| `torch.ops.auto_deploy.torch_all_reduce` | Distributed all-reduce operation |
| `torch.ops.auto_deploy.torch_linear_simple` | Simple linear layer implementation |
| `torch.ops.auto_deploy.torch_moe` | Mixture of Experts implementation |
| `torch.ops.auto_deploy.torch_moe_fused` | Fused Mixture of Experts implementation |
Expand All @@ -39,4 +39,4 @@ The table below lists the operators ordered by their backend.
| `torch.ops.auto_deploy.triton_rope_on_flattened_inputs` | Triton RoPE on flattened inputs |
| `torch.ops.auto_deploy.triton_rope_with_input_pos` | Triton RoPE with input positions |
| `torch.ops.auto_deploy.trtllm_moe_fused` | TensorRT-LLM fused MoE implementation |
| `torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce` | TensorRT-LLM fused linear layer followed by all-reduce operation |
| `torch.ops.auto_deploy.torch_fused_linear_all_reduce` | fused linear layer followed by all-reduce operation |
31 changes: 15 additions & 16 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
"""Custom ops and make sure they are all registered."""
"""AutoDeploy's custom ops library.

from ._triton_attention_internal import *
from .dist import *
from .flashinfer_attention import *
from .flashinfer_rope import *
from .linear import *
from .mla import *
from .quant import *
from .rms_norm import *
from .torch_attention import *
from .torch_backend_attention import *
from .torch_moe import *
from .torch_rope import *
from .triton_attention import *
from .triton_rope import *
from .trtllm_moe import *
This file ensures that all publicly listed files/custom ops in the custom_ops folder are
auto-imported and the corresponding custom ops are registered.
"""

import importlib
import pkgutil

__all__ = []

for _, module_name, is_pkg in pkgutil.iter_modules(__path__):
if module_name.startswith("_"):
continue
__all__.append(module_name)
importlib.import_module(f"{__name__}.{module_name}")
25 changes: 0 additions & 25 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@

import torch

from ..distributed import common as dist
from ..distributed import trtllm as trtllm_dist


@torch.library.custom_op("auto_deploy::torch_linear_simple", mutates_args=())
def simple(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor:
Expand All @@ -24,26 +21,4 @@ def simple(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tenso
@simple.register_fake
def simple_fake(input, weight, bias):
"""Fake implementation of simple_linear."""
# return torch.empty(
# input.shape[:-1] + (weight.shape[-1],), dtype=input.dtype, device=input.device
# )
return torch.ops.aten.linear(input, weight, bias)


@torch.library.custom_op(
"auto_deploy::trtllm_dist_fused_linear_all_reduce", mutates_args=(), device_types="cuda"
)
def fused_linear_all_reduce(
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor:
"""Fused linear followed by all_reduce on the output."""
output = torch.ops.aten.linear(input, weight, bias)
if trtllm_dist.is_trtllm_op_available():
return trtllm_dist.trtllm_allreduce(output, op=dist.ReduceOp.SUM)
dist.all_reduce(output, op=dist.ReduceOp.SUM)
return output


@fused_linear_all_reduce.register_fake
def fused_linear_all_reduce_fake(input, weight, bias):
return torch.ops.aten.linear(input, weight, bias)
3 changes: 0 additions & 3 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from tensorrt_llm._torch.autotuner import autotune

from ..distributed import common as dist
from ..distributed import trtllm as trtllm_dist
from .torch_libs.float8_python_api import addmm_float8_unwrapped

TRTLLM_FP4_OP_AVAILABLE = True
Expand Down Expand Up @@ -119,8 +118,6 @@ def fused_fp8_linear_all_reduce(
out = torch.ops.auto_deploy.torch_quant_fp8_linear(
input, weight_fp8, bias, input_scale, weight_scale
)
if trtllm_dist.is_trtllm_op_available():
return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM)
dist.all_reduce(out, op=dist.ReduceOp.SUM)
return out

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,13 @@
import torch

from ..distributed import common as dist
from ..distributed import trtllm as trtllm_dist


@torch.library.custom_op("auto_deploy::torch_dist_all_gather", mutates_args=(), device_types="cuda")
@torch.library.custom_op("auto_deploy::torch_all_gather", mutates_args=(), device_types="cuda")
def all_gather(
tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None
) -> torch.Tensor:
"""All gather followed by concat in dim = 0. This is the default nccl behavior."""
if trtllm_dist.is_trtllm_op_available():
return trtllm_dist.trtllm_allgather(tensor, dim=dim, sizes=sizes)
tl = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())]
dist.all_gather(tl, tensor)
return torch.cat(tl, dim=dim)
Expand All @@ -25,15 +22,13 @@ def all_gather_fake(tensor, dim=0):
return torch.cat([torch.empty_like(tensor) for _ in range(dist.get_world_size())], dim=dim)


@torch.library.custom_op("auto_deploy::torch_dist_all_reduce", mutates_args=(), device_types="cuda")
@torch.library.custom_op("auto_deploy::torch_all_reduce", mutates_args=(), device_types="cuda")
def all_reduce(t: torch.Tensor) -> torch.Tensor:
"""All_reduce across the ranks. Reduction op is SUM.

NOTE: this op requires an extra memory copy and should ONLY be used for debugging + testing. For
efficient all_reduce ops one should write/replace it with a fused op.
"""
if trtllm_dist.is_trtllm_op_available():
return trtllm_dist.trtllm_allreduce(t, op=dist.ReduceOp.SUM)
t_res = t.clone()
dist.all_reduce(t_res, op=dist.ReduceOp.SUM)
return t_res
Expand All @@ -42,3 +37,20 @@ def all_reduce(t: torch.Tensor) -> torch.Tensor:
@all_reduce.register_fake
def all_reduce_fake(tensor):
return torch.empty_like(tensor)


@torch.library.custom_op(
"auto_deploy::torch_fused_linear_all_reduce", mutates_args=(), device_types="cuda"
)
def fused_linear_all_reduce(
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor:
"""Fused linear followed by in-place all_reduce on the output."""
output = torch.ops.aten.linear(input, weight, bias)
dist.all_reduce(output, op=dist.ReduceOp.SUM)
return output


@fused_linear_all_reduce.register_fake
def fused_linear_all_reduce_fake(input, weight, bias):
return torch.ops.aten.linear(input, weight, bias)
85 changes: 85 additions & 0 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/trtllm_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""TRT-LLM optimized dist ops."""

from typing import List, Optional

import torch

from ....functional import AllReduceParams
from ....mapping import Mapping
from ...distributed import AllReduce, allgather
from ...modules.linear import AllReduceFusionOp, AllReduceStrategy
from ..distributed.common import get_rank_world_size


@torch.library.custom_op("auto_deploy::trtllm_all_gather", mutates_args=(), device_types="cuda")
def all_gather(
tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None
) -> torch.Tensor:
"""TRT-LLM all gather followed by concat in dim = 0."""
rank, world_size = get_rank_world_size()
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
result = allgather(tensor, p_config, dim=dim, sizes=sizes)
assert isinstance(result, torch.Tensor), "Expected tensor result from allgather"
return result


@all_gather.register_fake
def all_gather_fake(tensor, dim=0, sizes=None):
rank, world_size = get_rank_world_size()
return torch.cat([torch.empty_like(tensor) for _ in range(world_size)], dim=dim)


@torch.library.custom_op("auto_deploy::trtllm_all_reduce", mutates_args=(), device_types="cuda")
def all_reduce(tensor: torch.Tensor, strategy: int = int(AllReduceStrategy.AUTO)) -> torch.Tensor:
"""TRT-LLM all_reduce across the ranks."""
rank, world_size = get_rank_world_size()
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy(strategy))
result = torch_op(tensor)
assert isinstance(result, torch.Tensor), "Expected tensor result from allreduce"
return result


@all_reduce.register_fake
def all_reduce_fake(tensor):
return torch.empty_like(tensor)


@torch.library.custom_op(
"auto_deploy::trtllm_fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda"
)
def trtllm_fused_allreduce_residual_rmsnorm(
tensor: torch.Tensor,
residual: torch.Tensor,
norm_weight: torch.Tensor,
eps: float,
strategy: int = int(AllReduceStrategy.AUTO),
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fusing allreduce, residual (add), and hf_rms_norm together."""
rank, world_size = get_rank_world_size()
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)

# Use AllReduceParams like the old implementation
all_reduce_params = AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
bias=None,
residual=residual,
norm_weight=norm_weight,
eps=eps,
)

torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy(strategy))
output = torch_op(tensor, all_reduce_params=all_reduce_params)
assert len(output) == 2, "Expected 2 outputs from trtllm_fused_allreduce_residual_rmsnorm"
return output[0], output[1]


@trtllm_fused_allreduce_residual_rmsnorm.register_fake
def trtllm_fused_allreduce_residual_rmsnorm_fake(
tensor: torch.Tensor,
residual: torch.Tensor,
norm_weight: torch.Tensor,
eps: float,
strategy: int = int(AllReduceStrategy.AUTO),
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(tensor), torch.empty_like(tensor)
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/auto_deploy/distributed/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .common import *
26 changes: 20 additions & 6 deletions tensorrt_llm/_torch/auto_deploy/distributed/common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
"""Common utilities for distributed inference."""
"""Common utilities for distributed inference.

These utilities serve two purposes:

1. Drop-in replacement for torch.distributed to ensure that any function in torch.distributed works
out of the box.
2. Provide a simple interface to spawn multiple processes and communicate with them. We support
three supports:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We support three supports: ==> three modes?

a. MPI initialized via TRT-LLM runtime or mpirun.
b. TorchElastic initialized via torchrun.
c. Simple python multiprocessing started explicitly.
"""

import atexit
import os
Expand All @@ -12,8 +23,6 @@

from ..utils.logger import ad_logger

# TODO: check to what extend we can reuse _torch/distributed.py


class _DistGroup:
"""Global instance to set/get the default process group for distributed ops."""
Expand Down Expand Up @@ -93,17 +102,22 @@ def initialize_or_skip(*args, **kwargs) -> Tuple[int, int]:
return get_rank(), get_world_size()


def is_ompi():
def is_ompi() -> bool:
"""Check whether multi-processing was initialized with explicitly calling mpirun."""
return "OMPI_COMM_WORLD_SIZE" in os.environ


def is_torchelastic():
def is_trtllm_dist_available() -> bool:
"""Check if TRT-LLM distributed operations are available (they only work with MPI!)."""
return is_ompi()


def is_torchelastic() -> bool:
"""Check whether multi-processing was initialized with torchelastic."""
return "TORCHELASTIC_RUN_ID" in os.environ


def cleanup():
def cleanup() -> None:
"""Destroy process group when the program exits."""
if dist.is_initialized():
ad_logger.info("Destroying process group")
Expand Down
59 changes: 0 additions & 59 deletions tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py

This file was deleted.

Loading