Skip to content

Commit 3cb8ff4

Browse files
committed
CR fixes
Signed-off-by: Eran Geva <[email protected]>
1 parent 3192b19 commit 3cb8ff4

File tree

3 files changed

+51
-61
lines changed

3 files changed

+51
-61
lines changed

tensorrt_llm/_torch/distributed/ops.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -518,10 +518,8 @@ def __init__(self,
518518
strategy (AllReduceStrategy):
519519
The following all-reduce strategies are supported:
520520
521-
- SYMM_MEM: Uses PyTorch's symmetric memory with MULTIMEM hardware instructions (H100+).
522-
Provides 3x faster performance on supported configurations (4/6/8 GPUs on H100).
523-
Currently only supports plain allreduce (NONE fusion op). Falls back automatically
524-
if not supported.
521+
- SYMM_MEM: Uses PyTorch's symmetric memory with MULTIMEM hardware instructions.
522+
Falls back automatically if not supported.
525523
526524
- UB: AllReduce uses user-buffer based all-reduce kernel.
527525
@@ -571,7 +569,7 @@ def __init__(self,
571569
allocate_low_presicion_allreduce_workspace(self.mapping)
572570
self.workspace = get_allreduce_workspace(self.mapping)
573571

574-
# Initialize Symmetric Memory AllReduce if needed (H100+ hardware acceleration)
572+
# Initialize Symmetric Memory AllReduce if needed
575573
if self.strategy in (AllReduceStrategy.AUTO,
576574
AllReduceStrategy.SYMM_MEM):
577575
try:
@@ -658,7 +656,7 @@ def forward(
658656
if all_reduce_params is None:
659657
all_reduce_params = AllReduceParams()
660658

661-
# Try Symmetric Memory AllReduce first if available (H100+ hardware acceleration)
659+
# Try Symmetric Memory AllReduce first if available
662660
# Note: Currently only supports NONE fusion op (plain allreduce)
663661
if self.symm_mem_allreduce and all_reduce_params.fusion_op == AllReduceFusionOp.NONE:
664662
symm_mem_output = self.symm_mem_allreduce(input)

tensorrt_llm/_torch/distributed/symm_mem_allreduce.py

Lines changed: 46 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
"""
4-
Symmetric Memory AllReduce for H100+ GPUs
4+
Symmetric Memory AllReduce
55
66
This module provides PyTorch Symmetric Memory-based allreduce operations,
7-
leveraging H100's MULTIMEM hardware instructions for 3x faster performance
8-
compared to custom CUDA kernels on supported configurations.
7+
leveraging MULTIMEM hardware instructions.
98
"""
109

1110
from typing import Optional
@@ -19,6 +18,7 @@
1918

2019
try:
2120
import torch.distributed._symmetric_memory as torch_symm_mem
21+
2222
SYMM_MEM_AVAILABLE = True
2323
except ImportError:
2424
SYMM_MEM_AVAILABLE = False
@@ -30,21 +30,18 @@
3030
class SymmetricMemoryAllReduce(nn.Module):
3131
"""
3232
AllReduce implementation using PyTorch's symmetric memory operations.
33-
34-
This leverages H100's MULTIMEM hardware instructions for significantly faster
35-
allreduce operations compared to software implementations.
33+
This leverages MULTIMEM hardware instructions for faster allreduce operations.
3634
3735
Supported configurations (world_size):
38-
- SM 9.0 (H100): 4, 6, 8 GPUs
39-
- SM 10.0 (future): 6, 8 GPUs
36+
- SM 9.0: 4, 6, 8 GPUs
37+
- SM 10.0: 6, 8 GPUs
4038
41-
Based on vLLM's implementation but integrated into TensorRT-LLM.
4239
"""
4340

4441
# World sizes that support MULTIMEM instructions
4542
_WORLD_SIZES_MULTIMEM = {
46-
"9.0": [4, 6, 8], # H100
47-
"10.0": [6, 8], # Future architectures
43+
"9.0": [4, 6, 8],
44+
"10.0": [6, 8],
4845
}
4946

5047
# Maximum buffer sizes for symmetric memory (bytes)
@@ -57,7 +54,7 @@ class SymmetricMemoryAllReduce(nn.Module):
5754
"10.0": {
5855
6: 8 * 1024 * 1024,
5956
8: 6 * 1024 * 1024,
60-
}
57+
},
6158
}
6259

6360
def __init__(
@@ -74,8 +71,7 @@ def __init__(
7471
self.world_size = mapping.tp_size
7572

7673
if not SYMM_MEM_AVAILABLE:
77-
logger.warning(
78-
"SymmetricMemoryAllReduce: PyTorch symm_mem not available")
74+
logger.warning("SymmetricMemoryAllReduce: PyTorch symm_mem not available")
7975
return
8076

8177
if not torch.cuda.is_available():
@@ -97,7 +93,8 @@ def __init__(
9793
if self.world_size not in self._MAX_SIZES[self.device_capability]:
9894
logger.info(
9995
f"SymmetricMemoryAllReduce: World size {self.world_size} not supported "
100-
f"for SM {self.device_capability}")
96+
f"for SM {self.device_capability}"
97+
)
10198
return
10299

103100
# Get max buffer size for this configuration
@@ -109,17 +106,13 @@ def __init__(
109106
# For TP parallelism, we need ranks [0, 1, 2, ..., tp_size-1] globally
110107
# NOT starting from tp_rank!
111108
if not dist.is_initialized():
112-
logger.warning(
113-
"SymmetricMemoryAllReduce: torch.distributed not initialized"
114-
)
109+
logger.warning("SymmetricMemoryAllReduce: torch.distributed not initialized")
115110
self.disabled = True
116111
return
117112

118-
# Assume contiguous TP ranks for now
119-
# TODO: Get actual TP group from mapping if available
120-
tp_group_ranks = list(range(mapping.tp_size))
121-
self.group = dist.new_group(tp_group_ranks) if len(
122-
tp_group_ranks) > 1 else None
113+
# Get actual TP group ranks from mapping
114+
tp_group_ranks = mapping.tp_group()
115+
self.group = dist.new_group(tp_group_ranks) if len(tp_group_ranks) > 1 else None
123116
else:
124117
self.group = group
125118

@@ -136,30 +129,38 @@ def __init__(
136129
dtype=self.dtype,
137130
)
138131
# Pass group_name (string) not the group object
139-
handle = torch_symm_mem.rendezvous(self.buffer,
140-
self.group.group_name)
132+
handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name)
141133

142134
if handle.multicast_ptr == 0:
143135
logger.warning(
144136
"SymmetricMemoryAllReduce: MULTIMEM operations not supported (multicast_ptr is 0)"
145137
)
146138
return
147139

148-
# Determine which algorithm to use
149-
self.use_multimem = (self.world_size
150-
in self._WORLD_SIZES_MULTIMEM.get(
151-
self.device_capability, []))
140+
# Only enable if MULTIMEM is supported
141+
# Otherwise, no benefit over existing TensorRT-LLM strategies
142+
use_multimem = self.world_size in self._WORLD_SIZES_MULTIMEM.get(
143+
self.device_capability, []
144+
)
145+
146+
if not use_multimem:
147+
logger.info(
148+
f"SymmetricMemoryAllReduce: MULTIMEM not supported for "
149+
f"world_size={self.world_size}, SM={self.device_capability}. "
150+
f"Falling back to standard allreduce strategies."
151+
)
152+
return
152153

153154
self.disabled = False
154-
logger.info(f"SymmetricMemoryAllReduce initialized: "
155-
f"world_size={self.world_size}, "
156-
f"max_size={self.max_size}, "
157-
f"SM={self.device_capability}, "
158-
f"use_multimem={self.use_multimem}")
155+
logger.info(
156+
f"SymmetricMemoryAllReduce (MULTIMEM) initialized: "
157+
f"world_size={self.world_size}, "
158+
f"max_size={self.max_size}, "
159+
f"SM={self.device_capability}"
160+
)
159161

160162
except Exception as e:
161-
logger.warning(
162-
f"SymmetricMemoryAllReduce initialization failed: {e}")
163+
logger.warning(f"SymmetricMemoryAllReduce initialization failed: {e}")
163164
return
164165

165166
def should_use_symm_mem(self, inp: torch.Tensor) -> bool:
@@ -197,25 +198,16 @@ def forward(
197198
out = torch.empty_like(inp)
198199

199200
# Copy input to symmetric memory buffer
200-
self.buffer[:inp.numel()].copy_(inp.view(-1))
201-
202-
# Perform allreduce using appropriate algorithm
203-
if self.use_multimem:
204-
# Use MULTIMEM hardware instructions (faster)
205-
torch.ops.symm_mem.multimem_all_reduce_(
206-
self.buffer[:inp.numel()],
207-
"sum",
208-
self.group.group_name,
209-
)
210-
else:
211-
# Use two-shot algorithm (fallback)
212-
torch.ops.symm_mem.two_shot_all_reduce_(
213-
self.buffer[:inp.numel()],
214-
"sum",
215-
self.group.group_name,
216-
)
201+
self.buffer[: inp.numel()].copy_(inp.view(-1))
202+
203+
# Perform MULTIMEM allreduce
204+
torch.ops.symm_mem.multimem_all_reduce_(
205+
self.buffer[: inp.numel()],
206+
"sum",
207+
self.group.group_name,
208+
)
217209

218210
# Copy result back
219-
out.copy_(self.buffer[:inp.numel()].view(out.shape))
211+
out.copy_(self.buffer[: inp.numel()].view(out.shape))
220212

221213
return out

tensorrt_llm/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3883,7 +3883,7 @@ class AllReduceStrategy(IntEnum):
38833883
LOWPRECISION = 6
38843884
MNNVL = 7
38853885
NCCL_SYMMETRIC = 8
3886-
SYMM_MEM = 9 # PyTorch symmetric memory with MULTIMEM (H100+)
3886+
SYMM_MEM = 9 # PyTorch symmetric memory with MULTIMEM
38873887

38883888

38893889
class AllReduceFusionOp(IntEnum):

0 commit comments

Comments
 (0)