Skip to content

Commit c9771eb

Browse files
authored
[#9198][feat] Refactor dist ops in AutoDeploy (#9301)
Signed-off-by: Eran Geva <[email protected]>
1 parent 0a2104d commit c9771eb

File tree

19 files changed

+631
-326
lines changed

19 files changed

+631
-326
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ transforms:
8080
support_partial_config: true
8181
sharding_dims: ['tp', 'ep', 'bmm']
8282
allreduce_strategy: 'AUTO'
83+
dist_backend: auto
8384
requires_shape_prop: true
8485
sharding_transform_executor:
8586
stage: sharding

tensorrt_llm/_torch/auto_deploy/custom_ops/README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@ The table below lists the operators ordered by their backend.
1717
| `torch.ops.auto_deploy.torch_attention` | Grouped SDPA implementation with `bsnd` and `bnsd` layout supported |
1818
| `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for attention |
1919
| `torch.ops.auto_deploy.torch_attention_sdpa` | Standard SDPA implementation |
20-
| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation |
21-
| `torch.ops.auto_deploy.torch_dist_all_reduce` | Distributed all-reduce operation |
20+
| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation (PyTorch backend, demollm mode) |
21+
| `torch.ops.auto_deploy.torch_dist_all_reduce` | Distributed all-reduce operation (PyTorch backend, demollm mode) |
2222
| `torch.ops.auto_deploy.torch_linear_simple` | Simple linear layer implementation |
2323
| `torch.ops.auto_deploy.torch_moe` | Mixture of Experts implementation |
2424
| `torch.ops.auto_deploy.torch_moe_fused` | Fused Mixture of Experts implementation |
2525
| `torch.ops.auto_deploy.torch_quant_fn` | Generic quantization function that scales, rounds, and clamps input values |
26-
| `torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce` | Fused FP8 linear layer followed by all-reduce operation |
2726
| `torch.ops.auto_deploy.torch_quant_nvfp4_linear` | FP4 quantized linear layer |
2827
| `torch.ops.auto_deploy.torch_quant_fp8_linear` | FP8 quantized linear layer |
2928
| `torch.ops.auto_deploy.torch_rope_with_complex_freqs` | RoPE with complex frequencies |
@@ -38,4 +37,6 @@ The table below lists the operators ordered by their backend.
3837
| `torch.ops.auto_deploy.triton_rope_on_flattened_inputs` | Triton RoPE on flattened inputs |
3938
| `torch.ops.auto_deploy.triton_rope_with_input_pos` | Triton RoPE with input positions |
4039
| `torch.ops.auto_deploy.trtllm_moe_fused` | TensorRT LLM fused MoE implementation |
41-
| `torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce` | TensorRT LLM fused linear layer followed by all-reduce operation |
40+
| `torch.ops.auto_deploy.trtllm_dist_all_gather` | Distributed all-gather operation (TRT-LLM backend, MPI mode) |
41+
| `torch.ops.auto_deploy.trtllm_dist_all_reduce` | Distributed all-reduce operation (TRT-LLM backend, MPI mode) |
42+
| `torch.ops.dist.trtllm_fused_allreduce_residual_rmsnorm` | Fused all-reduce + residual add + RMSNorm (TRT-LLM backend, MPI mode) |

tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44

55
import torch
66

7-
from ..distributed import common as dist
8-
from ..distributed import trtllm as trtllm_dist
9-
107

118
@torch.library.custom_op("auto_deploy::torch_linear_simple", mutates_args=())
129
def simple(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor:
@@ -24,26 +21,4 @@ def simple(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tenso
2421
@simple.register_fake
2522
def simple_fake(input, weight, bias):
2623
"""Fake implementation of simple_linear."""
27-
# return torch.empty(
28-
# input.shape[:-1] + (weight.shape[-1],), dtype=input.dtype, device=input.device
29-
# )
30-
return torch.ops.aten.linear(input, weight, bias)
31-
32-
33-
@torch.library.custom_op(
34-
"auto_deploy::trtllm_dist_fused_linear_all_reduce", mutates_args=(), device_types="cuda"
35-
)
36-
def fused_linear_all_reduce(
37-
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], strategy: str
38-
) -> torch.Tensor:
39-
"""Fused linear followed by all_reduce on the output. Strategy is MANDATORY."""
40-
output = torch.ops.aten.linear(input, weight, bias)
41-
if trtllm_dist.is_trtllm_op_available():
42-
return trtllm_dist.trtllm_allreduce(output, op=dist.ReduceOp.SUM, strategy=strategy)
43-
dist.all_reduce(output, op=dist.ReduceOp.SUM)
44-
return output
45-
46-
47-
@fused_linear_all_reduce.register_fake
48-
def fused_linear_all_reduce_fake(input, weight, bias, strategy):
4924
return torch.ops.aten.linear(input, weight, bias)

tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from flashinfer import bmm_fp8
88
from torch import nn
99

10-
from ..distributed import common as dist
11-
from ..distributed import trtllm as trtllm_dist
1210
from .torch_libs.float8_python_api import addmm_float8_unwrapped
1311

1412
TRTLLM_FP4_OP_AVAILABLE = True
@@ -238,39 +236,6 @@ def fp8_linear_fake(
238236
return torch.ops.aten.linear(input, weight_fp8.to(input.dtype), bias)
239237

240238

241-
@torch.library.custom_op("auto_deploy::torch_quant_fused_fp8_linear_all_reduce", mutates_args=())
242-
@torch.compile(dynamic=True)
243-
def fused_fp8_linear_all_reduce(
244-
input: torch.Tensor,
245-
weight_fp8: torch.Tensor,
246-
strategy: str,
247-
bias: Optional[torch.Tensor] = None,
248-
input_scale: Optional[torch.Tensor] = None,
249-
weight_scale: Optional[torch.Tensor] = None,
250-
) -> torch.Tensor:
251-
out = torch.ops.auto_deploy.torch_quant_fp8_linear(
252-
input, weight_fp8, bias, input_scale, weight_scale
253-
)
254-
if trtllm_dist.is_trtllm_op_available():
255-
return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM, strategy=strategy)
256-
dist.all_reduce(out, op=dist.ReduceOp.SUM)
257-
return out
258-
259-
260-
@fused_fp8_linear_all_reduce.register_fake
261-
def fused_fp8_linear_all_reduce_fake(
262-
input: torch.Tensor,
263-
weight_fp8: torch.Tensor,
264-
strategy: str,
265-
bias: Optional[torch.Tensor] = None,
266-
input_scale: Optional[torch.Tensor] = None,
267-
weight_scale: Optional[torch.Tensor] = None,
268-
) -> torch.Tensor:
269-
return torch.ops.auto_deploy.torch_quant_fp8_linear(
270-
input, weight_fp8, bias, input_scale, weight_scale
271-
)
272-
273-
274239
class FP8Linear(nn.Linear):
275240
def __init__(self, *args, **kwargs):
276241
super().__init__(*args, **kwargs)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""Custom ops required for implementing tensor parallelism.
2+
3+
This module defines atomic distributed ops - each op uses a specific backend
4+
(torch.distributed or TRT-LLM) without internal dispatch logic.
5+
"""
6+
7+
from typing import List, Optional
8+
9+
import torch
10+
11+
from ..distributed import common as dist
12+
13+
# ============================================================================
14+
# PyTorch Distributed Backend Ops (demollm mode)
15+
# ============================================================================
16+
17+
18+
@torch.library.custom_op("auto_deploy::torch_dist_all_gather", mutates_args=(), device_types="cuda")
19+
def torch_dist_all_gather(
20+
tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None
21+
) -> torch.Tensor:
22+
"""All gather using PyTorch distributed backend.
23+
24+
This op always uses torch.distributed.all_gather and is used in demollm mode.
25+
"""
26+
tl = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())]
27+
dist.all_gather(tl, tensor)
28+
return torch.cat(tl, dim=dim)
29+
30+
31+
@torch_dist_all_gather.register_fake
32+
def torch_dist_all_gather_fake(tensor, dim=0, sizes=None):
33+
return torch.cat([torch.empty_like(tensor) for _ in range(dist.get_world_size())], dim=dim)
34+
35+
36+
@torch.library.custom_op("auto_deploy::torch_dist_all_reduce", mutates_args=(), device_types="cuda")
37+
def torch_dist_all_reduce(t: torch.Tensor, strategy: str) -> torch.Tensor:
38+
"""All_reduce using PyTorch distributed backend. Reduction op is SUM.
39+
40+
This op always uses torch.distributed.all_reduce and is used in demollm mode.
41+
42+
NOTE: this op requires an extra memory copy and should ONLY be used for debugging + testing. For
43+
efficient all_reduce ops one should write/replace it with a fused op.
44+
"""
45+
t_res = t.clone()
46+
dist.all_reduce(t_res, op=dist.ReduceOp.SUM)
47+
return t_res
48+
49+
50+
@torch_dist_all_reduce.register_fake
51+
def torch_dist_all_reduce_fake(tensor, strategy):
52+
return torch.empty_like(tensor)
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
"""TRT-LLM distributed operations and fused kernels.
2+
3+
This module defines atomic TRT-LLM-specific ops that use optimized kernels.
4+
The torch fallback variants are defined separately to enable multi-pattern matching.
5+
"""
6+
7+
from typing import List, Optional
8+
9+
import torch
10+
11+
# use trtllm distributed ops to improve TP performance if possible
12+
from ....mapping import Mapping
13+
from ...distributed import AllReduce, allgather
14+
from ...modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy
15+
from ..distributed.common import ReduceOp, get_rank_world_size, get_world_size, is_ompi
16+
17+
# Cache AllReduce modules to avoid recreating on every call
18+
# This is critical for CUDA graph compatibility - recreating modules during
19+
# warmup causes hangs due to workspace allocation with CPU synchronization
20+
_allreduce_cache = {}
21+
22+
23+
def trtllm_allgather(tensor, dim, sizes=None):
24+
rank, world_size = get_rank_world_size()
25+
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
26+
return allgather(tensor, p_config, dim=dim, sizes=sizes)
27+
28+
29+
def trtllm_allreduce(tensor, op, strategy: str, all_reduce_params=None):
30+
rank, world_size = get_rank_world_size()
31+
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."
32+
33+
# Convert string strategy to enum
34+
try:
35+
strategy_enum = getattr(AllReduceStrategy, strategy)
36+
except AttributeError:
37+
raise ValueError(
38+
f"Invalid allreduce strategy: {strategy}. "
39+
f"Valid options: AUTO, NCCL, ONESHOT, TWOSHOT, MIN_LATENCY, "
40+
f"LOWPRECISION, UB, MNNVL, NCCL_SYMMETRIC"
41+
)
42+
43+
# Cache key includes rank, world_size, dtype, and strategy to handle different configurations
44+
cache_key = (rank, world_size, tensor.dtype, strategy_enum)
45+
if cache_key not in _allreduce_cache:
46+
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
47+
_allreduce_cache[cache_key] = AllReduce(
48+
mapping=p_config, strategy=strategy_enum, dtype=tensor.dtype
49+
)
50+
51+
torch_op = _allreduce_cache[cache_key]
52+
return torch_op(tensor, all_reduce_params=all_reduce_params)
53+
54+
55+
# ============================================================================
56+
# TRT-LLM Backend Ops (MPI mode)
57+
# ============================================================================
58+
59+
60+
@torch.library.custom_op(
61+
"auto_deploy::trtllm_dist_all_gather", mutates_args=(), device_types="cuda"
62+
)
63+
def trtllm_dist_all_gather(
64+
tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None
65+
) -> torch.Tensor:
66+
"""All gather using TRT-LLM optimized backend.
67+
68+
This op always uses TRT-LLM's optimized allgather and is used in MPI mode.
69+
"""
70+
return trtllm_allgather(tensor, dim=dim, sizes=sizes)
71+
72+
73+
@trtllm_dist_all_gather.register_fake
74+
def trtllm_dist_all_gather_fake(tensor, dim=0, sizes=None):
75+
return torch.cat([torch.empty_like(tensor) for _ in range(get_world_size())], dim=dim)
76+
77+
78+
@torch.library.custom_op(
79+
"auto_deploy::trtllm_dist_all_reduce", mutates_args=(), device_types="cuda"
80+
)
81+
def trtllm_dist_all_reduce(t: torch.Tensor, strategy: str) -> torch.Tensor:
82+
"""All_reduce using TRT-LLM optimized backend. Reduction op is SUM.
83+
84+
This op always uses TRT-LLM's optimized allreduce and is used in MPI mode.
85+
"""
86+
return trtllm_allreduce(t, op=ReduceOp.SUM, strategy=strategy)
87+
88+
89+
@trtllm_dist_all_reduce.register_fake
90+
def trtllm_dist_all_reduce_fake(tensor, strategy):
91+
return torch.empty_like(tensor)
92+
93+
94+
# TRT-LLM fused op (atomic - always uses TRT-LLM backend)
95+
@torch.library.custom_op(
96+
"dist::trtllm_fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda"
97+
)
98+
def trtllm_fused_allreduce_residual_rmsnorm(
99+
tensor: torch.Tensor,
100+
residual: torch.Tensor,
101+
norm_weight: torch.Tensor,
102+
eps: float,
103+
strategy: str,
104+
) -> tuple[torch.Tensor, torch.Tensor]:
105+
"""Fused allreduce + residual + rmsnorm using TRT-LLM optimized kernel.
106+
107+
This op always uses TRT-LLM's fused kernel and is used in MPI mode.
108+
"""
109+
all_reduce_params = AllReduceParams(
110+
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
111+
bias=None,
112+
residual=residual,
113+
norm_weight=norm_weight,
114+
eps=eps,
115+
)
116+
return trtllm_allreduce(
117+
tensor, ReduceOp.SUM, strategy=strategy, all_reduce_params=all_reduce_params
118+
)
119+
120+
121+
@trtllm_fused_allreduce_residual_rmsnorm.register_fake
122+
def trtllm_fused_allreduce_residual_rmsnorm_fake(
123+
tensor: torch.Tensor,
124+
residual: torch.Tensor,
125+
norm_weight: torch.Tensor,
126+
eps: float,
127+
strategy: str,
128+
) -> tuple[torch.Tensor, torch.Tensor]:
129+
return torch.empty_like(tensor), torch.empty_like(tensor)
130+
131+
132+
def is_trtllm_op_available():
133+
"""Check if TRT-LLM ops are available and running with MPI."""
134+
return is_ompi()

0 commit comments

Comments
 (0)