Skip to content

Commit afc52d7

Browse files
authored
[https://nvbugs/5647400] [fix] Enlarged the AllReduce workspace size to 64MB. Added AllReduce strategy to AD config. (#9145)
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
1 parent 899fda9 commit afc52d7

File tree

15 files changed

+585
-73
lines changed

15 files changed

+585
-73
lines changed

cpp/tensorrt_llm/thop/allreduceOp.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -959,6 +959,13 @@ class AllreduceOp
959959
// MIN_LATENCY.
960960
if (mStrategy != AllReduceStrategyType::AUTO)
961961
{
962+
// Check TWOSHOT constraint: seq_len >= tp_size
963+
if (mStrategy == AllReduceStrategyType::TWOSHOT && seq_len < mGroup.size())
964+
{
965+
TLLM_LOG_WARNING("TWOSHOT strategy requires seq_len >= tp_size (%zu < %zu), falling back to ONESHOT",
966+
seq_len, mGroup.size());
967+
return AllReduceStrategyType::ONESHOT;
968+
}
962969
return mStrategy;
963970
}
964971
else

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ transforms:
7979
sharding_source: ['factory','heuristic']
8080
support_partial_config: true
8181
sharding_dims: ['tp', 'ep', 'bmm']
82+
allreduce_strategy: 'AUTO'
8283
requires_shape_prop: true
8384
sharding_transform_executor:
8485
stage: sharding

tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,23 @@ def all_gather_fake(tensor, dim=0):
2626

2727

2828
@torch.library.custom_op("auto_deploy::torch_dist_all_reduce", mutates_args=(), device_types="cuda")
29-
def all_reduce(t: torch.Tensor) -> torch.Tensor:
30-
"""All_reduce across the ranks. Reduction op is SUM.
29+
def all_reduce(t: torch.Tensor, strategy: str) -> torch.Tensor:
30+
"""All_reduce across the ranks. Reduction op is SUM. Strategy is MANDATORY.
31+
32+
Args:
33+
t: Tensor to reduce across ranks
34+
strategy: AllReduce strategy - "AUTO", "NCCL", "ONESHOT", "TWOSHOT", "MIN_LATENCY", etc.
3135
3236
NOTE: this op requires an extra memory copy and should ONLY be used for debugging + testing. For
3337
efficient all_reduce ops one should write/replace it with a fused op.
3438
"""
3539
if trtllm_dist.is_trtllm_op_available():
36-
return trtllm_dist.trtllm_allreduce(t, op=dist.ReduceOp.SUM)
40+
return trtllm_dist.trtllm_allreduce(t, op=dist.ReduceOp.SUM, strategy=strategy)
3741
t_res = t.clone()
3842
dist.all_reduce(t_res, op=dist.ReduceOp.SUM)
3943
return t_res
4044

4145

4246
@all_reduce.register_fake
43-
def all_reduce_fake(tensor):
47+
def all_reduce_fake(tensor, strategy):
4448
return torch.empty_like(tensor)

tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,16 @@ def simple_fake(input, weight, bias):
3434
"auto_deploy::trtllm_dist_fused_linear_all_reduce", mutates_args=(), device_types="cuda"
3535
)
3636
def fused_linear_all_reduce(
37-
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
37+
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], strategy: str
3838
) -> torch.Tensor:
39-
"""Fused linear followed by all_reduce on the output."""
39+
"""Fused linear followed by all_reduce on the output. Strategy is MANDATORY."""
4040
output = torch.ops.aten.linear(input, weight, bias)
4141
if trtllm_dist.is_trtllm_op_available():
42-
return trtllm_dist.trtllm_allreduce(output, op=dist.ReduceOp.SUM)
42+
return trtllm_dist.trtllm_allreduce(output, op=dist.ReduceOp.SUM, strategy=strategy)
4343
dist.all_reduce(output, op=dist.ReduceOp.SUM)
4444
return output
4545

4646

4747
@fused_linear_all_reduce.register_fake
48-
def fused_linear_all_reduce_fake(input, weight, bias):
48+
def fused_linear_all_reduce_fake(input, weight, bias, strategy):
4949
return torch.ops.aten.linear(input, weight, bias)

tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ def fp8_linear_fake(
245245
def fused_fp8_linear_all_reduce(
246246
input: torch.Tensor,
247247
weight_fp8: torch.Tensor,
248+
strategy: str,
248249
bias: Optional[torch.Tensor] = None,
249250
input_scale: Optional[torch.Tensor] = None,
250251
weight_scale: Optional[torch.Tensor] = None,
@@ -253,7 +254,7 @@ def fused_fp8_linear_all_reduce(
253254
input, weight_fp8, bias, input_scale, weight_scale
254255
)
255256
if trtllm_dist.is_trtllm_op_available():
256-
return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM)
257+
return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM, strategy=strategy)
257258
dist.all_reduce(out, op=dist.ReduceOp.SUM)
258259
return out
259260

@@ -262,6 +263,7 @@ def fused_fp8_linear_all_reduce(
262263
def fused_fp8_linear_all_reduce_fake(
263264
input: torch.Tensor,
264265
weight_fp8: torch.Tensor,
266+
strategy: str,
265267
bias: Optional[torch.Tensor] = None,
266268
input_scale: Optional[torch.Tensor] = None,
267269
weight_scale: Optional[torch.Tensor] = None,

tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,26 @@ def trtllm_allgather(tensor, dim, sizes=None):
1818
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
1919
return allgather(tensor, p_config, dim=dim, sizes=sizes)
2020

21-
def trtllm_allreduce(tensor, op, all_reduce_params=None):
21+
def trtllm_allreduce(tensor, op, strategy: str, all_reduce_params=None):
2222
rank, world_size = get_rank_world_size()
2323
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."
2424

25-
# Cache key includes rank, world_size, and dtype to handle different configurations
26-
cache_key = (rank, world_size, tensor.dtype)
25+
# Convert string strategy to enum
26+
try:
27+
strategy_enum = getattr(AllReduceStrategy, strategy)
28+
except AttributeError:
29+
raise ValueError(
30+
f"Invalid allreduce strategy: {strategy}. "
31+
f"Valid options: AUTO, NCCL, ONESHOT, TWOSHOT, MIN_LATENCY, "
32+
f"LOWPRECISION, UB, MNNVL, NCCL_SYMMETRIC"
33+
)
34+
35+
# Cache key includes rank, world_size, dtype, and strategy to handle different configurations
36+
cache_key = (rank, world_size, tensor.dtype, strategy_enum)
2737
if cache_key not in _allreduce_cache:
2838
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
29-
# Use Strategy.AUTO for optimal performance
3039
_allreduce_cache[cache_key] = AllReduce(
31-
mapping=p_config, strategy=AllReduceStrategy.NCCL, dtype=tensor.dtype
40+
mapping=p_config, strategy=strategy_enum, dtype=tensor.dtype
3241
)
3342

3443
torch_op = _allreduce_cache[cache_key]
@@ -38,7 +47,11 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None):
3847
"dist::fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda"
3948
)
4049
def fused_allreduce_residual_rmsnorm(
41-
tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float
50+
tensor: torch.Tensor,
51+
residual: torch.Tensor,
52+
norm_weight: torch.Tensor,
53+
eps: float,
54+
strategy: str = "AUTO",
4255
) -> tuple[torch.Tensor, torch.Tensor]:
4356
"""Fusing allreduce, residual (add), and hf_rms_norm together.
4457
@@ -54,7 +67,9 @@ def fused_allreduce_residual_rmsnorm(
5467
norm_weight=norm_weight,
5568
eps=eps,
5669
)
57-
return trtllm_allreduce(tensor, ReduceOp.SUM, all_reduce_params=all_reduce_params)
70+
return trtllm_allreduce(
71+
tensor, ReduceOp.SUM, strategy=strategy, all_reduce_params=all_reduce_params
72+
)
5873
else:
5974
# Fallback: unfused implementation using torch distributed
6075
# This is used in demollm mode without MPI
@@ -79,7 +94,11 @@ def fused_allreduce_residual_rmsnorm(
7994

8095
@fused_allreduce_residual_rmsnorm.register_fake
8196
def fused_allreduce_residual_rmsnorm_fake(
82-
tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float
97+
tensor: torch.Tensor,
98+
residual: torch.Tensor,
99+
norm_weight: torch.Tensor,
100+
eps: float,
101+
strategy: str = "AUTO",
83102
) -> tuple[torch.Tensor, torch.Tensor]:
84103
return torch.empty_like(tensor), torch.empty_like(tensor)
85104

@@ -89,7 +108,7 @@ def fused_allreduce_residual_rmsnorm_fake(
89108
def trtllm_allgather(tensor, dim, sizes=None):
90109
raise ImportError("TRT-LLM is not available.")
91110

92-
def trtllm_allreduce(tensor, op):
111+
def trtllm_allreduce(tensor, op, strategy: str, all_reduce_params=None):
93112
raise ImportError("TRT-LLM is not available.")
94113

95114
TRTLLM_OP_AVAILABLE = False

tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def _allreduce_residual_rmsnorm_pattern(
2828
"""
2929

3030
input_dtype = x.dtype
31-
hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x)
31+
hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x, "AUTO")
3232
add = residual + hidden_states
3333

3434
hidden_states = add.to(torch.float32)
@@ -52,7 +52,7 @@ def _allreduce_residual_rmsnorm_pattern2(
5252
"""
5353

5454
input_dtype = x.dtype
55-
hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x)
55+
hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x, "AUTO")
5656
add = hidden_states + residual
5757

5858
hidden_states = add.to(torch.float32)

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
from typing import DefaultDict, Dict, List, Set, Tuple, Type
2323

2424
import torch
25-
from pydantic import Field
25+
from pydantic import Field, field_validator
2626
from torch.fx import GraphModule, Node
2727

28+
from .....functional import AllReduceStrategy
2829
from ...models.factory import ModelFactory, ShardingConfigSource
2930
from ...shim.interface import CachedSequenceInterface
3031
from ...utils.logger import ad_logger
@@ -49,6 +50,7 @@
4950
SplitDimension,
5051
WeightShardingInfo,
5152
get_all_weights_in_subgraph,
53+
validate_allreduce_strategy,
5254
)
5355
from ..interface import (
5456
BaseTransform,
@@ -152,6 +154,18 @@ class ShardingTransformConfig(TransformConfig):
152154
sharding_dims: List[ShardingDim] = Field(
153155
default_factory=lambda: [ShardingDim.SSM, ShardingDim.TP, ShardingDim.EP, ShardingDim.BMM]
154156
)
157+
allreduce_strategy: AllReduceStrategy = Field(
158+
default=AllReduceStrategy.AUTO,
159+
description="AllReduce strategy for distributed operations. "
160+
"Options: AUTO (automatic selection), NCCL, ONESHOT, TWOSHOT, MIN_LATENCY, "
161+
"LOWPRECISION, UB, MNNVL, NCCL_SYMMETRIC",
162+
)
163+
164+
@field_validator("allreduce_strategy", mode="before")
165+
@classmethod
166+
def _validate_allreduce_strategy(cls, v):
167+
"""Convert string names like 'AUTO' to AllReduceStrategy enum."""
168+
return validate_allreduce_strategy(v)
155169

156170

157171
@TransformRegistry.register("detect_sharding")
@@ -199,6 +213,8 @@ def _apply(
199213
sharding_config = shared_config.sharding_config
200214
sharding_config.rank = local_rank
201215
sharding_config.world_size = world_size
216+
sharding_config.allreduce_strategy = self.config.allreduce_strategy
217+
ad_logger.info(f"Using allreduce strategy: {sharding_config.allreduce_strategy.name}")
202218
sharding_config.predefined_config = factory.get_sharding_config() if factory else {}
203219
sharding_config.factory_source = (
204220
sharding_config.predefined_config.get("source", ShardingConfigSource.UNKNOWN)
@@ -573,7 +589,7 @@ def detect_sharding_from_factory_config(
573589
# we have a match. Get the config for this layer
574590
config = tp_plan[key]
575591
if config == "colwise":
576-
sharding_config.weight_sharding_transforms.append(
592+
if sharding_config.add(
577593
WeightShardingInfo.from_node(
578594
lin_node,
579595
split_dim=SplitDimension.COLUMN,
@@ -582,10 +598,10 @@ def detect_sharding_from_factory_config(
582598
dist_op=None,
583599
min_local_shape=min_local_shape,
584600
)
585-
)
586-
num_row_col_shards += 1
601+
):
602+
num_row_col_shards += 1
587603
elif config == "rowwise":
588-
sharding_config.weight_sharding_transforms.append(
604+
if sharding_config.add(
589605
WeightShardingInfo.from_node(
590606
lin_node,
591607
split_dim=SplitDimension.ROW,
@@ -594,10 +610,10 @@ def detect_sharding_from_factory_config(
594610
dist_op="all_reduce",
595611
min_local_shape=min_local_shape,
596612
)
597-
)
598-
num_row_col_shards += 1
613+
):
614+
num_row_col_shards += 1
599615
elif config == "mamba":
600-
sharding_config.weight_sharding_transforms.append(
616+
sharding_config.add(
601617
WeightShardingInfo.from_node(
602618
lin_node,
603619
split_dim=SplitDimension.COLUMN,
@@ -618,7 +634,7 @@ def detect_sharding_from_factory_config(
618634
if "shared" in module_name:
619635
col_row_action = config.replace("local_", "")
620636
if col_row_action == "colwise":
621-
sharding_config.weight_sharding_transforms.append(
637+
sharding_config.add(
622638
WeightShardingInfo(
623639
target_node=lin_node.name,
624640
split_dim=SplitDimension.COLUMN,
@@ -629,7 +645,7 @@ def detect_sharding_from_factory_config(
629645
)
630646
)
631647
elif col_row_action == "rowwise":
632-
sharding_config.weight_sharding_transforms.append(
648+
if sharding_config.add(
633649
WeightShardingInfo(
634650
target_node=lin_node.name,
635651
split_dim=SplitDimension.ROW,
@@ -638,8 +654,8 @@ def detect_sharding_from_factory_config(
638654
dist_op="all_reduce",
639655
min_local_shape=min_local_shape,
640656
)
641-
)
642-
num_row_col_shards += 1
657+
):
658+
num_row_col_shards += 1
643659
else:
644660
ad_logger.warning(f"Unsupported sharding action {config}. Skipping.")
645661
else:
@@ -648,7 +664,7 @@ def detect_sharding_from_factory_config(
648664

649665
elif "gather" in config:
650666
# Simple shard (row + all_gather)
651-
sharding_config.weight_sharding_transforms.append(
667+
if sharding_config.add(
652668
WeightShardingInfo.from_node(
653669
lin_node,
654670
split_dim=SplitDimension.COLUMN,
@@ -657,13 +673,13 @@ def detect_sharding_from_factory_config(
657673
dist_op="all_gather",
658674
min_local_shape=1,
659675
)
660-
)
661-
num_simple_shards += 1
676+
):
677+
num_simple_shards += 1
662678
else:
663679
ad_logger.warning(
664680
f"Unsupported sharding action {config}. Fallback to simple shard"
665681
)
666-
sharding_config.weight_sharding_transforms.append(
682+
sharding_config.add(
667683
WeightShardingInfo.from_node(
668684
lin_node,
669685
split_dim=SplitDimension.COLUMN,
@@ -943,7 +959,7 @@ def detect_column_row_shard(
943959
)
944960

945961
# shard single row node
946-
sharding_config.weight_sharding_transforms.append(
962+
if sharding_config.add(
947963
WeightShardingInfo.from_node(
948964
nodes_to_row_shard[0],
949965
split_dim=SplitDimension.ROW,
@@ -952,9 +968,8 @@ def detect_column_row_shard(
952968
dist_op="all_reduce",
953969
min_local_shape=min_local_shape,
954970
)
955-
)
956-
957-
num_row_col_shards += 1
971+
):
972+
num_row_col_shards += 1
958973

959974
ad_logger.info(
960975
f"Found {num_shards} TP shards (simple: {num_simple_shards}, row-col: {num_row_col_shards})"
@@ -1020,7 +1035,7 @@ def detect_dp_bmm_shard(gm: GraphModule, sharding_config: ShardingConfig) -> Tra
10201035
start_idx = remainder + rank * base_size
10211036
end_idx = start_idx + base_size
10221037

1023-
sharding_config.bmm_transforms.append(
1038+
sharding_config.add(
10241039
BMMShardingInfo(
10251040
target_node=node.name,
10261041
rank=rank,
@@ -1064,14 +1079,14 @@ def detect_ep_shard(gm: GraphModule, sharding_config: ShardingConfig) -> Transfo
10641079
),
10651080
):
10661081
continue
1067-
sharding_config.ep_transforms.append(
1082+
if sharding_config.add(
10681083
EPShardingInfo.from_node(
10691084
node,
10701085
rank=rank,
10711086
world_size=world_size,
10721087
)
1073-
)
1074-
num_moe_patterns += 1
1088+
):
1089+
num_moe_patterns += 1
10751090

10761091
ad_logger.info(f"Found {num_moe_patterns} MoE patterns")
10771092

0 commit comments

Comments
 (0)