Skip to content

Commit 3192b19

Browse files
committed
Added symm mem strategy
Signed-off-by: Eran Geva <[email protected]>
1 parent 70e4d72 commit 3192b19

File tree

3 files changed

+270
-5
lines changed

3 files changed

+270
-5
lines changed

tensorrt_llm/_torch/distributed/ops.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import torch
88
from torch import nn
99

10+
from tensorrt_llm._torch.distributed.symm_mem_allreduce import \
11+
SymmetricMemoryAllReduce
1012
from tensorrt_llm._utils import mpi_comm, mpi_disabled
1113
from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer
1214
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
@@ -516,13 +518,19 @@ def __init__(self,
516518
strategy (AllReduceStrategy):
517519
The following all-reduce strategies are supported:
518520
521+
- SYMM_MEM: Uses PyTorch's symmetric memory with MULTIMEM hardware instructions (H100+).
522+
Provides 3x faster performance on supported configurations (4/6/8 GPUs on H100).
523+
Currently only supports plain allreduce (NONE fusion op). Falls back automatically
524+
if not supported.
525+
519526
- UB: AllReduce uses user-buffer based all-reduce kernel.
520527
521528
- NCCL: Use NCCL allreduce.
522529
523530
- MIN_LATENCY: AllReduce uses MIN_LATENCY mode kernel.
524531
525-
- AUTO: AUTO chooses between NCCL and MIN_LATENCY mode based on a heuristic policy.
532+
- AUTO: AUTO chooses the best available strategy. Will try SYMM_MEM first (if available),
533+
then MNNVL, then choose between NCCL and MIN_LATENCY based on a heuristic policy.
526534
527535
- LOWPRECISION: AllReduce quantizes data to lower precision for transmission.
528536
Should only be used on topologies with PCIe switches and without NVLink.
@@ -551,6 +559,7 @@ def __init__(self,
551559
self.workspace = None
552560
self.strategy = strategy
553561
self.mnnvl_allreduce = None
562+
self.symm_mem_allreduce = None
554563
self._disable_mpi = mpi_disabled()
555564

556565
self.all_reduce_op = torch.ops.trtllm.allreduce_pg if self._disable_mpi else torch.ops.trtllm.allreduce
@@ -562,6 +571,29 @@ def __init__(self,
562571
allocate_low_presicion_allreduce_workspace(self.mapping)
563572
self.workspace = get_allreduce_workspace(self.mapping)
564573

574+
# Initialize Symmetric Memory AllReduce if needed (H100+ hardware acceleration)
575+
if self.strategy in (AllReduceStrategy.AUTO,
576+
AllReduceStrategy.SYMM_MEM):
577+
try:
578+
symm_mem = SymmetricMemoryAllReduce(
579+
self.mapping,
580+
dtype=dtype if dtype else torch.bfloat16,
581+
)
582+
if not symm_mem.disabled:
583+
self.symm_mem_allreduce = symm_mem
584+
logger.info(
585+
f"SymmetricMemoryAllReduce (MULTIMEM) is enabled for world_size={self.mapping.tp_size}"
586+
)
587+
else:
588+
logger.debug(
589+
f"SymmetricMemoryAllReduce is disabled (not supported or unavailable)"
590+
)
591+
except Exception as e:
592+
logger.debug(
593+
f"Symmetric Memory AllReduce can't be enabled due to {e}."
594+
)
595+
self.symm_mem_allreduce = None
596+
565597
# Initialize MNNVL AllReduce if needed
566598
if self.strategy in (AllReduceStrategy.AUTO,
567599
AllReduceStrategy.MNNVL):
@@ -626,16 +658,27 @@ def forward(
626658
if all_reduce_params is None:
627659
all_reduce_params = AllReduceParams()
628660

629-
# Try MNNVL AllReduce first if available
661+
# Try Symmetric Memory AllReduce first if available (H100+ hardware acceleration)
662+
# Note: Currently only supports NONE fusion op (plain allreduce)
663+
if self.symm_mem_allreduce and all_reduce_params.fusion_op == AllReduceFusionOp.NONE:
664+
symm_mem_output = self.symm_mem_allreduce(input)
665+
if symm_mem_output is not None:
666+
logger.debug(
667+
f"Using SymmetricMemoryAllReduce (MULTIMEM) for input shape {input.shape}"
668+
)
669+
return symm_mem_output
670+
671+
# Try MNNVL AllReduce if symm_mem didn't handle it
630672
if self.mnnvl_allreduce:
631673
mnnvl_output = self.mnnvl_allreduce(
632674
input, all_reduce_params=all_reduce_params)
633675
if mnnvl_output is not None:
634676
return mnnvl_output
635677

636-
# Fall back to regular AllReduce if MNNVL is not available or not applicable
637-
# Make sure the strategy is AUTO since allreduceOp does not have the branch for MNNVL
638-
if allreduce_strategy == AllReduceStrategy.MNNVL:
678+
# Fall back to regular AllReduce if specialized methods are not available or not applicable
679+
# Make sure the strategy is AUTO since allreduceOp does not have the branch for MNNVL/SYMM_MEM
680+
if allreduce_strategy in (AllReduceStrategy.MNNVL,
681+
AllReduceStrategy.SYMM_MEM):
639682
allreduce_strategy = AllReduceStrategy.AUTO
640683

641684
additional_args = {}
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
"""
4+
Symmetric Memory AllReduce for H100+ GPUs
5+
6+
This module provides PyTorch Symmetric Memory-based allreduce operations,
7+
leveraging H100's MULTIMEM hardware instructions for 3x faster performance
8+
compared to custom CUDA kernels on supported configurations.
9+
"""
10+
11+
from typing import Optional
12+
13+
import torch
14+
import torch.distributed as dist
15+
from torch import nn
16+
17+
from tensorrt_llm.logger import logger
18+
from tensorrt_llm.mapping import Mapping
19+
20+
try:
21+
import torch.distributed._symmetric_memory as torch_symm_mem
22+
SYMM_MEM_AVAILABLE = True
23+
except ImportError:
24+
SYMM_MEM_AVAILABLE = False
25+
logger.warning(
26+
"PyTorch symmetric memory not available. Install PyTorch >= 2.8 for MULTIMEM support."
27+
)
28+
29+
30+
class SymmetricMemoryAllReduce(nn.Module):
31+
"""
32+
AllReduce implementation using PyTorch's symmetric memory operations.
33+
34+
This leverages H100's MULTIMEM hardware instructions for significantly faster
35+
allreduce operations compared to software implementations.
36+
37+
Supported configurations (world_size):
38+
- SM 9.0 (H100): 4, 6, 8 GPUs
39+
- SM 10.0 (future): 6, 8 GPUs
40+
41+
Based on vLLM's implementation but integrated into TensorRT-LLM.
42+
"""
43+
44+
# World sizes that support MULTIMEM instructions
45+
_WORLD_SIZES_MULTIMEM = {
46+
"9.0": [4, 6, 8], # H100
47+
"10.0": [6, 8], # Future architectures
48+
}
49+
50+
# Maximum buffer sizes for symmetric memory (bytes)
51+
_MAX_SIZES = {
52+
"9.0": {
53+
4: 8 * 1024 * 1024, # 8MB for 4 GPUs
54+
6: 6 * 1024 * 1024, # 6MB for 6 GPUs
55+
8: 4 * 1024 * 1024, # 4MB for 8 GPUs
56+
},
57+
"10.0": {
58+
6: 8 * 1024 * 1024,
59+
8: 6 * 1024 * 1024,
60+
}
61+
}
62+
63+
def __init__(
64+
self,
65+
mapping: Mapping,
66+
dtype: torch.dtype = torch.bfloat16,
67+
group: Optional[dist.ProcessGroup] = None,
68+
):
69+
super().__init__()
70+
71+
self.disabled = True
72+
self.mapping = mapping
73+
self.dtype = dtype
74+
self.world_size = mapping.tp_size
75+
76+
if not SYMM_MEM_AVAILABLE:
77+
logger.warning(
78+
"SymmetricMemoryAllReduce: PyTorch symm_mem not available")
79+
return
80+
81+
if not torch.cuda.is_available():
82+
logger.warning("SymmetricMemoryAllReduce: CUDA not available")
83+
return
84+
85+
# Get device capability
86+
device = torch.device(f"cuda:{mapping.tp_rank}")
87+
capability = torch.cuda.get_device_capability(device)
88+
self.device_capability = f"{capability[0]}.{capability[1]}"
89+
90+
# Check if this configuration is supported
91+
if self.device_capability not in self._MAX_SIZES:
92+
logger.warning(
93+
f"SymmetricMemoryAllReduce: Device capability {self.device_capability} not supported"
94+
)
95+
return
96+
97+
if self.world_size not in self._MAX_SIZES[self.device_capability]:
98+
logger.info(
99+
f"SymmetricMemoryAllReduce: World size {self.world_size} not supported "
100+
f"for SM {self.device_capability}")
101+
return
102+
103+
# Get max buffer size for this configuration
104+
self.max_size = self._MAX_SIZES[self.device_capability][self.world_size]
105+
106+
# Set up process group
107+
if group is None:
108+
# Get or create TP group with correct ranks
109+
# For TP parallelism, we need ranks [0, 1, 2, ..., tp_size-1] globally
110+
# NOT starting from tp_rank!
111+
if not dist.is_initialized():
112+
logger.warning(
113+
"SymmetricMemoryAllReduce: torch.distributed not initialized"
114+
)
115+
self.disabled = True
116+
return
117+
118+
# Assume contiguous TP ranks for now
119+
# TODO: Get actual TP group from mapping if available
120+
tp_group_ranks = list(range(mapping.tp_size))
121+
self.group = dist.new_group(tp_group_ranks) if len(
122+
tp_group_ranks) > 1 else None
123+
else:
124+
self.group = group
125+
126+
if self.group is None:
127+
logger.warning("SymmetricMemoryAllReduce: No valid process group")
128+
self.disabled = True
129+
return
130+
131+
# Allocate symmetric memory buffer
132+
try:
133+
self.buffer = torch_symm_mem.empty(
134+
self.max_size // self.dtype.itemsize,
135+
device=device,
136+
dtype=self.dtype,
137+
)
138+
# Pass group_name (string) not the group object
139+
handle = torch_symm_mem.rendezvous(self.buffer,
140+
self.group.group_name)
141+
142+
if handle.multicast_ptr == 0:
143+
logger.warning(
144+
"SymmetricMemoryAllReduce: MULTIMEM operations not supported (multicast_ptr is 0)"
145+
)
146+
return
147+
148+
# Determine which algorithm to use
149+
self.use_multimem = (self.world_size
150+
in self._WORLD_SIZES_MULTIMEM.get(
151+
self.device_capability, []))
152+
153+
self.disabled = False
154+
logger.info(f"SymmetricMemoryAllReduce initialized: "
155+
f"world_size={self.world_size}, "
156+
f"max_size={self.max_size}, "
157+
f"SM={self.device_capability}, "
158+
f"use_multimem={self.use_multimem}")
159+
160+
except Exception as e:
161+
logger.warning(
162+
f"SymmetricMemoryAllReduce initialization failed: {e}")
163+
return
164+
165+
def should_use_symm_mem(self, inp: torch.Tensor) -> bool:
166+
"""Check if symmetric memory can be used for this tensor."""
167+
if self.disabled:
168+
return False
169+
if inp.dtype != self.dtype:
170+
return False
171+
inp_size = inp.numel() * inp.element_size()
172+
if inp_size % 4 != 0:
173+
return False
174+
if inp_size >= self.max_size:
175+
return False
176+
return True
177+
178+
def forward(
179+
self,
180+
inp: torch.Tensor,
181+
out: Optional[torch.Tensor] = None,
182+
) -> torch.Tensor:
183+
"""
184+
Perform allreduce using symmetric memory operations.
185+
186+
Args:
187+
inp: Input tensor to reduce
188+
out: Optional output tensor (if None, will be allocated)
189+
190+
Returns:
191+
Reduced tensor
192+
"""
193+
if not self.should_use_symm_mem(inp):
194+
return None # Caller should fall back to other strategy
195+
196+
if out is None:
197+
out = torch.empty_like(inp)
198+
199+
# Copy input to symmetric memory buffer
200+
self.buffer[:inp.numel()].copy_(inp.view(-1))
201+
202+
# Perform allreduce using appropriate algorithm
203+
if self.use_multimem:
204+
# Use MULTIMEM hardware instructions (faster)
205+
torch.ops.symm_mem.multimem_all_reduce_(
206+
self.buffer[:inp.numel()],
207+
"sum",
208+
self.group.group_name,
209+
)
210+
else:
211+
# Use two-shot algorithm (fallback)
212+
torch.ops.symm_mem.two_shot_all_reduce_(
213+
self.buffer[:inp.numel()],
214+
"sum",
215+
self.group.group_name,
216+
)
217+
218+
# Copy result back
219+
out.copy_(self.buffer[:inp.numel()].view(out.shape))
220+
221+
return out

tensorrt_llm/functional.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3883,6 +3883,7 @@ class AllReduceStrategy(IntEnum):
38833883
LOWPRECISION = 6
38843884
MNNVL = 7
38853885
NCCL_SYMMETRIC = 8
3886+
SYMM_MEM = 9 # PyTorch symmetric memory with MULTIMEM (H100+)
38863887

38873888

38883889
class AllReduceFusionOp(IntEnum):

0 commit comments

Comments
 (0)