From d94716751770e5cf3e9fc4ece068cb55f8e77f05 Mon Sep 17 00:00:00 2001 From: pranavm Date: Tue, 29 Jul 2025 11:12:03 -0700 Subject: [PATCH 1/4] Removes MPI dependency from MNNVL AllReduce implementation --- flashinfer/comm/mnnvl.py | 44 +++++++++++++----------- flashinfer/comm/trtllm_mnnvl_ar.py | 7 ++-- flashinfer/jit/env.py | 22 ++++++------ tests/test_trtllm_mnnvl_allreduce.py | 50 +++++++++++++++++++--------- 4 files changed, 73 insertions(+), 50 deletions(-) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 75a060662..c8d975a2d 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -17,13 +17,12 @@ import logging import platform import sys -from dataclasses import dataclass from typing import List, Optional import pynvml import torch +import torch.distributed as dist from cuda import cuda -from mpi4py import MPI from ..cuda_utils import checkCudaErrors from .dlpack_utils import create_dlpack_capsule, pack_strided_memory @@ -130,17 +129,6 @@ def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]: return device_ptr -class MpiComm: - _comm: MPI.Intracomm = MPI.COMM_WORLD - - @classmethod - def set_mpi_comm(cls, new_comm: MPI.Intracomm): - cls._comm = new_comm - - def __getattr__(self, name): - return getattr(self._comm, name) - - class MnnvlMemory: initialized: bool = False @@ -198,6 +186,19 @@ def initialize(): def get_comm(mapping: Mapping): if MnnvlMemory.comm is not None: return MnnvlMemory.comm + + from mpi4py import MPI + + class MpiComm: + _comm: MPI.Intracomm = MPI.COMM_WORLD + + @classmethod + def set_mpi_comm(cls, new_comm: MPI.Intracomm): + cls._comm = new_comm + + def __getattr__(self, name): + return getattr(self._comm, name) + comm = MpiComm().Split( mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank ) @@ -628,9 +629,6 @@ def _alloc_mn_mcast_mem(self, buf_size: int): except Exception as e: print(f"Error checking CUDA context: {e}") - # Get MPI communicator - comm = MpiComm() - # Set up allocation properties handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC @@ -692,7 +690,12 @@ def _alloc_mn_mcast_mem(self, buf_size: int): ) # All-gather fabric handles - all_fabric_handles = comm.allgather(my_fabric_handle.data) + if not dist.is_initialized(): + raise RuntimeError("torch.distributed must be initialized before use.") + + # Use all_gather_object to collect fabric handles from all ranks + all_fabric_handles = [None for _ in range(self.group_size)] + dist.all_gather_object(all_fabric_handles, my_fabric_handle.data) cuda.cuCtxSynchronize() # Import remote handles @@ -722,9 +725,10 @@ def _alloc_mn_mcast_mem(self, buf_size: int): mc_fabric_handle = None # Broadcast multicast handle - mc_fabric_handle_data = comm.bcast( - mc_fabric_handle.data if mc_fabric_handle else None, root=0 - ) + mc_fabric_handle_list = [mc_fabric_handle.data] if mc_fabric_handle else [None] + dist.broadcast_object_list(mc_fabric_handle_list, src=0) + mc_fabric_handle_data = mc_fabric_handle_list[0] + # Sync device to ensure broadcast is complete cuda.cuCtxSynchronize() # Import multicast handle for non-root ranks diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index d6dd0506e..ca300a2c5 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -10,7 +10,7 @@ from typing import List, Optional, Tuple import torch -from mpi4py import MPI +import torch.distributed as dist from flashinfer.comm.mapping import Mapping @@ -21,9 +21,6 @@ from .mnnvl import McastGPUBuffer -def mpi_barrier(): - """MPI barrier - could potentially be replaced with dist.barrier()""" - MPI.COMM_WORLD.Barrier() def gen_trtllm_mnnvl_comm_module() -> JitSpec: @@ -181,7 +178,7 @@ def get_allreduce_mnnvl_workspace( # CPU barrier since we assume this should not be called in cuda graph torch.cuda.synchronize() - mpi_barrier() + dist.barrier() # This is a buffer to maintain the state of this allreduce Op # [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter] diff --git a/flashinfer/jit/env.py b/flashinfer/jit/env.py index f8f3db553..b5aea8081 100644 --- a/flashinfer/jit/env.py +++ b/flashinfer/jit/env.py @@ -51,18 +51,20 @@ def _get_workspace_dir_name() -> pathlib.Path: FLASHINFER_WORKSPACE_DIR = _get_workspace_dir_name() FLASHINFER_JIT_DIR = FLASHINFER_WORKSPACE_DIR / "cached_ops" FLASHINFER_GEN_SRC_DIR = FLASHINFER_WORKSPACE_DIR / "generated" -_package_root = pathlib.Path(__file__).resolve().parents[1] -FLASHINFER_DATA = _package_root / "data" -FLASHINFER_INCLUDE_DIR = _package_root / "data" / "include" -FLASHINFER_CSRC_DIR = _package_root / "data" / "csrc" -# FLASHINFER_SRC_DIR = _package_root / "data" / "src" -FLASHINFER_TVM_BINDING_DIR = _package_root / "data" / "tvm_binding" -FLASHINFER_AOT_DIR = _package_root / "data" / "aot" +# TODO (pranavm): Check if this is right? +# Why were these pointing to non-existent directories? Must be a copy missing somewhere? +_package_root = pathlib.Path(__file__).resolve().parents[2] +FLASHINFER_DATA = _package_root +FLASHINFER_INCLUDE_DIR = _package_root / "include" +FLASHINFER_CSRC_DIR = _package_root / "csrc" +# FLASHINFER_SRC_DIR = _package_root / "src" +FLASHINFER_TVM_BINDING_DIR = _package_root / "tvm_binding" +FLASHINFER_AOT_DIR = _package_root / "aot" CUTLASS_INCLUDE_DIRS = [ - _package_root / "data" / "cutlass" / "include", - _package_root / "data" / "cutlass" / "tools" / "util" / "include", + _package_root / "3rdparty" / "cutlass" / "include", + _package_root / "3rdparty" / "cutlass" / "tools" / "util" / "include", ] -SPDLOG_INCLUDE_DIR = _package_root / "data" / "spdlog" / "include" +SPDLOG_INCLUDE_DIR = _package_root / "3rdparty" / "spdlog" / "include" def get_nvshmem_include_dirs(): diff --git a/tests/test_trtllm_mnnvl_allreduce.py b/tests/test_trtllm_mnnvl_allreduce.py index 860da8535..a31a9bb3e 100644 --- a/tests/test_trtllm_mnnvl_allreduce.py +++ b/tests/test_trtllm_mnnvl_allreduce.py @@ -1,16 +1,13 @@ # Check torch version: import os import sys -import traceback import pytest import torch -from mpi4py import MPI # Added MPI import +import torch.distributed as dist -import flashinfer.comm as comm import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar from flashinfer.comm.mapping import Mapping -from flashinfer.comm.mnnvl import McastDeviceMemory, McastGPUBuffer # Use flashinfer.norm.rmsnorm as reference implementation. from flashinfer.norm import rmsnorm @@ -42,7 +39,7 @@ def row_linear_residual_norm_fusion_forward( tensor_parallel_size = mapping.tp_size tensor_parallel_rank = mapping.tp_rank - MPI.COMM_WORLD.barrier() + dist.barrier() def func( input, @@ -70,7 +67,7 @@ def func( prenorm_output = torch.empty_like(residual) normed_output = torch.empty_like(residual) - trtllm_mnnvl_ar.mpi_barrier() + dist.barrier() trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_rmsnorm( prenorm_output, @@ -152,7 +149,7 @@ def func( ) -"""Main test function that runs on each MPI rank""" +"""Main test function that runs on each distributed rank""" @pytest.mark.parametrize( @@ -173,9 +170,29 @@ def test_mnnvl_allreduce_full( ): monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce. - # Get MPI info - rank = MPI.COMM_WORLD.Get_rank() - world_size = MPI.COMM_WORLD.Get_size() + # Hack to get an address that all nodes in the slurm job can reach + master_addr = os.environ.get("MASTER_ADDR", None) + if master_addr is None: + # Use SLURM's first node as master + master_addr = os.environ.get("SLURM_NODELIST", "localhost").split(",")[0] + # Remove brackets if present (e.g., "node[01-02]" -> "node01") + if "[" in master_addr: + import re + + master_addr = re.sub(r"\[.*\]", "0", master_addr) + master_port = os.environ.get("MASTER_PORT", "12345") + + if not dist.is_initialized(): + dist.init_process_group( + backend="gloo", + init_method=f"tcp://{master_addr}:{master_port}", + world_size=2, + rank=int(os.environ.get("SLURM_PROCID")), + ) + + # Get distributed info + rank = dist.get_rank() + world_size = dist.get_world_size() gpus_per_node = torch.cuda.device_count() if gpus_per_node == 0: @@ -184,7 +201,9 @@ def test_mnnvl_allreduce_full( # Ensure we have exactly 2 ranks for this test if world_size < 2: if rank == 0: - print(f"ERROR: This test requires at least 2 MPI ranks, got {world_size}") + print( + f"ERROR: This test requires at least 2 distributed ranks, got {world_size}" + ) sys.exit(1) mapping = Mapping( @@ -287,7 +306,7 @@ def test_mnnvl_allreduce_full( ) # Synchronize before next test - trtllm_mnnvl_ar.mpi_barrier() + dist.barrier() print( f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}, dtype={dtype}" @@ -298,7 +317,8 @@ def test_mnnvl_allreduce_full( failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype} failed: {e}" print(failure_message) # Gather failure status from all ranks - all_failures = MPI.COMM_WORLD.allgather(rank_failed) + all_failures = [None for _ in range(world_size)] + dist.all_gather_object(all_failures, rank_failed) # If any rank failed, fail the test if any(all_failures): @@ -308,7 +328,7 @@ def test_mnnvl_allreduce_full( # Fail the test on all ranks pytest.fail(f"Test failed on ranks {failed_ranks}") - trtllm_mnnvl_ar.mpi_barrier() + dist.barrier() finally: # Ensure cleanup happens for this list's workspace @@ -316,4 +336,4 @@ def test_mnnvl_allreduce_full( del mcast_buffer_mnnvl # Final synchronization and check for failures across all ranks - trtllm_mnnvl_ar.mpi_barrier() + dist.barrier() From 694af3a63ef565b9cb4704758330e45d101d40f3 Mon Sep 17 00:00:00 2001 From: pranavm Date: Tue, 29 Jul 2025 11:16:07 -0700 Subject: [PATCH 2/4] Adds groups argument to use custom process groups --- flashinfer/comm/mnnvl.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index c8d975a2d..0dc7cfe0a 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -420,6 +420,7 @@ def __init__( group_rank: int, device_idx: int, is_multi_node: bool = True, + group: dist.ProcessGroup = None, ): cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx)) @@ -437,6 +438,7 @@ def __init__( self.device_idx = device_idx self.group_size = group_size self.group_rank = group_rank + self.group = group self.buf_size = buf_size self.signal_pad_offset = 0 self.allocation_size = 0 @@ -695,7 +697,9 @@ def _alloc_mn_mcast_mem(self, buf_size: int): # Use all_gather_object to collect fabric handles from all ranks all_fabric_handles = [None for _ in range(self.group_size)] - dist.all_gather_object(all_fabric_handles, my_fabric_handle.data) + dist.all_gather_object( + all_fabric_handles, my_fabric_handle.data, group=self.group + ) cuda.cuCtxSynchronize() # Import remote handles @@ -726,7 +730,12 @@ def _alloc_mn_mcast_mem(self, buf_size: int): # Broadcast multicast handle mc_fabric_handle_list = [mc_fabric_handle.data] if mc_fabric_handle else [None] - dist.broadcast_object_list(mc_fabric_handle_list, src=0) + if self.group: + dist.broadcast_object_list( + mc_fabric_handle_list, group_src=0, group=self.group + ) + else: + dist.broadcast_object_list(mc_fabric_handle_list, src=0) mc_fabric_handle_data = mc_fabric_handle_list[0] # Sync device to ensure broadcast is complete From 5e5410e87ec8819c3ca69a5a25c00d970c305372 Mon Sep 17 00:00:00 2001 From: Pranav Marathe Date: Wed, 30 Jul 2025 15:58:52 +0000 Subject: [PATCH 3/4] Pipes changes to pass group up through API --- flashinfer/comm/mnnvl.py | 3 ++- flashinfer/comm/trtllm_mnnvl_ar.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 0dc7cfe0a..286ece22f 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -843,6 +843,7 @@ def __init__( group_rank: int, device: torch.device, mn_nvlink: bool = True, + group: dist.ProcessGroup = None, ): """ Constructor for McastGpuBuffer. @@ -855,7 +856,7 @@ def __init__( mn_nvlink: Flag indicating if multi-node NVLink is used """ self.mcast_device_memory = McastDeviceMemory( - buf_size, group_size, group_rank, device.index, mn_nvlink + buf_size, group_size, group_rank, device.index, mn_nvlink, group ) self.buf_size = buf_size self.local_device = device diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index ca300a2c5..5cf260e77 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -129,7 +129,7 @@ def trtllm_mnnvl_rmsnorm( def get_allreduce_mnnvl_workspace( - mapping: Mapping, dtype: torch.dtype + mapping: Mapping, dtype: torch.dtype, group: Optional[dist.ProcessGroup] = None ) -> Tuple[McastGPUBuffer, torch.Tensor, int]: """Get workspace buffers needed for multi-node NVLink all-reduce operation. @@ -171,6 +171,7 @@ def get_allreduce_mnnvl_workspace( mapping.tp_rank, torch.device("cuda", mapping.local_rank), mapping.is_multi_node() or force_mn, + group=group, ) # Initialize the unicast buffer with -0.0 From 481e9bcdac1467d9a08ab7eab8ec648d16039f8a Mon Sep 17 00:00:00 2001 From: pranavm Date: Mon, 4 Aug 2025 11:11:55 -0700 Subject: [PATCH 4/4] Applies pre-commit hook --- flashinfer/comm/trtllm_mnnvl_ar.py | 2 -- flashinfer/jit/env.py | 8 ++++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index 5cf260e77..8490b1b6b 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -21,8 +21,6 @@ from .mnnvl import McastGPUBuffer - - def gen_trtllm_mnnvl_comm_module() -> JitSpec: return gen_jit_spec( "trtllm_mnnvl_comm", diff --git a/flashinfer/jit/env.py b/flashinfer/jit/env.py index b5aea8081..ce80e99a3 100644 --- a/flashinfer/jit/env.py +++ b/flashinfer/jit/env.py @@ -55,11 +55,11 @@ def _get_workspace_dir_name() -> pathlib.Path: # Why were these pointing to non-existent directories? Must be a copy missing somewhere? _package_root = pathlib.Path(__file__).resolve().parents[2] FLASHINFER_DATA = _package_root -FLASHINFER_INCLUDE_DIR = _package_root / "include" -FLASHINFER_CSRC_DIR = _package_root / "csrc" +FLASHINFER_INCLUDE_DIR = _package_root / "include" +FLASHINFER_CSRC_DIR = _package_root / "csrc" # FLASHINFER_SRC_DIR = _package_root / "src" -FLASHINFER_TVM_BINDING_DIR = _package_root / "tvm_binding" -FLASHINFER_AOT_DIR = _package_root / "aot" +FLASHINFER_TVM_BINDING_DIR = _package_root / "tvm_binding" +FLASHINFER_AOT_DIR = _package_root / "aot" CUTLASS_INCLUDE_DIRS = [ _package_root / "3rdparty" / "cutlass" / "include", _package_root / "3rdparty" / "cutlass" / "tools" / "util" / "include",