88 from ...distributed import AllReduce , allgather
99 from ...modules .linear import AllReduceFusionOp , AllReduceParams , AllReduceStrategy
1010
11+ # Cache AllReduce modules to avoid recreating on every call
12+ # This is critical for CUDA graph compatibility - recreating modules during
13+ # warmup causes hangs due to workspace allocation with CPU synchronization
14+ _allreduce_cache = {}
15+
1116 def trtllm_allgather (tensor , dim , sizes = None ):
1217 rank , world_size = get_rank_world_size ()
1318 p_config = Mapping (world_size = world_size , tp_size = world_size , rank = rank )
@@ -16,9 +21,17 @@ def trtllm_allgather(tensor, dim, sizes=None):
1621 def trtllm_allreduce (tensor , op , all_reduce_params = None ):
1722 rank , world_size = get_rank_world_size ()
1823 assert op == ReduceOp .SUM , "TRT-LLM all reduce only supports SUM op."
19- p_config = Mapping (world_size = world_size , tp_size = world_size , rank = rank )
20- # Use Strategy.NCCL until https://nvbugspro.nvidia.com/bug/5331013 is fixed, then change to Strategy.AUTO
21- torch_op = AllReduce (mapping = p_config , strategy = AllReduceStrategy .NCCL )
24+
25+ # Cache key includes rank, world_size, and dtype to handle different configurations
26+ cache_key = (rank , world_size , tensor .dtype )
27+ if cache_key not in _allreduce_cache :
28+ p_config = Mapping (world_size = world_size , tp_size = world_size , rank = rank )
29+ # Use Strategy.AUTO for optimal performance
30+ _allreduce_cache [cache_key ] = AllReduce (
31+ mapping = p_config , strategy = AllReduceStrategy .AUTO , dtype = tensor .dtype
32+ )
33+
34+ torch_op = _allreduce_cache [cache_key ]
2235 return torch_op (tensor , all_reduce_params = all_reduce_params )
2336
2437 @torch .library .custom_op (
0 commit comments