diff --git a/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py b/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py index 10b2bf4e251..f7c3343be22 100644 --- a/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py +++ b/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py @@ -26,9 +26,9 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None): cache_key = (rank, world_size, tensor.dtype) if cache_key not in _allreduce_cache: p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) - # Use Strategy.AUTO for optimal performance + # Use SYMM_MEM strategy (tries symmetric memory first, falls back to AUTO if needed) _allreduce_cache[cache_key] = AllReduce( - mapping=p_config, strategy=AllReduceStrategy.AUTO, dtype=tensor.dtype + mapping=p_config, strategy=AllReduceStrategy.SYMM_MEM, dtype=tensor.dtype ) torch_op = _allreduce_cache[cache_key] diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 4d7853794b2..4406b236f58 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -7,6 +7,8 @@ import torch from torch import nn +from tensorrt_llm._torch.distributed.symm_mem_allreduce import \ + SymmetricMemoryAllReduce from tensorrt_llm._utils import mpi_comm, mpi_disabled from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams, @@ -566,13 +568,17 @@ def __init__(self, strategy (AllReduceStrategy): The following all-reduce strategies are supported: + - SYMM_MEM: Uses PyTorch's symmetric memory with MULTIMEM hardware instructions. + Falls back automatically if not supported. + - UB: AllReduce uses user-buffer based all-reduce kernel. - NCCL: Use NCCL allreduce. - MIN_LATENCY: AllReduce uses MIN_LATENCY mode kernel. - - AUTO: AUTO chooses between NCCL and MIN_LATENCY mode based on a heuristic policy. + - AUTO: AUTO chooses the best available strategy. Will try MNNVL, + then choose between NCCL and MIN_LATENCY based on a heuristic policy. - LOWPRECISION: AllReduce quantizes data to lower precision for transmission. Should only be used on topologies with PCIe switches and without NVLink. @@ -601,20 +607,51 @@ def __init__(self, self.workspace = None self.strategy = strategy self.mnnvl_allreduce = None + self.symm_mem_allreduce = None self._disable_mpi = mpi_disabled() self.all_reduce_op = torch.ops.trtllm.allreduce_pg if self._disable_mpi else torch.ops.trtllm.allreduce if self.mapping.tp_size > 1: - # When Strategy is UB, it is guaranteed that the workspace is not used. + # Initialize Symmetric Memory AllReduce if needed (before workspace allocation) + if self.strategy == AllReduceStrategy.SYMM_MEM: + try: + symm_mem = SymmetricMemoryAllReduce( + self.mapping, + dtype=dtype if dtype else torch.bfloat16, + ) + if not symm_mem.disabled: + self.symm_mem_allreduce = symm_mem + logger.info( + f"SymmetricMemoryAllReduce (MULTIMEM) is enabled with fallback support for world_size={self.mapping.tp_size}" + ) + # Keep SYMM_MEM strategy but allocate workspace for fallback to regular allreduce + else: + logger.info( + f"SymmetricMemoryAllReduce is disabled (not supported or unavailable), falling back to AUTO strategy" + ) + # Fall back to AUTO if SYMM_MEM can't be enabled + self.strategy = AllReduceStrategy.AUTO + except Exception as e: + logger.info( + f"Symmetric Memory AllReduce can't be enabled due to {e}, falling back to AUTO strategy" + ) + self.symm_mem_allreduce = None + # Fall back to AUTO if SYMM_MEM initialization fails + self.strategy = AllReduceStrategy.AUTO + + # Allocate workspace for strategies that need it + # Note: SYMM_MEM now also needs workspace for fallback scenarios (fused ops, etc.) + # Only UB doesn't need workspace if self.strategy != AllReduceStrategy.UB: if self.strategy == AllReduceStrategy.LOWPRECISION: allocate_low_presicion_allreduce_workspace(self.mapping) self.workspace = get_allreduce_workspace(self.mapping) - # Initialize MNNVL AllReduce if needed + # Initialize MNNVL if using AUTO or MNNVL strategy if self.strategy in (AllReduceStrategy.AUTO, AllReduceStrategy.MNNVL): + # Try to initialize MNNVL if MNNVLAllReduce.is_mnnvl(self.mapping, dtype): # ALWAYS capture the exception when creating this instance try: @@ -670,20 +707,39 @@ def forward( if all_reduce_params is None: all_reduce_params = AllReduceParams() - # Try MNNVL AllReduce first if available + # Try Symmetric Memory AllReduce first if available + # Note: Currently only supports NONE fusion op (plain allreduce) + if self.symm_mem_allreduce and all_reduce_params.fusion_op == AllReduceFusionOp.NONE: + symm_mem_output = self.symm_mem_allreduce(input) + if symm_mem_output is not None: + logger.debug( + f"Using SymmetricMemoryAllReduce (MULTIMEM) for input shape {input.shape}" + ) + return symm_mem_output + elif self.symm_mem_allreduce and all_reduce_params.fusion_op != AllReduceFusionOp.NONE: + # Log once per rank that we're skipping symm_mem due to fusion + if not hasattr(self, '_logged_fusion_skip'): + logger.debug( + f"Skipping SymmetricMemoryAllReduce for fused operation (fusion_op={all_reduce_params.fusion_op}), using regular allreduce" + ) + self._logged_fusion_skip = True + + # Try MNNVL AllReduce if symm_mem didn't handle it if self.mnnvl_allreduce: mnnvl_output = self.mnnvl_allreduce( input, all_reduce_params=all_reduce_params) if mnnvl_output is not None: return mnnvl_output - # Fall back to regular AllReduce if MNNVL is not available or not applicable - # Make sure the strategy is AUTO since allreduceOp does not have the branch for MNNVL - if allreduce_strategy == AllReduceStrategy.MNNVL: + # Fall back to regular AllReduce if specialized methods are not available or not applicable + # Make sure the strategy is AUTO since allreduceOp does not have the branch for MNNVL/SYMM_MEM + if allreduce_strategy in (AllReduceStrategy.MNNVL, + AllReduceStrategy.SYMM_MEM): allreduce_strategy = AllReduceStrategy.AUTO additional_args = {} if self._disable_mpi: + # Get ProcessGroup from mapping pg = self.mapping.tp_group_pg assert pg is not None, "TP ProcessGroup not initialised" additional_args = { diff --git a/tensorrt_llm/_torch/distributed/symm_mem_allreduce.py b/tensorrt_llm/_torch/distributed/symm_mem_allreduce.py new file mode 100644 index 00000000000..f02f20b786e --- /dev/null +++ b/tensorrt_llm/_torch/distributed/symm_mem_allreduce.py @@ -0,0 +1,246 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +""" +Symmetric Memory AllReduce + +This module provides PyTorch Symmetric Memory-based allreduce operations, +leveraging MULTIMEM hardware instructions. +""" + +from typing import Optional + +import torch +import torch.distributed as dist +from torch import nn + +from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping + +try: + import torch.distributed._symmetric_memory as torch_symm_mem + + SYMM_MEM_AVAILABLE = True +except ImportError: + SYMM_MEM_AVAILABLE = False + logger.warning( + "PyTorch symmetric memory not available. Install PyTorch >= 2.8 for MULTIMEM support." + ) + + +class SymmetricMemoryAllReduce(nn.Module): + """ + AllReduce implementation using PyTorch's symmetric memory operations. + This leverages MULTIMEM hardware instructions for faster allreduce operations. + + Supported configurations (world_size): + - SM 9.0: 4, 6, 8 GPUs + - SM 10.0: 6, 8 GPUs + + """ + + # World sizes that support MULTIMEM instructions + _WORLD_SIZES_MULTIMEM = { + "9.0": [4, 6, 8], + "10.0": [6, 8], + } + + # Maximum buffer sizes for symmetric memory (bytes) + _MAX_SIZES = { + "9.0": { + 4: 8 * 1024 * 1024, # 8MB for 4 GPUs + 6: 6 * 1024 * 1024, # 6MB for 6 GPUs + 8: 4 * 1024 * 1024, # 4MB for 8 GPUs + }, + "10.0": { + 6: 8 * 1024 * 1024, + 8: 6 * 1024 * 1024, + }, + } + + def __init__( + self, + mapping: Mapping, + dtype: torch.dtype = torch.bfloat16, + group: Optional[dist.ProcessGroup] = None, + ): + super().__init__() + + self.disabled = True + self.mapping = mapping + self.dtype = dtype + self.world_size = mapping.tp_size + + if not SYMM_MEM_AVAILABLE: + logger.warning("SymmetricMemoryAllReduce: PyTorch symm_mem not available") + return + + if not torch.cuda.is_available(): + logger.warning("SymmetricMemoryAllReduce: CUDA not available") + return + + # Get device capability + device = torch.device(f"cuda:{mapping.tp_rank}") + capability = torch.cuda.get_device_capability(device) + self.device_capability = f"{capability[0]}.{capability[1]}" + + # Check if this configuration is supported + if self.device_capability not in self._MAX_SIZES: + logger.warning( + f"SymmetricMemoryAllReduce: Device capability {self.device_capability} not supported" + ) + return + + if self.world_size not in self._MAX_SIZES[self.device_capability]: + logger.info( + f"SymmetricMemoryAllReduce: World size {self.world_size} not supported " + f"for SM {self.device_capability}" + ) + return + + # Get max buffer size for this configuration + self.max_size = self._MAX_SIZES[self.device_capability][self.world_size] + + # Set up process group + if group is None: + # Get or create TP group with correct ranks + # For TP parallelism, we need ranks [0, 1, 2, ..., tp_size-1] globally + # NOT starting from tp_rank! + if not dist.is_initialized(): + logger.warning("SymmetricMemoryAllReduce: torch.distributed not initialized") + self.disabled = True + return + + # Get actual TP group ranks from mapping (tp_group is a property, not a method) + tp_group_ranks = mapping.tp_group + self.group = dist.new_group(tp_group_ranks) if len(tp_group_ranks) > 1 else None + else: + self.group = group + + if self.group is None: + logger.warning("SymmetricMemoryAllReduce: No valid process group") + self.disabled = True + return + + # Enable symmetric memory for this group + try: + # Get group_name - this may fail if ProcessGroup doesn't have group_name set + if not hasattr(self.group, "group_name"): + logger.warning( + "SymmetricMemoryAllReduce: ProcessGroup does not have group_name attribute" + ) + self.disabled = True + return + + group_name_str = str(self.group.group_name) + torch_symm_mem.enable_symm_mem_for_group(group_name_str) + logger.debug( + f"SymmetricMemoryAllReduce: Enabled symmetric memory for group {group_name_str}" + ) + except Exception as e: + logger.warning( + f"SymmetricMemoryAllReduce: Failed to enable symmetric memory for group: {e}" + ) + import traceback + + logger.debug(traceback.format_exc()) + self.disabled = True + return + + # Allocate symmetric memory buffer + try: + self.buffer = torch_symm_mem.empty( + self.max_size // self.dtype.itemsize, + device=device, + dtype=self.dtype, + ) + # Pass group name string + group_name_str = str(self.group.group_name) + handle = torch_symm_mem.rendezvous(self.buffer, group_name_str) + + if handle.multicast_ptr == 0: + logger.warning( + "SymmetricMemoryAllReduce: MULTIMEM operations not supported (multicast_ptr is 0)" + ) + return + + # Only enable if MULTIMEM is supported + # Otherwise, no benefit over existing TensorRT-LLM strategies + use_multimem = self.world_size in self._WORLD_SIZES_MULTIMEM.get( + self.device_capability, [] + ) + + if not use_multimem: + logger.info( + f"SymmetricMemoryAllReduce: MULTIMEM not supported for " + f"world_size={self.world_size}, SM={self.device_capability}. " + f"Falling back to standard allreduce strategies." + ) + return + + self.disabled = False + logger.info( + f"SymmetricMemoryAllReduce (MULTIMEM) initialized: " + f"world_size={self.world_size}, " + f"max_size={self.max_size}, " + f"SM={self.device_capability}" + ) + + except Exception as e: + logger.warning(f"SymmetricMemoryAllReduce initialization failed: {e}") + return + + @property + def process_group(self) -> Optional[dist.ProcessGroup]: + """Expose the ProcessGroup for use in fallback scenarios.""" + return self.group if not self.disabled else None + + def should_use_symm_mem(self, inp: torch.Tensor) -> bool: + """Check if symmetric memory can be used for this tensor.""" + if self.disabled: + return False + if inp.dtype != self.dtype: + return False + inp_size = inp.numel() * inp.element_size() + if inp_size % 4 != 0: + return False + if inp_size >= self.max_size: + return False + return True + + def forward( + self, + inp: torch.Tensor, + out: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Perform allreduce using symmetric memory operations. + + Args: + inp: Input tensor to reduce + out: Optional output tensor (if None, will be allocated) + + Returns: + Reduced tensor + """ + if not self.should_use_symm_mem(inp): + return None # Caller should fall back to other strategy + + if out is None: + out = torch.empty_like(inp) + + # Copy input to symmetric memory buffer + self.buffer[: inp.numel()].copy_(inp.view(-1)) + + # Perform MULTIMEM allreduce + # Pass group name string (matching vLLM's implementation) + group_name_str = str(self.group.group_name) + torch.ops.symm_mem.multimem_all_reduce_( + self.buffer[: inp.numel()], + "sum", + group_name_str, + ) + + # Copy result back + out.copy_(self.buffer[: inp.numel()].view(out.shape)) + + return out diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index 282febd262e..129cc33b183 100755 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -3883,6 +3883,7 @@ class AllReduceStrategy(IntEnum): LOWPRECISION = 6 MNNVL = 7 NCCL_SYMMETRIC = 8 + SYMM_MEM = 9 # PyTorch symmetric memory with MULTIMEM class AllReduceFusionOp(IntEnum):