-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[#8921][feat] Added symetric memory AllReduce strategy #8919
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
MrGeva
wants to merge
4
commits into
NVIDIA:main
Choose a base branch
from
nv-auto-deploy:egeva/sym_mem
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+312
−9
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,246 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| """ | ||
| Symmetric Memory AllReduce | ||
|
|
||
| This module provides PyTorch Symmetric Memory-based allreduce operations, | ||
| leveraging MULTIMEM hardware instructions. | ||
| """ | ||
|
|
||
| from typing import Optional | ||
|
|
||
| import torch | ||
| import torch.distributed as dist | ||
| from torch import nn | ||
|
|
||
| from tensorrt_llm.logger import logger | ||
| from tensorrt_llm.mapping import Mapping | ||
|
|
||
| try: | ||
| import torch.distributed._symmetric_memory as torch_symm_mem | ||
|
|
||
| SYMM_MEM_AVAILABLE = True | ||
| except ImportError: | ||
| SYMM_MEM_AVAILABLE = False | ||
| logger.warning( | ||
| "PyTorch symmetric memory not available. Install PyTorch >= 2.8 for MULTIMEM support." | ||
| ) | ||
|
|
||
|
|
||
| class SymmetricMemoryAllReduce(nn.Module): | ||
| """ | ||
| AllReduce implementation using PyTorch's symmetric memory operations. | ||
| This leverages MULTIMEM hardware instructions for faster allreduce operations. | ||
|
|
||
| Supported configurations (world_size): | ||
| - SM 9.0: 4, 6, 8 GPUs | ||
| - SM 10.0: 6, 8 GPUs | ||
|
|
||
| """ | ||
|
|
||
| # World sizes that support MULTIMEM instructions | ||
| _WORLD_SIZES_MULTIMEM = { | ||
| "9.0": [4, 6, 8], | ||
| "10.0": [6, 8], | ||
| } | ||
|
|
||
| # Maximum buffer sizes for symmetric memory (bytes) | ||
| _MAX_SIZES = { | ||
| "9.0": { | ||
| 4: 8 * 1024 * 1024, # 8MB for 4 GPUs | ||
| 6: 6 * 1024 * 1024, # 6MB for 6 GPUs | ||
| 8: 4 * 1024 * 1024, # 4MB for 8 GPUs | ||
| }, | ||
| "10.0": { | ||
| 6: 8 * 1024 * 1024, | ||
| 8: 6 * 1024 * 1024, | ||
| }, | ||
| } | ||
|
|
||
| def __init__( | ||
| self, | ||
| mapping: Mapping, | ||
| dtype: torch.dtype = torch.bfloat16, | ||
| group: Optional[dist.ProcessGroup] = None, | ||
| ): | ||
| super().__init__() | ||
|
|
||
| self.disabled = True | ||
| self.mapping = mapping | ||
| self.dtype = dtype | ||
| self.world_size = mapping.tp_size | ||
|
|
||
| if not SYMM_MEM_AVAILABLE: | ||
| logger.warning("SymmetricMemoryAllReduce: PyTorch symm_mem not available") | ||
| return | ||
|
|
||
| if not torch.cuda.is_available(): | ||
| logger.warning("SymmetricMemoryAllReduce: CUDA not available") | ||
| return | ||
|
|
||
| # Get device capability | ||
| device = torch.device(f"cuda:{mapping.tp_rank}") | ||
| capability = torch.cuda.get_device_capability(device) | ||
| self.device_capability = f"{capability[0]}.{capability[1]}" | ||
|
|
||
| # Check if this configuration is supported | ||
| if self.device_capability not in self._MAX_SIZES: | ||
| logger.warning( | ||
| f"SymmetricMemoryAllReduce: Device capability {self.device_capability} not supported" | ||
| ) | ||
| return | ||
|
|
||
| if self.world_size not in self._MAX_SIZES[self.device_capability]: | ||
| logger.info( | ||
| f"SymmetricMemoryAllReduce: World size {self.world_size} not supported " | ||
| f"for SM {self.device_capability}" | ||
| ) | ||
| return | ||
|
|
||
| # Get max buffer size for this configuration | ||
| self.max_size = self._MAX_SIZES[self.device_capability][self.world_size] | ||
|
|
||
| # Set up process group | ||
| if group is None: | ||
| # Get or create TP group with correct ranks | ||
| # For TP parallelism, we need ranks [0, 1, 2, ..., tp_size-1] globally | ||
| # NOT starting from tp_rank! | ||
| if not dist.is_initialized(): | ||
| logger.warning("SymmetricMemoryAllReduce: torch.distributed not initialized") | ||
| self.disabled = True | ||
| return | ||
|
|
||
| # Get actual TP group ranks from mapping (tp_group is a property, not a method) | ||
| tp_group_ranks = mapping.tp_group | ||
| self.group = dist.new_group(tp_group_ranks) if len(tp_group_ranks) > 1 else None | ||
| else: | ||
| self.group = group | ||
|
|
||
| if self.group is None: | ||
| logger.warning("SymmetricMemoryAllReduce: No valid process group") | ||
| self.disabled = True | ||
| return | ||
|
|
||
| # Enable symmetric memory for this group | ||
| try: | ||
| # Get group_name - this may fail if ProcessGroup doesn't have group_name set | ||
| if not hasattr(self.group, "group_name"): | ||
| logger.warning( | ||
| "SymmetricMemoryAllReduce: ProcessGroup does not have group_name attribute" | ||
| ) | ||
| self.disabled = True | ||
| return | ||
|
|
||
| group_name_str = str(self.group.group_name) | ||
| torch_symm_mem.enable_symm_mem_for_group(group_name_str) | ||
| logger.debug( | ||
| f"SymmetricMemoryAllReduce: Enabled symmetric memory for group {group_name_str}" | ||
| ) | ||
| except Exception as e: | ||
| logger.warning( | ||
| f"SymmetricMemoryAllReduce: Failed to enable symmetric memory for group: {e}" | ||
| ) | ||
| import traceback | ||
|
|
||
| logger.debug(traceback.format_exc()) | ||
| self.disabled = True | ||
| return | ||
|
|
||
| # Allocate symmetric memory buffer | ||
| try: | ||
| self.buffer = torch_symm_mem.empty( | ||
| self.max_size // self.dtype.itemsize, | ||
| device=device, | ||
| dtype=self.dtype, | ||
| ) | ||
| # Pass group name string | ||
| group_name_str = str(self.group.group_name) | ||
| handle = torch_symm_mem.rendezvous(self.buffer, group_name_str) | ||
|
|
||
| if handle.multicast_ptr == 0: | ||
| logger.warning( | ||
| "SymmetricMemoryAllReduce: MULTIMEM operations not supported (multicast_ptr is 0)" | ||
| ) | ||
| return | ||
|
|
||
| # Only enable if MULTIMEM is supported | ||
| # Otherwise, no benefit over existing TensorRT-LLM strategies | ||
| use_multimem = self.world_size in self._WORLD_SIZES_MULTIMEM.get( | ||
| self.device_capability, [] | ||
| ) | ||
|
|
||
| if not use_multimem: | ||
| logger.info( | ||
| f"SymmetricMemoryAllReduce: MULTIMEM not supported for " | ||
| f"world_size={self.world_size}, SM={self.device_capability}. " | ||
| f"Falling back to standard allreduce strategies." | ||
| ) | ||
| return | ||
|
|
||
| self.disabled = False | ||
| logger.info( | ||
| f"SymmetricMemoryAllReduce (MULTIMEM) initialized: " | ||
| f"world_size={self.world_size}, " | ||
| f"max_size={self.max_size}, " | ||
| f"SM={self.device_capability}" | ||
| ) | ||
|
|
||
| except Exception as e: | ||
| logger.warning(f"SymmetricMemoryAllReduce initialization failed: {e}") | ||
| return | ||
|
|
||
| @property | ||
| def process_group(self) -> Optional[dist.ProcessGroup]: | ||
| """Expose the ProcessGroup for use in fallback scenarios.""" | ||
| return self.group if not self.disabled else None | ||
|
|
||
| def should_use_symm_mem(self, inp: torch.Tensor) -> bool: | ||
| """Check if symmetric memory can be used for this tensor.""" | ||
| if self.disabled: | ||
| return False | ||
| if inp.dtype != self.dtype: | ||
| return False | ||
| inp_size = inp.numel() * inp.element_size() | ||
| if inp_size % 4 != 0: | ||
| return False | ||
| if inp_size >= self.max_size: | ||
| return False | ||
| return True | ||
|
|
||
| def forward( | ||
| self, | ||
| inp: torch.Tensor, | ||
| out: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Perform allreduce using symmetric memory operations. | ||
|
|
||
| Args: | ||
| inp: Input tensor to reduce | ||
| out: Optional output tensor (if None, will be allocated) | ||
|
|
||
| Returns: | ||
| Reduced tensor | ||
| """ | ||
| if not self.should_use_symm_mem(inp): | ||
| return None # Caller should fall back to other strategy | ||
|
|
||
| if out is None: | ||
| out = torch.empty_like(inp) | ||
|
|
||
| # Copy input to symmetric memory buffer | ||
| self.buffer[: inp.numel()].copy_(inp.view(-1)) | ||
|
|
||
| # Perform MULTIMEM allreduce | ||
| # Pass group name string (matching vLLM's implementation) | ||
| group_name_str = str(self.group.group_name) | ||
| torch.ops.symm_mem.multimem_all_reduce_( | ||
| self.buffer[: inp.numel()], | ||
| "sum", | ||
| group_name_str, | ||
| ) | ||
|
|
||
| # Copy result back | ||
| out.copy_(self.buffer[: inp.numel()].view(out.shape)) | ||
|
|
||
| return out |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MrGeva : we should make this configurable and users should be able to specify the choice through inference optimizer config. Would you be able to take on the task of plumbing the AllReduceStrategy from the top level default.yaml -> sharding transformation -> inserting the all reduce op with the specified strategy?