Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None):
cache_key = (rank, world_size, tensor.dtype)
if cache_key not in _allreduce_cache:
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
# Use Strategy.AUTO for optimal performance
# Use SYMM_MEM strategy (tries symmetric memory first, falls back to AUTO if needed)
_allreduce_cache[cache_key] = AllReduce(
mapping=p_config, strategy=AllReduceStrategy.AUTO, dtype=tensor.dtype
mapping=p_config, strategy=AllReduceStrategy.SYMM_MEM, dtype=tensor.dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MrGeva : we should make this configurable and users should be able to specify the choice through inference optimizer config. Would you be able to take on the task of plumbing the AllReduceStrategy from the top level default.yaml -> sharding transformation -> inserting the all reduce op with the specified strategy?

)

torch_op = _allreduce_cache[cache_key]
Expand Down
70 changes: 63 additions & 7 deletions tensorrt_llm/_torch/distributed/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import torch
from torch import nn

from tensorrt_llm._torch.distributed.symm_mem_allreduce import \
SymmetricMemoryAllReduce
from tensorrt_llm._utils import mpi_comm, mpi_disabled
from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
Expand Down Expand Up @@ -566,13 +568,17 @@ def __init__(self,
strategy (AllReduceStrategy):
The following all-reduce strategies are supported:
- SYMM_MEM: Uses PyTorch's symmetric memory with MULTIMEM hardware instructions.
Falls back automatically if not supported.
- UB: AllReduce uses user-buffer based all-reduce kernel.
- NCCL: Use NCCL allreduce.
- MIN_LATENCY: AllReduce uses MIN_LATENCY mode kernel.
- AUTO: AUTO chooses between NCCL and MIN_LATENCY mode based on a heuristic policy.
- AUTO: AUTO chooses the best available strategy. Will try MNNVL,
then choose between NCCL and MIN_LATENCY based on a heuristic policy.
- LOWPRECISION: AllReduce quantizes data to lower precision for transmission.
Should only be used on topologies with PCIe switches and without NVLink.
Expand Down Expand Up @@ -601,20 +607,51 @@ def __init__(self,
self.workspace = None
self.strategy = strategy
self.mnnvl_allreduce = None
self.symm_mem_allreduce = None
self._disable_mpi = mpi_disabled()

self.all_reduce_op = torch.ops.trtllm.allreduce_pg if self._disable_mpi else torch.ops.trtllm.allreduce

if self.mapping.tp_size > 1:
# When Strategy is UB, it is guaranteed that the workspace is not used.
# Initialize Symmetric Memory AllReduce if needed (before workspace allocation)
if self.strategy == AllReduceStrategy.SYMM_MEM:
try:
symm_mem = SymmetricMemoryAllReduce(
self.mapping,
dtype=dtype if dtype else torch.bfloat16,
)
if not symm_mem.disabled:
self.symm_mem_allreduce = symm_mem
logger.info(
f"SymmetricMemoryAllReduce (MULTIMEM) is enabled with fallback support for world_size={self.mapping.tp_size}"
)
# Keep SYMM_MEM strategy but allocate workspace for fallback to regular allreduce
else:
logger.info(
f"SymmetricMemoryAllReduce is disabled (not supported or unavailable), falling back to AUTO strategy"
)
# Fall back to AUTO if SYMM_MEM can't be enabled
self.strategy = AllReduceStrategy.AUTO
except Exception as e:
logger.info(
f"Symmetric Memory AllReduce can't be enabled due to {e}, falling back to AUTO strategy"
)
self.symm_mem_allreduce = None
# Fall back to AUTO if SYMM_MEM initialization fails
self.strategy = AllReduceStrategy.AUTO

# Allocate workspace for strategies that need it
# Note: SYMM_MEM now also needs workspace for fallback scenarios (fused ops, etc.)
# Only UB doesn't need workspace
if self.strategy != AllReduceStrategy.UB:
if self.strategy == AllReduceStrategy.LOWPRECISION:
allocate_low_presicion_allreduce_workspace(self.mapping)
self.workspace = get_allreduce_workspace(self.mapping)

# Initialize MNNVL AllReduce if needed
# Initialize MNNVL if using AUTO or MNNVL strategy
if self.strategy in (AllReduceStrategy.AUTO,
AllReduceStrategy.MNNVL):
# Try to initialize MNNVL
if MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
# ALWAYS capture the exception when creating this instance
try:
Expand Down Expand Up @@ -670,20 +707,39 @@ def forward(
if all_reduce_params is None:
all_reduce_params = AllReduceParams()

# Try MNNVL AllReduce first if available
# Try Symmetric Memory AllReduce first if available
# Note: Currently only supports NONE fusion op (plain allreduce)
if self.symm_mem_allreduce and all_reduce_params.fusion_op == AllReduceFusionOp.NONE:
symm_mem_output = self.symm_mem_allreduce(input)
if symm_mem_output is not None:
logger.debug(
f"Using SymmetricMemoryAllReduce (MULTIMEM) for input shape {input.shape}"
)
return symm_mem_output
elif self.symm_mem_allreduce and all_reduce_params.fusion_op != AllReduceFusionOp.NONE:
# Log once per rank that we're skipping symm_mem due to fusion
if not hasattr(self, '_logged_fusion_skip'):
logger.debug(
f"Skipping SymmetricMemoryAllReduce for fused operation (fusion_op={all_reduce_params.fusion_op}), using regular allreduce"
)
self._logged_fusion_skip = True

# Try MNNVL AllReduce if symm_mem didn't handle it
if self.mnnvl_allreduce:
mnnvl_output = self.mnnvl_allreduce(
input, all_reduce_params=all_reduce_params)
if mnnvl_output is not None:
return mnnvl_output

# Fall back to regular AllReduce if MNNVL is not available or not applicable
# Make sure the strategy is AUTO since allreduceOp does not have the branch for MNNVL
if allreduce_strategy == AllReduceStrategy.MNNVL:
# Fall back to regular AllReduce if specialized methods are not available or not applicable
# Make sure the strategy is AUTO since allreduceOp does not have the branch for MNNVL/SYMM_MEM
if allreduce_strategy in (AllReduceStrategy.MNNVL,
AllReduceStrategy.SYMM_MEM):
allreduce_strategy = AllReduceStrategy.AUTO

additional_args = {}
if self._disable_mpi:
# Get ProcessGroup from mapping
pg = self.mapping.tp_group_pg
assert pg is not None, "TP ProcessGroup not initialised"
additional_args = {
Expand Down
246 changes: 246 additions & 0 deletions tensorrt_llm/_torch/distributed/symm_mem_allreduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
"""
Symmetric Memory AllReduce

This module provides PyTorch Symmetric Memory-based allreduce operations,
leveraging MULTIMEM hardware instructions.
"""

from typing import Optional

import torch
import torch.distributed as dist
from torch import nn

from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping

try:
import torch.distributed._symmetric_memory as torch_symm_mem

SYMM_MEM_AVAILABLE = True
except ImportError:
SYMM_MEM_AVAILABLE = False
logger.warning(
"PyTorch symmetric memory not available. Install PyTorch >= 2.8 for MULTIMEM support."
)


class SymmetricMemoryAllReduce(nn.Module):
"""
AllReduce implementation using PyTorch's symmetric memory operations.
This leverages MULTIMEM hardware instructions for faster allreduce operations.

Supported configurations (world_size):
- SM 9.0: 4, 6, 8 GPUs
- SM 10.0: 6, 8 GPUs

"""

# World sizes that support MULTIMEM instructions
_WORLD_SIZES_MULTIMEM = {
"9.0": [4, 6, 8],
"10.0": [6, 8],
}

# Maximum buffer sizes for symmetric memory (bytes)
_MAX_SIZES = {
"9.0": {
4: 8 * 1024 * 1024, # 8MB for 4 GPUs
6: 6 * 1024 * 1024, # 6MB for 6 GPUs
8: 4 * 1024 * 1024, # 4MB for 8 GPUs
},
"10.0": {
6: 8 * 1024 * 1024,
8: 6 * 1024 * 1024,
},
}

def __init__(
self,
mapping: Mapping,
dtype: torch.dtype = torch.bfloat16,
group: Optional[dist.ProcessGroup] = None,
):
super().__init__()

self.disabled = True
self.mapping = mapping
self.dtype = dtype
self.world_size = mapping.tp_size

if not SYMM_MEM_AVAILABLE:
logger.warning("SymmetricMemoryAllReduce: PyTorch symm_mem not available")
return

if not torch.cuda.is_available():
logger.warning("SymmetricMemoryAllReduce: CUDA not available")
return

# Get device capability
device = torch.device(f"cuda:{mapping.tp_rank}")
capability = torch.cuda.get_device_capability(device)
self.device_capability = f"{capability[0]}.{capability[1]}"

# Check if this configuration is supported
if self.device_capability not in self._MAX_SIZES:
logger.warning(
f"SymmetricMemoryAllReduce: Device capability {self.device_capability} not supported"
)
return

if self.world_size not in self._MAX_SIZES[self.device_capability]:
logger.info(
f"SymmetricMemoryAllReduce: World size {self.world_size} not supported "
f"for SM {self.device_capability}"
)
return

# Get max buffer size for this configuration
self.max_size = self._MAX_SIZES[self.device_capability][self.world_size]

# Set up process group
if group is None:
# Get or create TP group with correct ranks
# For TP parallelism, we need ranks [0, 1, 2, ..., tp_size-1] globally
# NOT starting from tp_rank!
if not dist.is_initialized():
logger.warning("SymmetricMemoryAllReduce: torch.distributed not initialized")
self.disabled = True
return

# Get actual TP group ranks from mapping (tp_group is a property, not a method)
tp_group_ranks = mapping.tp_group
self.group = dist.new_group(tp_group_ranks) if len(tp_group_ranks) > 1 else None
else:
self.group = group

if self.group is None:
logger.warning("SymmetricMemoryAllReduce: No valid process group")
self.disabled = True
return

# Enable symmetric memory for this group
try:
# Get group_name - this may fail if ProcessGroup doesn't have group_name set
if not hasattr(self.group, "group_name"):
logger.warning(
"SymmetricMemoryAllReduce: ProcessGroup does not have group_name attribute"
)
self.disabled = True
return

group_name_str = str(self.group.group_name)
torch_symm_mem.enable_symm_mem_for_group(group_name_str)
logger.debug(
f"SymmetricMemoryAllReduce: Enabled symmetric memory for group {group_name_str}"
)
except Exception as e:
logger.warning(
f"SymmetricMemoryAllReduce: Failed to enable symmetric memory for group: {e}"
)
import traceback

logger.debug(traceback.format_exc())
self.disabled = True
return

# Allocate symmetric memory buffer
try:
self.buffer = torch_symm_mem.empty(
self.max_size // self.dtype.itemsize,
device=device,
dtype=self.dtype,
)
# Pass group name string
group_name_str = str(self.group.group_name)
handle = torch_symm_mem.rendezvous(self.buffer, group_name_str)

if handle.multicast_ptr == 0:
logger.warning(
"SymmetricMemoryAllReduce: MULTIMEM operations not supported (multicast_ptr is 0)"
)
return

# Only enable if MULTIMEM is supported
# Otherwise, no benefit over existing TensorRT-LLM strategies
use_multimem = self.world_size in self._WORLD_SIZES_MULTIMEM.get(
self.device_capability, []
)

if not use_multimem:
logger.info(
f"SymmetricMemoryAllReduce: MULTIMEM not supported for "
f"world_size={self.world_size}, SM={self.device_capability}. "
f"Falling back to standard allreduce strategies."
)
return

self.disabled = False
logger.info(
f"SymmetricMemoryAllReduce (MULTIMEM) initialized: "
f"world_size={self.world_size}, "
f"max_size={self.max_size}, "
f"SM={self.device_capability}"
)

except Exception as e:
logger.warning(f"SymmetricMemoryAllReduce initialization failed: {e}")
return

@property
def process_group(self) -> Optional[dist.ProcessGroup]:
"""Expose the ProcessGroup for use in fallback scenarios."""
return self.group if not self.disabled else None

def should_use_symm_mem(self, inp: torch.Tensor) -> bool:
"""Check if symmetric memory can be used for this tensor."""
if self.disabled:
return False
if inp.dtype != self.dtype:
return False
inp_size = inp.numel() * inp.element_size()
if inp_size % 4 != 0:
return False
if inp_size >= self.max_size:
return False
return True

def forward(
self,
inp: torch.Tensor,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Perform allreduce using symmetric memory operations.

Args:
inp: Input tensor to reduce
out: Optional output tensor (if None, will be allocated)

Returns:
Reduced tensor
"""
if not self.should_use_symm_mem(inp):
return None # Caller should fall back to other strategy

if out is None:
out = torch.empty_like(inp)

# Copy input to symmetric memory buffer
self.buffer[: inp.numel()].copy_(inp.view(-1))

# Perform MULTIMEM allreduce
# Pass group name string (matching vLLM's implementation)
group_name_str = str(self.group.group_name)
torch.ops.symm_mem.multimem_all_reduce_(
self.buffer[: inp.numel()],
"sum",
group_name_str,
)

# Copy result back
out.copy_(self.buffer[: inp.numel()].view(out.shape))

return out
1 change: 1 addition & 0 deletions tensorrt_llm/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3883,6 +3883,7 @@ class AllReduceStrategy(IntEnum):
LOWPRECISION = 6
MNNVL = 7
NCCL_SYMMETRIC = 8
SYMM_MEM = 9 # PyTorch symmetric memory with MULTIMEM


class AllReduceFusionOp(IntEnum):
Expand Down