11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33"""
4- Symmetric Memory AllReduce for H100+ GPUs
4+ Symmetric Memory AllReduce
55
66This module provides PyTorch Symmetric Memory-based allreduce operations,
7- leveraging H100's MULTIMEM hardware instructions for 3x faster performance
8- compared to custom CUDA kernels on supported configurations.
7+ leveraging MULTIMEM hardware instructions.
98"""
109
1110from typing import Optional
1918
2019try :
2120 import torch .distributed ._symmetric_memory as torch_symm_mem
21+
2222 SYMM_MEM_AVAILABLE = True
2323except ImportError :
2424 SYMM_MEM_AVAILABLE = False
3030class SymmetricMemoryAllReduce (nn .Module ):
3131 """
3232 AllReduce implementation using PyTorch's symmetric memory operations.
33-
34- This leverages H100's MULTIMEM hardware instructions for significantly faster
35- allreduce operations compared to software implementations.
33+ This leverages MULTIMEM hardware instructions for faster allreduce operations.
3634
3735 Supported configurations (world_size):
38- - SM 9.0 (H100) : 4, 6, 8 GPUs
39- - SM 10.0 (future) : 6, 8 GPUs
36+ - SM 9.0: 4, 6, 8 GPUs
37+ - SM 10.0: 6, 8 GPUs
4038
41- Based on vLLM's implementation but integrated into TensorRT-LLM.
4239 """
4340
4441 # World sizes that support MULTIMEM instructions
4542 _WORLD_SIZES_MULTIMEM = {
46- "9.0" : [4 , 6 , 8 ], # H100
47- "10.0" : [6 , 8 ], # Future architectures
43+ "9.0" : [4 , 6 , 8 ],
44+ "10.0" : [6 , 8 ],
4845 }
4946
5047 # Maximum buffer sizes for symmetric memory (bytes)
@@ -57,7 +54,7 @@ class SymmetricMemoryAllReduce(nn.Module):
5754 "10.0" : {
5855 6 : 8 * 1024 * 1024 ,
5956 8 : 6 * 1024 * 1024 ,
60- }
57+ },
6158 }
6259
6360 def __init__ (
@@ -74,8 +71,7 @@ def __init__(
7471 self .world_size = mapping .tp_size
7572
7673 if not SYMM_MEM_AVAILABLE :
77- logger .warning (
78- "SymmetricMemoryAllReduce: PyTorch symm_mem not available" )
74+ logger .warning ("SymmetricMemoryAllReduce: PyTorch symm_mem not available" )
7975 return
8076
8177 if not torch .cuda .is_available ():
@@ -97,7 +93,8 @@ def __init__(
9793 if self .world_size not in self ._MAX_SIZES [self .device_capability ]:
9894 logger .info (
9995 f"SymmetricMemoryAllReduce: World size { self .world_size } not supported "
100- f"for SM { self .device_capability } " )
96+ f"for SM { self .device_capability } "
97+ )
10198 return
10299
103100 # Get max buffer size for this configuration
@@ -109,17 +106,13 @@ def __init__(
109106 # For TP parallelism, we need ranks [0, 1, 2, ..., tp_size-1] globally
110107 # NOT starting from tp_rank!
111108 if not dist .is_initialized ():
112- logger .warning (
113- "SymmetricMemoryAllReduce: torch.distributed not initialized"
114- )
109+ logger .warning ("SymmetricMemoryAllReduce: torch.distributed not initialized" )
115110 self .disabled = True
116111 return
117112
118- # Assume contiguous TP ranks for now
119- # TODO: Get actual TP group from mapping if available
120- tp_group_ranks = list (range (mapping .tp_size ))
121- self .group = dist .new_group (tp_group_ranks ) if len (
122- tp_group_ranks ) > 1 else None
113+ # Get actual TP group ranks from mapping
114+ tp_group_ranks = mapping .tp_group ()
115+ self .group = dist .new_group (tp_group_ranks ) if len (tp_group_ranks ) > 1 else None
123116 else :
124117 self .group = group
125118
@@ -136,30 +129,38 @@ def __init__(
136129 dtype = self .dtype ,
137130 )
138131 # Pass group_name (string) not the group object
139- handle = torch_symm_mem .rendezvous (self .buffer ,
140- self .group .group_name )
132+ handle = torch_symm_mem .rendezvous (self .buffer , self .group .group_name )
141133
142134 if handle .multicast_ptr == 0 :
143135 logger .warning (
144136 "SymmetricMemoryAllReduce: MULTIMEM operations not supported (multicast_ptr is 0)"
145137 )
146138 return
147139
148- # Determine which algorithm to use
149- self .use_multimem = (self .world_size
150- in self ._WORLD_SIZES_MULTIMEM .get (
151- self .device_capability , []))
140+ # Only enable if MULTIMEM is supported
141+ # Otherwise, no benefit over existing TensorRT-LLM strategies
142+ use_multimem = self .world_size in self ._WORLD_SIZES_MULTIMEM .get (
143+ self .device_capability , []
144+ )
145+
146+ if not use_multimem :
147+ logger .info (
148+ f"SymmetricMemoryAllReduce: MULTIMEM not supported for "
149+ f"world_size={ self .world_size } , SM={ self .device_capability } . "
150+ f"Falling back to standard allreduce strategies."
151+ )
152+ return
152153
153154 self .disabled = False
154- logger .info (f"SymmetricMemoryAllReduce initialized: "
155- f"world_size={ self .world_size } , "
156- f"max_size={ self .max_size } , "
157- f"SM={ self .device_capability } , "
158- f"use_multimem={ self .use_multimem } " )
155+ logger .info (
156+ f"SymmetricMemoryAllReduce (MULTIMEM) initialized: "
157+ f"world_size={ self .world_size } , "
158+ f"max_size={ self .max_size } , "
159+ f"SM={ self .device_capability } "
160+ )
159161
160162 except Exception as e :
161- logger .warning (
162- f"SymmetricMemoryAllReduce initialization failed: { e } " )
163+ logger .warning (f"SymmetricMemoryAllReduce initialization failed: { e } " )
163164 return
164165
165166 def should_use_symm_mem (self , inp : torch .Tensor ) -> bool :
@@ -197,25 +198,16 @@ def forward(
197198 out = torch .empty_like (inp )
198199
199200 # Copy input to symmetric memory buffer
200- self .buffer [:inp .numel ()].copy_ (inp .view (- 1 ))
201-
202- # Perform allreduce using appropriate algorithm
203- if self .use_multimem :
204- # Use MULTIMEM hardware instructions (faster)
205- torch .ops .symm_mem .multimem_all_reduce_ (
206- self .buffer [:inp .numel ()],
207- "sum" ,
208- self .group .group_name ,
209- )
210- else :
211- # Use two-shot algorithm (fallback)
212- torch .ops .symm_mem .two_shot_all_reduce_ (
213- self .buffer [:inp .numel ()],
214- "sum" ,
215- self .group .group_name ,
216- )
201+ self .buffer [: inp .numel ()].copy_ (inp .view (- 1 ))
202+
203+ # Perform MULTIMEM allreduce
204+ torch .ops .symm_mem .multimem_all_reduce_ (
205+ self .buffer [: inp .numel ()],
206+ "sum" ,
207+ self .group .group_name ,
208+ )
217209
218210 # Copy result back
219- out .copy_ (self .buffer [:inp .numel ()].view (out .shape ))
211+ out .copy_ (self .buffer [: inp .numel ()].view (out .shape ))
220212
221213 return out
0 commit comments