Skip to content

Commit f877823

Browse files
authored
[#8781][fix] Cache the AllReduce wrapper to avoid re-allocating workspace which caused a hang (#8803)
Signed-off-by: Eran Geva <[email protected]>
1 parent da73410 commit f877823

File tree

1 file changed

+16
-3
lines changed
  • tensorrt_llm/_torch/auto_deploy/distributed

1 file changed

+16
-3
lines changed

tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
from ...distributed import AllReduce, allgather
99
from ...modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy
1010

11+
# Cache AllReduce modules to avoid recreating on every call
12+
# This is critical for CUDA graph compatibility - recreating modules during
13+
# warmup causes hangs due to workspace allocation with CPU synchronization
14+
_allreduce_cache = {}
15+
1116
def trtllm_allgather(tensor, dim, sizes=None):
1217
rank, world_size = get_rank_world_size()
1318
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
@@ -16,9 +21,17 @@ def trtllm_allgather(tensor, dim, sizes=None):
1621
def trtllm_allreduce(tensor, op, all_reduce_params=None):
1722
rank, world_size = get_rank_world_size()
1823
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."
19-
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
20-
# Use Strategy.NCCL until https://nvbugspro.nvidia.com/bug/5331013 is fixed, then change to Strategy.AUTO
21-
torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy.NCCL)
24+
25+
# Cache key includes rank, world_size, and dtype to handle different configurations
26+
cache_key = (rank, world_size, tensor.dtype)
27+
if cache_key not in _allreduce_cache:
28+
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
29+
# Use Strategy.AUTO for optimal performance
30+
_allreduce_cache[cache_key] = AllReduce(
31+
mapping=p_config, strategy=AllReduceStrategy.AUTO, dtype=tensor.dtype
32+
)
33+
34+
torch_op = _allreduce_cache[cache_key]
2235
return torch_op(tensor, all_reduce_params=all_reduce_params)
2336

2437
@torch.library.custom_op(

0 commit comments

Comments
 (0)