From 964d1a302d27fd13e49dd11aa817467612981771 Mon Sep 17 00:00:00 2001 From: Eran Geva <19514940+MrGeva@users.noreply.github.com> Date: Tue, 4 Nov 2025 06:43:35 -0800 Subject: [PATCH 1/4] Added symm mem strategy Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com> --- tensorrt_llm/_torch/distributed/ops.py | 53 ++++- .../_torch/distributed/symm_mem_allreduce.py | 221 ++++++++++++++++++ tensorrt_llm/functional.py | 1 + 3 files changed, 270 insertions(+), 5 deletions(-) create mode 100644 tensorrt_llm/_torch/distributed/symm_mem_allreduce.py diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 4d7853794b2..fb1916e8755 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -7,6 +7,8 @@ import torch from torch import nn +from tensorrt_llm._torch.distributed.symm_mem_allreduce import \ + SymmetricMemoryAllReduce from tensorrt_llm._utils import mpi_comm, mpi_disabled from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams, @@ -566,13 +568,19 @@ def __init__(self, strategy (AllReduceStrategy): The following all-reduce strategies are supported: + - SYMM_MEM: Uses PyTorch's symmetric memory with MULTIMEM hardware instructions (H100+). + Provides 3x faster performance on supported configurations (4/6/8 GPUs on H100). + Currently only supports plain allreduce (NONE fusion op). Falls back automatically + if not supported. + - UB: AllReduce uses user-buffer based all-reduce kernel. - NCCL: Use NCCL allreduce. - MIN_LATENCY: AllReduce uses MIN_LATENCY mode kernel. - - AUTO: AUTO chooses between NCCL and MIN_LATENCY mode based on a heuristic policy. + - AUTO: AUTO chooses the best available strategy. Will try SYMM_MEM first (if available), + then MNNVL, then choose between NCCL and MIN_LATENCY based on a heuristic policy. - LOWPRECISION: AllReduce quantizes data to lower precision for transmission. Should only be used on topologies with PCIe switches and without NVLink. @@ -601,6 +609,7 @@ def __init__(self, self.workspace = None self.strategy = strategy self.mnnvl_allreduce = None + self.symm_mem_allreduce = None self._disable_mpi = mpi_disabled() self.all_reduce_op = torch.ops.trtllm.allreduce_pg if self._disable_mpi else torch.ops.trtllm.allreduce @@ -612,6 +621,29 @@ def __init__(self, allocate_low_presicion_allreduce_workspace(self.mapping) self.workspace = get_allreduce_workspace(self.mapping) + # Initialize Symmetric Memory AllReduce if needed (H100+ hardware acceleration) + if self.strategy in (AllReduceStrategy.AUTO, + AllReduceStrategy.SYMM_MEM): + try: + symm_mem = SymmetricMemoryAllReduce( + self.mapping, + dtype=dtype if dtype else torch.bfloat16, + ) + if not symm_mem.disabled: + self.symm_mem_allreduce = symm_mem + logger.info( + f"SymmetricMemoryAllReduce (MULTIMEM) is enabled for world_size={self.mapping.tp_size}" + ) + else: + logger.debug( + f"SymmetricMemoryAllReduce is disabled (not supported or unavailable)" + ) + except Exception as e: + logger.debug( + f"Symmetric Memory AllReduce can't be enabled due to {e}." + ) + self.symm_mem_allreduce = None + # Initialize MNNVL AllReduce if needed if self.strategy in (AllReduceStrategy.AUTO, AllReduceStrategy.MNNVL): @@ -670,16 +702,27 @@ def forward( if all_reduce_params is None: all_reduce_params = AllReduceParams() - # Try MNNVL AllReduce first if available + # Try Symmetric Memory AllReduce first if available (H100+ hardware acceleration) + # Note: Currently only supports NONE fusion op (plain allreduce) + if self.symm_mem_allreduce and all_reduce_params.fusion_op == AllReduceFusionOp.NONE: + symm_mem_output = self.symm_mem_allreduce(input) + if symm_mem_output is not None: + logger.debug( + f"Using SymmetricMemoryAllReduce (MULTIMEM) for input shape {input.shape}" + ) + return symm_mem_output + + # Try MNNVL AllReduce if symm_mem didn't handle it if self.mnnvl_allreduce: mnnvl_output = self.mnnvl_allreduce( input, all_reduce_params=all_reduce_params) if mnnvl_output is not None: return mnnvl_output - # Fall back to regular AllReduce if MNNVL is not available or not applicable - # Make sure the strategy is AUTO since allreduceOp does not have the branch for MNNVL - if allreduce_strategy == AllReduceStrategy.MNNVL: + # Fall back to regular AllReduce if specialized methods are not available or not applicable + # Make sure the strategy is AUTO since allreduceOp does not have the branch for MNNVL/SYMM_MEM + if allreduce_strategy in (AllReduceStrategy.MNNVL, + AllReduceStrategy.SYMM_MEM): allreduce_strategy = AllReduceStrategy.AUTO additional_args = {} diff --git a/tensorrt_llm/_torch/distributed/symm_mem_allreduce.py b/tensorrt_llm/_torch/distributed/symm_mem_allreduce.py new file mode 100644 index 00000000000..97d604e2f25 --- /dev/null +++ b/tensorrt_llm/_torch/distributed/symm_mem_allreduce.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +""" +Symmetric Memory AllReduce for H100+ GPUs + +This module provides PyTorch Symmetric Memory-based allreduce operations, +leveraging H100's MULTIMEM hardware instructions for 3x faster performance +compared to custom CUDA kernels on supported configurations. +""" + +from typing import Optional + +import torch +import torch.distributed as dist +from torch import nn + +from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping + +try: + import torch.distributed._symmetric_memory as torch_symm_mem + SYMM_MEM_AVAILABLE = True +except ImportError: + SYMM_MEM_AVAILABLE = False + logger.warning( + "PyTorch symmetric memory not available. Install PyTorch >= 2.8 for MULTIMEM support." + ) + + +class SymmetricMemoryAllReduce(nn.Module): + """ + AllReduce implementation using PyTorch's symmetric memory operations. + + This leverages H100's MULTIMEM hardware instructions for significantly faster + allreduce operations compared to software implementations. + + Supported configurations (world_size): + - SM 9.0 (H100): 4, 6, 8 GPUs + - SM 10.0 (future): 6, 8 GPUs + + Based on vLLM's implementation but integrated into TensorRT-LLM. + """ + + # World sizes that support MULTIMEM instructions + _WORLD_SIZES_MULTIMEM = { + "9.0": [4, 6, 8], # H100 + "10.0": [6, 8], # Future architectures + } + + # Maximum buffer sizes for symmetric memory (bytes) + _MAX_SIZES = { + "9.0": { + 4: 8 * 1024 * 1024, # 8MB for 4 GPUs + 6: 6 * 1024 * 1024, # 6MB for 6 GPUs + 8: 4 * 1024 * 1024, # 4MB for 8 GPUs + }, + "10.0": { + 6: 8 * 1024 * 1024, + 8: 6 * 1024 * 1024, + } + } + + def __init__( + self, + mapping: Mapping, + dtype: torch.dtype = torch.bfloat16, + group: Optional[dist.ProcessGroup] = None, + ): + super().__init__() + + self.disabled = True + self.mapping = mapping + self.dtype = dtype + self.world_size = mapping.tp_size + + if not SYMM_MEM_AVAILABLE: + logger.warning( + "SymmetricMemoryAllReduce: PyTorch symm_mem not available") + return + + if not torch.cuda.is_available(): + logger.warning("SymmetricMemoryAllReduce: CUDA not available") + return + + # Get device capability + device = torch.device(f"cuda:{mapping.tp_rank}") + capability = torch.cuda.get_device_capability(device) + self.device_capability = f"{capability[0]}.{capability[1]}" + + # Check if this configuration is supported + if self.device_capability not in self._MAX_SIZES: + logger.warning( + f"SymmetricMemoryAllReduce: Device capability {self.device_capability} not supported" + ) + return + + if self.world_size not in self._MAX_SIZES[self.device_capability]: + logger.info( + f"SymmetricMemoryAllReduce: World size {self.world_size} not supported " + f"for SM {self.device_capability}") + return + + # Get max buffer size for this configuration + self.max_size = self._MAX_SIZES[self.device_capability][self.world_size] + + # Set up process group + if group is None: + # Get or create TP group with correct ranks + # For TP parallelism, we need ranks [0, 1, 2, ..., tp_size-1] globally + # NOT starting from tp_rank! + if not dist.is_initialized(): + logger.warning( + "SymmetricMemoryAllReduce: torch.distributed not initialized" + ) + self.disabled = True + return + + # Assume contiguous TP ranks for now + # TODO: Get actual TP group from mapping if available + tp_group_ranks = list(range(mapping.tp_size)) + self.group = dist.new_group(tp_group_ranks) if len( + tp_group_ranks) > 1 else None + else: + self.group = group + + if self.group is None: + logger.warning("SymmetricMemoryAllReduce: No valid process group") + self.disabled = True + return + + # Allocate symmetric memory buffer + try: + self.buffer = torch_symm_mem.empty( + self.max_size // self.dtype.itemsize, + device=device, + dtype=self.dtype, + ) + # Pass group_name (string) not the group object + handle = torch_symm_mem.rendezvous(self.buffer, + self.group.group_name) + + if handle.multicast_ptr == 0: + logger.warning( + "SymmetricMemoryAllReduce: MULTIMEM operations not supported (multicast_ptr is 0)" + ) + return + + # Determine which algorithm to use + self.use_multimem = (self.world_size + in self._WORLD_SIZES_MULTIMEM.get( + self.device_capability, [])) + + self.disabled = False + logger.info(f"SymmetricMemoryAllReduce initialized: " + f"world_size={self.world_size}, " + f"max_size={self.max_size}, " + f"SM={self.device_capability}, " + f"use_multimem={self.use_multimem}") + + except Exception as e: + logger.warning( + f"SymmetricMemoryAllReduce initialization failed: {e}") + return + + def should_use_symm_mem(self, inp: torch.Tensor) -> bool: + """Check if symmetric memory can be used for this tensor.""" + if self.disabled: + return False + if inp.dtype != self.dtype: + return False + inp_size = inp.numel() * inp.element_size() + if inp_size % 4 != 0: + return False + if inp_size >= self.max_size: + return False + return True + + def forward( + self, + inp: torch.Tensor, + out: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Perform allreduce using symmetric memory operations. + + Args: + inp: Input tensor to reduce + out: Optional output tensor (if None, will be allocated) + + Returns: + Reduced tensor + """ + if not self.should_use_symm_mem(inp): + return None # Caller should fall back to other strategy + + if out is None: + out = torch.empty_like(inp) + + # Copy input to symmetric memory buffer + self.buffer[:inp.numel()].copy_(inp.view(-1)) + + # Perform allreduce using appropriate algorithm + if self.use_multimem: + # Use MULTIMEM hardware instructions (faster) + torch.ops.symm_mem.multimem_all_reduce_( + self.buffer[:inp.numel()], + "sum", + self.group.group_name, + ) + else: + # Use two-shot algorithm (fallback) + torch.ops.symm_mem.two_shot_all_reduce_( + self.buffer[:inp.numel()], + "sum", + self.group.group_name, + ) + + # Copy result back + out.copy_(self.buffer[:inp.numel()].view(out.shape)) + + return out diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index 282febd262e..8e05194d5a3 100755 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -3883,6 +3883,7 @@ class AllReduceStrategy(IntEnum): LOWPRECISION = 6 MNNVL = 7 NCCL_SYMMETRIC = 8 + SYMM_MEM = 9 # PyTorch symmetric memory with MULTIMEM (H100+) class AllReduceFusionOp(IntEnum): From 668b514ee4c8717671913c7e2375d84eabee2441 Mon Sep 17 00:00:00 2001 From: Eran Geva <19514940+MrGeva@users.noreply.github.com> Date: Tue, 4 Nov 2025 12:18:05 -0800 Subject: [PATCH 2/4] CR fixes Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com> --- tensorrt_llm/_torch/distributed/ops.py | 10 +- .../_torch/distributed/symm_mem_allreduce.py | 100 ++++++++---------- tensorrt_llm/functional.py | 2 +- 3 files changed, 51 insertions(+), 61 deletions(-) diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index fb1916e8755..867fac9df63 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -568,10 +568,8 @@ def __init__(self, strategy (AllReduceStrategy): The following all-reduce strategies are supported: - - SYMM_MEM: Uses PyTorch's symmetric memory with MULTIMEM hardware instructions (H100+). - Provides 3x faster performance on supported configurations (4/6/8 GPUs on H100). - Currently only supports plain allreduce (NONE fusion op). Falls back automatically - if not supported. + - SYMM_MEM: Uses PyTorch's symmetric memory with MULTIMEM hardware instructions. + Falls back automatically if not supported. - UB: AllReduce uses user-buffer based all-reduce kernel. @@ -621,7 +619,7 @@ def __init__(self, allocate_low_presicion_allreduce_workspace(self.mapping) self.workspace = get_allreduce_workspace(self.mapping) - # Initialize Symmetric Memory AllReduce if needed (H100+ hardware acceleration) + # Initialize Symmetric Memory AllReduce if needed if self.strategy in (AllReduceStrategy.AUTO, AllReduceStrategy.SYMM_MEM): try: @@ -702,7 +700,7 @@ def forward( if all_reduce_params is None: all_reduce_params = AllReduceParams() - # Try Symmetric Memory AllReduce first if available (H100+ hardware acceleration) + # Try Symmetric Memory AllReduce first if available # Note: Currently only supports NONE fusion op (plain allreduce) if self.symm_mem_allreduce and all_reduce_params.fusion_op == AllReduceFusionOp.NONE: symm_mem_output = self.symm_mem_allreduce(input) diff --git a/tensorrt_llm/_torch/distributed/symm_mem_allreduce.py b/tensorrt_llm/_torch/distributed/symm_mem_allreduce.py index 97d604e2f25..41fb28a0685 100644 --- a/tensorrt_llm/_torch/distributed/symm_mem_allreduce.py +++ b/tensorrt_llm/_torch/distributed/symm_mem_allreduce.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. """ -Symmetric Memory AllReduce for H100+ GPUs +Symmetric Memory AllReduce This module provides PyTorch Symmetric Memory-based allreduce operations, -leveraging H100's MULTIMEM hardware instructions for 3x faster performance -compared to custom CUDA kernels on supported configurations. +leveraging MULTIMEM hardware instructions. """ from typing import Optional @@ -19,6 +18,7 @@ try: import torch.distributed._symmetric_memory as torch_symm_mem + SYMM_MEM_AVAILABLE = True except ImportError: SYMM_MEM_AVAILABLE = False @@ -30,21 +30,18 @@ class SymmetricMemoryAllReduce(nn.Module): """ AllReduce implementation using PyTorch's symmetric memory operations. - - This leverages H100's MULTIMEM hardware instructions for significantly faster - allreduce operations compared to software implementations. + This leverages MULTIMEM hardware instructions for faster allreduce operations. Supported configurations (world_size): - - SM 9.0 (H100): 4, 6, 8 GPUs - - SM 10.0 (future): 6, 8 GPUs + - SM 9.0: 4, 6, 8 GPUs + - SM 10.0: 6, 8 GPUs - Based on vLLM's implementation but integrated into TensorRT-LLM. """ # World sizes that support MULTIMEM instructions _WORLD_SIZES_MULTIMEM = { - "9.0": [4, 6, 8], # H100 - "10.0": [6, 8], # Future architectures + "9.0": [4, 6, 8], + "10.0": [6, 8], } # Maximum buffer sizes for symmetric memory (bytes) @@ -57,7 +54,7 @@ class SymmetricMemoryAllReduce(nn.Module): "10.0": { 6: 8 * 1024 * 1024, 8: 6 * 1024 * 1024, - } + }, } def __init__( @@ -74,8 +71,7 @@ def __init__( self.world_size = mapping.tp_size if not SYMM_MEM_AVAILABLE: - logger.warning( - "SymmetricMemoryAllReduce: PyTorch symm_mem not available") + logger.warning("SymmetricMemoryAllReduce: PyTorch symm_mem not available") return if not torch.cuda.is_available(): @@ -97,7 +93,8 @@ def __init__( if self.world_size not in self._MAX_SIZES[self.device_capability]: logger.info( f"SymmetricMemoryAllReduce: World size {self.world_size} not supported " - f"for SM {self.device_capability}") + f"for SM {self.device_capability}" + ) return # Get max buffer size for this configuration @@ -109,17 +106,13 @@ def __init__( # For TP parallelism, we need ranks [0, 1, 2, ..., tp_size-1] globally # NOT starting from tp_rank! if not dist.is_initialized(): - logger.warning( - "SymmetricMemoryAllReduce: torch.distributed not initialized" - ) + logger.warning("SymmetricMemoryAllReduce: torch.distributed not initialized") self.disabled = True return - # Assume contiguous TP ranks for now - # TODO: Get actual TP group from mapping if available - tp_group_ranks = list(range(mapping.tp_size)) - self.group = dist.new_group(tp_group_ranks) if len( - tp_group_ranks) > 1 else None + # Get actual TP group ranks from mapping + tp_group_ranks = mapping.tp_group() + self.group = dist.new_group(tp_group_ranks) if len(tp_group_ranks) > 1 else None else: self.group = group @@ -136,8 +129,7 @@ def __init__( dtype=self.dtype, ) # Pass group_name (string) not the group object - handle = torch_symm_mem.rendezvous(self.buffer, - self.group.group_name) + handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name) if handle.multicast_ptr == 0: logger.warning( @@ -145,21 +137,30 @@ def __init__( ) return - # Determine which algorithm to use - self.use_multimem = (self.world_size - in self._WORLD_SIZES_MULTIMEM.get( - self.device_capability, [])) + # Only enable if MULTIMEM is supported + # Otherwise, no benefit over existing TensorRT-LLM strategies + use_multimem = self.world_size in self._WORLD_SIZES_MULTIMEM.get( + self.device_capability, [] + ) + + if not use_multimem: + logger.info( + f"SymmetricMemoryAllReduce: MULTIMEM not supported for " + f"world_size={self.world_size}, SM={self.device_capability}. " + f"Falling back to standard allreduce strategies." + ) + return self.disabled = False - logger.info(f"SymmetricMemoryAllReduce initialized: " - f"world_size={self.world_size}, " - f"max_size={self.max_size}, " - f"SM={self.device_capability}, " - f"use_multimem={self.use_multimem}") + logger.info( + f"SymmetricMemoryAllReduce (MULTIMEM) initialized: " + f"world_size={self.world_size}, " + f"max_size={self.max_size}, " + f"SM={self.device_capability}" + ) except Exception as e: - logger.warning( - f"SymmetricMemoryAllReduce initialization failed: {e}") + logger.warning(f"SymmetricMemoryAllReduce initialization failed: {e}") return def should_use_symm_mem(self, inp: torch.Tensor) -> bool: @@ -197,25 +198,16 @@ def forward( out = torch.empty_like(inp) # Copy input to symmetric memory buffer - self.buffer[:inp.numel()].copy_(inp.view(-1)) - - # Perform allreduce using appropriate algorithm - if self.use_multimem: - # Use MULTIMEM hardware instructions (faster) - torch.ops.symm_mem.multimem_all_reduce_( - self.buffer[:inp.numel()], - "sum", - self.group.group_name, - ) - else: - # Use two-shot algorithm (fallback) - torch.ops.symm_mem.two_shot_all_reduce_( - self.buffer[:inp.numel()], - "sum", - self.group.group_name, - ) + self.buffer[: inp.numel()].copy_(inp.view(-1)) + + # Perform MULTIMEM allreduce + torch.ops.symm_mem.multimem_all_reduce_( + self.buffer[: inp.numel()], + "sum", + self.group.group_name, + ) # Copy result back - out.copy_(self.buffer[:inp.numel()].view(out.shape)) + out.copy_(self.buffer[: inp.numel()].view(out.shape)) return out diff --git a/tensorrt_llm/functional.py b/tensorrt_llm/functional.py index 8e05194d5a3..129cc33b183 100755 --- a/tensorrt_llm/functional.py +++ b/tensorrt_llm/functional.py @@ -3883,7 +3883,7 @@ class AllReduceStrategy(IntEnum): LOWPRECISION = 6 MNNVL = 7 NCCL_SYMMETRIC = 8 - SYMM_MEM = 9 # PyTorch symmetric memory with MULTIMEM (H100+) + SYMM_MEM = 9 # PyTorch symmetric memory with MULTIMEM class AllReduceFusionOp(IntEnum): From ef2924442fd7a59451dafa7f0ecc5a3639869b50 Mon Sep 17 00:00:00 2001 From: Eran Geva <19514940+MrGeva@users.noreply.github.com> Date: Wed, 5 Nov 2025 01:18:17 -0800 Subject: [PATCH 3/4] removed sym mem from AUTO flow Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com> --- tensorrt_llm/_torch/distributed/ops.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 867fac9df63..dda67705ca4 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -577,8 +577,8 @@ def __init__(self, - MIN_LATENCY: AllReduce uses MIN_LATENCY mode kernel. - - AUTO: AUTO chooses the best available strategy. Will try SYMM_MEM first (if available), - then MNNVL, then choose between NCCL and MIN_LATENCY based on a heuristic policy. + - AUTO: AUTO chooses the best available strategy. Will try MNNVL, + then choose between NCCL and MIN_LATENCY based on a heuristic policy. - LOWPRECISION: AllReduce quantizes data to lower precision for transmission. Should only be used on topologies with PCIe switches and without NVLink. @@ -613,15 +613,14 @@ def __init__(self, self.all_reduce_op = torch.ops.trtllm.allreduce_pg if self._disable_mpi else torch.ops.trtllm.allreduce if self.mapping.tp_size > 1: - # When Strategy is UB, it is guaranteed that the workspace is not used. - if self.strategy != AllReduceStrategy.UB: + # When Strategy is UB or SYMM_MEM, it is guaranteed that the workspace is not used. + if self.strategy != AllReduceStrategy.UB and self.strategy != AllReduceStrategy.SYMM_MEM: if self.strategy == AllReduceStrategy.LOWPRECISION: allocate_low_presicion_allreduce_workspace(self.mapping) self.workspace = get_allreduce_workspace(self.mapping) # Initialize Symmetric Memory AllReduce if needed - if self.strategy in (AllReduceStrategy.AUTO, - AllReduceStrategy.SYMM_MEM): + if self.strategy == AllReduceStrategy.SYMM_MEM: try: symm_mem = SymmetricMemoryAllReduce( self.mapping, @@ -643,8 +642,8 @@ def __init__(self, self.symm_mem_allreduce = None # Initialize MNNVL AllReduce if needed - if self.strategy in (AllReduceStrategy.AUTO, - AllReduceStrategy.MNNVL): + elif self.strategy in (AllReduceStrategy.AUTO, + AllReduceStrategy.MNNVL): if MNNVLAllReduce.is_mnnvl(self.mapping, dtype): # ALWAYS capture the exception when creating this instance try: From 511139ff0750ae79b1a333ec174e710116dafd43 Mon Sep 17 00:00:00 2001 From: Eran Geva <19514940+MrGeva@users.noreply.github.com> Date: Wed, 5 Nov 2025 08:47:03 -0800 Subject: [PATCH 4/4] fixed process groups handling and fallback Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com> --- .../_torch/auto_deploy/distributed/trtllm.py | 4 +- tensorrt_llm/_torch/distributed/ops.py | 46 +++++++++++++------ .../_torch/distributed/symm_mem_allreduce.py | 43 +++++++++++++++-- 3 files changed, 71 insertions(+), 22 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py b/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py index 10b2bf4e251..f7c3343be22 100644 --- a/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py +++ b/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py @@ -26,9 +26,9 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None): cache_key = (rank, world_size, tensor.dtype) if cache_key not in _allreduce_cache: p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) - # Use Strategy.AUTO for optimal performance + # Use SYMM_MEM strategy (tries symmetric memory first, falls back to AUTO if needed) _allreduce_cache[cache_key] = AllReduce( - mapping=p_config, strategy=AllReduceStrategy.AUTO, dtype=tensor.dtype + mapping=p_config, strategy=AllReduceStrategy.SYMM_MEM, dtype=tensor.dtype ) torch_op = _allreduce_cache[cache_key] diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index dda67705ca4..4406b236f58 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -613,13 +613,7 @@ def __init__(self, self.all_reduce_op = torch.ops.trtllm.allreduce_pg if self._disable_mpi else torch.ops.trtllm.allreduce if self.mapping.tp_size > 1: - # When Strategy is UB or SYMM_MEM, it is guaranteed that the workspace is not used. - if self.strategy != AllReduceStrategy.UB and self.strategy != AllReduceStrategy.SYMM_MEM: - if self.strategy == AllReduceStrategy.LOWPRECISION: - allocate_low_presicion_allreduce_workspace(self.mapping) - self.workspace = get_allreduce_workspace(self.mapping) - - # Initialize Symmetric Memory AllReduce if needed + # Initialize Symmetric Memory AllReduce if needed (before workspace allocation) if self.strategy == AllReduceStrategy.SYMM_MEM: try: symm_mem = SymmetricMemoryAllReduce( @@ -629,21 +623,35 @@ def __init__(self, if not symm_mem.disabled: self.symm_mem_allreduce = symm_mem logger.info( - f"SymmetricMemoryAllReduce (MULTIMEM) is enabled for world_size={self.mapping.tp_size}" + f"SymmetricMemoryAllReduce (MULTIMEM) is enabled with fallback support for world_size={self.mapping.tp_size}" ) + # Keep SYMM_MEM strategy but allocate workspace for fallback to regular allreduce else: - logger.debug( - f"SymmetricMemoryAllReduce is disabled (not supported or unavailable)" + logger.info( + f"SymmetricMemoryAllReduce is disabled (not supported or unavailable), falling back to AUTO strategy" ) + # Fall back to AUTO if SYMM_MEM can't be enabled + self.strategy = AllReduceStrategy.AUTO except Exception as e: - logger.debug( - f"Symmetric Memory AllReduce can't be enabled due to {e}." + logger.info( + f"Symmetric Memory AllReduce can't be enabled due to {e}, falling back to AUTO strategy" ) self.symm_mem_allreduce = None + # Fall back to AUTO if SYMM_MEM initialization fails + self.strategy = AllReduceStrategy.AUTO - # Initialize MNNVL AllReduce if needed - elif self.strategy in (AllReduceStrategy.AUTO, - AllReduceStrategy.MNNVL): + # Allocate workspace for strategies that need it + # Note: SYMM_MEM now also needs workspace for fallback scenarios (fused ops, etc.) + # Only UB doesn't need workspace + if self.strategy != AllReduceStrategy.UB: + if self.strategy == AllReduceStrategy.LOWPRECISION: + allocate_low_presicion_allreduce_workspace(self.mapping) + self.workspace = get_allreduce_workspace(self.mapping) + + # Initialize MNNVL if using AUTO or MNNVL strategy + if self.strategy in (AllReduceStrategy.AUTO, + AllReduceStrategy.MNNVL): + # Try to initialize MNNVL if MNNVLAllReduce.is_mnnvl(self.mapping, dtype): # ALWAYS capture the exception when creating this instance try: @@ -708,6 +716,13 @@ def forward( f"Using SymmetricMemoryAllReduce (MULTIMEM) for input shape {input.shape}" ) return symm_mem_output + elif self.symm_mem_allreduce and all_reduce_params.fusion_op != AllReduceFusionOp.NONE: + # Log once per rank that we're skipping symm_mem due to fusion + if not hasattr(self, '_logged_fusion_skip'): + logger.debug( + f"Skipping SymmetricMemoryAllReduce for fused operation (fusion_op={all_reduce_params.fusion_op}), using regular allreduce" + ) + self._logged_fusion_skip = True # Try MNNVL AllReduce if symm_mem didn't handle it if self.mnnvl_allreduce: @@ -724,6 +739,7 @@ def forward( additional_args = {} if self._disable_mpi: + # Get ProcessGroup from mapping pg = self.mapping.tp_group_pg assert pg is not None, "TP ProcessGroup not initialised" additional_args = { diff --git a/tensorrt_llm/_torch/distributed/symm_mem_allreduce.py b/tensorrt_llm/_torch/distributed/symm_mem_allreduce.py index 41fb28a0685..f02f20b786e 100644 --- a/tensorrt_llm/_torch/distributed/symm_mem_allreduce.py +++ b/tensorrt_llm/_torch/distributed/symm_mem_allreduce.py @@ -110,8 +110,8 @@ def __init__( self.disabled = True return - # Get actual TP group ranks from mapping - tp_group_ranks = mapping.tp_group() + # Get actual TP group ranks from mapping (tp_group is a property, not a method) + tp_group_ranks = mapping.tp_group self.group = dist.new_group(tp_group_ranks) if len(tp_group_ranks) > 1 else None else: self.group = group @@ -121,6 +121,31 @@ def __init__( self.disabled = True return + # Enable symmetric memory for this group + try: + # Get group_name - this may fail if ProcessGroup doesn't have group_name set + if not hasattr(self.group, "group_name"): + logger.warning( + "SymmetricMemoryAllReduce: ProcessGroup does not have group_name attribute" + ) + self.disabled = True + return + + group_name_str = str(self.group.group_name) + torch_symm_mem.enable_symm_mem_for_group(group_name_str) + logger.debug( + f"SymmetricMemoryAllReduce: Enabled symmetric memory for group {group_name_str}" + ) + except Exception as e: + logger.warning( + f"SymmetricMemoryAllReduce: Failed to enable symmetric memory for group: {e}" + ) + import traceback + + logger.debug(traceback.format_exc()) + self.disabled = True + return + # Allocate symmetric memory buffer try: self.buffer = torch_symm_mem.empty( @@ -128,8 +153,9 @@ def __init__( device=device, dtype=self.dtype, ) - # Pass group_name (string) not the group object - handle = torch_symm_mem.rendezvous(self.buffer, self.group.group_name) + # Pass group name string + group_name_str = str(self.group.group_name) + handle = torch_symm_mem.rendezvous(self.buffer, group_name_str) if handle.multicast_ptr == 0: logger.warning( @@ -163,6 +189,11 @@ def __init__( logger.warning(f"SymmetricMemoryAllReduce initialization failed: {e}") return + @property + def process_group(self) -> Optional[dist.ProcessGroup]: + """Expose the ProcessGroup for use in fallback scenarios.""" + return self.group if not self.disabled else None + def should_use_symm_mem(self, inp: torch.Tensor) -> bool: """Check if symmetric memory can be used for this tensor.""" if self.disabled: @@ -201,10 +232,12 @@ def forward( self.buffer[: inp.numel()].copy_(inp.view(-1)) # Perform MULTIMEM allreduce + # Pass group name string (matching vLLM's implementation) + group_name_str = str(self.group.group_name) torch.ops.symm_mem.multimem_all_reduce_( self.buffer[: inp.numel()], "sum", - self.group.group_name, + group_name_str, ) # Copy result back