diff --git a/flashinfer/comm/mnnvl.py b/flashinfer/comm/mnnvl.py index 75a060662..286ece22f 100644 --- a/flashinfer/comm/mnnvl.py +++ b/flashinfer/comm/mnnvl.py @@ -17,13 +17,12 @@ import logging import platform import sys -from dataclasses import dataclass from typing import List, Optional import pynvml import torch +import torch.distributed as dist from cuda import cuda -from mpi4py import MPI from ..cuda_utils import checkCudaErrors from .dlpack_utils import create_dlpack_capsule, pack_strided_memory @@ -130,17 +129,6 @@ def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]: return device_ptr -class MpiComm: - _comm: MPI.Intracomm = MPI.COMM_WORLD - - @classmethod - def set_mpi_comm(cls, new_comm: MPI.Intracomm): - cls._comm = new_comm - - def __getattr__(self, name): - return getattr(self._comm, name) - - class MnnvlMemory: initialized: bool = False @@ -198,6 +186,19 @@ def initialize(): def get_comm(mapping: Mapping): if MnnvlMemory.comm is not None: return MnnvlMemory.comm + + from mpi4py import MPI + + class MpiComm: + _comm: MPI.Intracomm = MPI.COMM_WORLD + + @classmethod + def set_mpi_comm(cls, new_comm: MPI.Intracomm): + cls._comm = new_comm + + def __getattr__(self, name): + return getattr(self._comm, name) + comm = MpiComm().Split( mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank ) @@ -419,6 +420,7 @@ def __init__( group_rank: int, device_idx: int, is_multi_node: bool = True, + group: dist.ProcessGroup = None, ): cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx)) @@ -436,6 +438,7 @@ def __init__( self.device_idx = device_idx self.group_size = group_size self.group_rank = group_rank + self.group = group self.buf_size = buf_size self.signal_pad_offset = 0 self.allocation_size = 0 @@ -628,9 +631,6 @@ def _alloc_mn_mcast_mem(self, buf_size: int): except Exception as e: print(f"Error checking CUDA context: {e}") - # Get MPI communicator - comm = MpiComm() - # Set up allocation properties handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC @@ -692,7 +692,14 @@ def _alloc_mn_mcast_mem(self, buf_size: int): ) # All-gather fabric handles - all_fabric_handles = comm.allgather(my_fabric_handle.data) + if not dist.is_initialized(): + raise RuntimeError("torch.distributed must be initialized before use.") + + # Use all_gather_object to collect fabric handles from all ranks + all_fabric_handles = [None for _ in range(self.group_size)] + dist.all_gather_object( + all_fabric_handles, my_fabric_handle.data, group=self.group + ) cuda.cuCtxSynchronize() # Import remote handles @@ -722,9 +729,15 @@ def _alloc_mn_mcast_mem(self, buf_size: int): mc_fabric_handle = None # Broadcast multicast handle - mc_fabric_handle_data = comm.bcast( - mc_fabric_handle.data if mc_fabric_handle else None, root=0 - ) + mc_fabric_handle_list = [mc_fabric_handle.data] if mc_fabric_handle else [None] + if self.group: + dist.broadcast_object_list( + mc_fabric_handle_list, group_src=0, group=self.group + ) + else: + dist.broadcast_object_list(mc_fabric_handle_list, src=0) + mc_fabric_handle_data = mc_fabric_handle_list[0] + # Sync device to ensure broadcast is complete cuda.cuCtxSynchronize() # Import multicast handle for non-root ranks @@ -830,6 +843,7 @@ def __init__( group_rank: int, device: torch.device, mn_nvlink: bool = True, + group: dist.ProcessGroup = None, ): """ Constructor for McastGpuBuffer. @@ -842,7 +856,7 @@ def __init__( mn_nvlink: Flag indicating if multi-node NVLink is used """ self.mcast_device_memory = McastDeviceMemory( - buf_size, group_size, group_rank, device.index, mn_nvlink + buf_size, group_size, group_rank, device.index, mn_nvlink, group ) self.buf_size = buf_size self.local_device = device diff --git a/flashinfer/comm/trtllm_mnnvl_ar.py b/flashinfer/comm/trtllm_mnnvl_ar.py index d6dd0506e..8490b1b6b 100644 --- a/flashinfer/comm/trtllm_mnnvl_ar.py +++ b/flashinfer/comm/trtllm_mnnvl_ar.py @@ -10,7 +10,7 @@ from typing import List, Optional, Tuple import torch -from mpi4py import MPI +import torch.distributed as dist from flashinfer.comm.mapping import Mapping @@ -21,11 +21,6 @@ from .mnnvl import McastGPUBuffer -def mpi_barrier(): - """MPI barrier - could potentially be replaced with dist.barrier()""" - MPI.COMM_WORLD.Barrier() - - def gen_trtllm_mnnvl_comm_module() -> JitSpec: return gen_jit_spec( "trtllm_mnnvl_comm", @@ -132,7 +127,7 @@ def trtllm_mnnvl_rmsnorm( def get_allreduce_mnnvl_workspace( - mapping: Mapping, dtype: torch.dtype + mapping: Mapping, dtype: torch.dtype, group: Optional[dist.ProcessGroup] = None ) -> Tuple[McastGPUBuffer, torch.Tensor, int]: """Get workspace buffers needed for multi-node NVLink all-reduce operation. @@ -174,6 +169,7 @@ def get_allreduce_mnnvl_workspace( mapping.tp_rank, torch.device("cuda", mapping.local_rank), mapping.is_multi_node() or force_mn, + group=group, ) # Initialize the unicast buffer with -0.0 @@ -181,7 +177,7 @@ def get_allreduce_mnnvl_workspace( # CPU barrier since we assume this should not be called in cuda graph torch.cuda.synchronize() - mpi_barrier() + dist.barrier() # This is a buffer to maintain the state of this allreduce Op # [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter] diff --git a/flashinfer/jit/env.py b/flashinfer/jit/env.py index f8f3db553..ce80e99a3 100644 --- a/flashinfer/jit/env.py +++ b/flashinfer/jit/env.py @@ -51,18 +51,20 @@ def _get_workspace_dir_name() -> pathlib.Path: FLASHINFER_WORKSPACE_DIR = _get_workspace_dir_name() FLASHINFER_JIT_DIR = FLASHINFER_WORKSPACE_DIR / "cached_ops" FLASHINFER_GEN_SRC_DIR = FLASHINFER_WORKSPACE_DIR / "generated" -_package_root = pathlib.Path(__file__).resolve().parents[1] -FLASHINFER_DATA = _package_root / "data" -FLASHINFER_INCLUDE_DIR = _package_root / "data" / "include" -FLASHINFER_CSRC_DIR = _package_root / "data" / "csrc" -# FLASHINFER_SRC_DIR = _package_root / "data" / "src" -FLASHINFER_TVM_BINDING_DIR = _package_root / "data" / "tvm_binding" -FLASHINFER_AOT_DIR = _package_root / "data" / "aot" +# TODO (pranavm): Check if this is right? +# Why were these pointing to non-existent directories? Must be a copy missing somewhere? +_package_root = pathlib.Path(__file__).resolve().parents[2] +FLASHINFER_DATA = _package_root +FLASHINFER_INCLUDE_DIR = _package_root / "include" +FLASHINFER_CSRC_DIR = _package_root / "csrc" +# FLASHINFER_SRC_DIR = _package_root / "src" +FLASHINFER_TVM_BINDING_DIR = _package_root / "tvm_binding" +FLASHINFER_AOT_DIR = _package_root / "aot" CUTLASS_INCLUDE_DIRS = [ - _package_root / "data" / "cutlass" / "include", - _package_root / "data" / "cutlass" / "tools" / "util" / "include", + _package_root / "3rdparty" / "cutlass" / "include", + _package_root / "3rdparty" / "cutlass" / "tools" / "util" / "include", ] -SPDLOG_INCLUDE_DIR = _package_root / "data" / "spdlog" / "include" +SPDLOG_INCLUDE_DIR = _package_root / "3rdparty" / "spdlog" / "include" def get_nvshmem_include_dirs(): diff --git a/tests/test_trtllm_mnnvl_allreduce.py b/tests/test_trtllm_mnnvl_allreduce.py index 860da8535..a31a9bb3e 100644 --- a/tests/test_trtllm_mnnvl_allreduce.py +++ b/tests/test_trtllm_mnnvl_allreduce.py @@ -1,16 +1,13 @@ # Check torch version: import os import sys -import traceback import pytest import torch -from mpi4py import MPI # Added MPI import +import torch.distributed as dist -import flashinfer.comm as comm import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar from flashinfer.comm.mapping import Mapping -from flashinfer.comm.mnnvl import McastDeviceMemory, McastGPUBuffer # Use flashinfer.norm.rmsnorm as reference implementation. from flashinfer.norm import rmsnorm @@ -42,7 +39,7 @@ def row_linear_residual_norm_fusion_forward( tensor_parallel_size = mapping.tp_size tensor_parallel_rank = mapping.tp_rank - MPI.COMM_WORLD.barrier() + dist.barrier() def func( input, @@ -70,7 +67,7 @@ def func( prenorm_output = torch.empty_like(residual) normed_output = torch.empty_like(residual) - trtllm_mnnvl_ar.mpi_barrier() + dist.barrier() trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_rmsnorm( prenorm_output, @@ -152,7 +149,7 @@ def func( ) -"""Main test function that runs on each MPI rank""" +"""Main test function that runs on each distributed rank""" @pytest.mark.parametrize( @@ -173,9 +170,29 @@ def test_mnnvl_allreduce_full( ): monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce. - # Get MPI info - rank = MPI.COMM_WORLD.Get_rank() - world_size = MPI.COMM_WORLD.Get_size() + # Hack to get an address that all nodes in the slurm job can reach + master_addr = os.environ.get("MASTER_ADDR", None) + if master_addr is None: + # Use SLURM's first node as master + master_addr = os.environ.get("SLURM_NODELIST", "localhost").split(",")[0] + # Remove brackets if present (e.g., "node[01-02]" -> "node01") + if "[" in master_addr: + import re + + master_addr = re.sub(r"\[.*\]", "0", master_addr) + master_port = os.environ.get("MASTER_PORT", "12345") + + if not dist.is_initialized(): + dist.init_process_group( + backend="gloo", + init_method=f"tcp://{master_addr}:{master_port}", + world_size=2, + rank=int(os.environ.get("SLURM_PROCID")), + ) + + # Get distributed info + rank = dist.get_rank() + world_size = dist.get_world_size() gpus_per_node = torch.cuda.device_count() if gpus_per_node == 0: @@ -184,7 +201,9 @@ def test_mnnvl_allreduce_full( # Ensure we have exactly 2 ranks for this test if world_size < 2: if rank == 0: - print(f"ERROR: This test requires at least 2 MPI ranks, got {world_size}") + print( + f"ERROR: This test requires at least 2 distributed ranks, got {world_size}" + ) sys.exit(1) mapping = Mapping( @@ -287,7 +306,7 @@ def test_mnnvl_allreduce_full( ) # Synchronize before next test - trtllm_mnnvl_ar.mpi_barrier() + dist.barrier() print( f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}, dtype={dtype}" @@ -298,7 +317,8 @@ def test_mnnvl_allreduce_full( failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype} failed: {e}" print(failure_message) # Gather failure status from all ranks - all_failures = MPI.COMM_WORLD.allgather(rank_failed) + all_failures = [None for _ in range(world_size)] + dist.all_gather_object(all_failures, rank_failed) # If any rank failed, fail the test if any(all_failures): @@ -308,7 +328,7 @@ def test_mnnvl_allreduce_full( # Fail the test on all ranks pytest.fail(f"Test failed on ranks {failed_ranks}") - trtllm_mnnvl_ar.mpi_barrier() + dist.barrier() finally: # Ensure cleanup happens for this list's workspace @@ -316,4 +336,4 @@ def test_mnnvl_allreduce_full( del mcast_buffer_mnnvl # Final synchronization and check for failures across all ranks - trtllm_mnnvl_ar.mpi_barrier() + dist.barrier()