Skip to content

Removes MPI dependency from MNNVL AllReduce #1379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 35 additions & 21 deletions flashinfer/comm/mnnvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@
import logging
import platform
import sys
from dataclasses import dataclass
from typing import List, Optional

import pynvml
import torch
import torch.distributed as dist
from cuda import cuda
from mpi4py import MPI

from ..cuda_utils import checkCudaErrors
from .dlpack_utils import create_dlpack_capsule, pack_strided_memory
Expand Down Expand Up @@ -130,17 +129,6 @@ def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]:
return device_ptr


class MpiComm:
_comm: MPI.Intracomm = MPI.COMM_WORLD

@classmethod
def set_mpi_comm(cls, new_comm: MPI.Intracomm):
cls._comm = new_comm

def __getattr__(self, name):
return getattr(self._comm, name)


class MnnvlMemory:
initialized: bool = False

Expand Down Expand Up @@ -198,6 +186,19 @@ def initialize():
def get_comm(mapping: Mapping):
if MnnvlMemory.comm is not None:
return MnnvlMemory.comm

from mpi4py import MPI

class MpiComm:
_comm: MPI.Intracomm = MPI.COMM_WORLD

@classmethod
def set_mpi_comm(cls, new_comm: MPI.Intracomm):
cls._comm = new_comm

def __getattr__(self, name):
return getattr(self._comm, name)

comm = MpiComm().Split(
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
)
Expand Down Expand Up @@ -419,6 +420,7 @@ def __init__(
group_rank: int,
device_idx: int,
is_multi_node: bool = True,
group: dist.ProcessGroup = None,
):
cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx))

Expand All @@ -436,6 +438,7 @@ def __init__(
self.device_idx = device_idx
self.group_size = group_size
self.group_rank = group_rank
self.group = group
self.buf_size = buf_size
self.signal_pad_offset = 0
self.allocation_size = 0
Expand Down Expand Up @@ -628,9 +631,6 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
except Exception as e:
print(f"Error checking CUDA context: {e}")

# Get MPI communicator
comm = MpiComm()

# Set up allocation properties
handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC

Expand Down Expand Up @@ -692,7 +692,14 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
)

# All-gather fabric handles
all_fabric_handles = comm.allgather(my_fabric_handle.data)
if not dist.is_initialized():
raise RuntimeError("torch.distributed must be initialized before use.")

# Use all_gather_object to collect fabric handles from all ranks
all_fabric_handles = [None for _ in range(self.group_size)]
dist.all_gather_object(
all_fabric_handles, my_fabric_handle.data, group=self.group
)
cuda.cuCtxSynchronize()

# Import remote handles
Expand Down Expand Up @@ -722,9 +729,15 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
mc_fabric_handle = None

# Broadcast multicast handle
mc_fabric_handle_data = comm.bcast(
mc_fabric_handle.data if mc_fabric_handle else None, root=0
)
mc_fabric_handle_list = [mc_fabric_handle.data] if mc_fabric_handle else [None]
if self.group:
dist.broadcast_object_list(
mc_fabric_handle_list, group_src=0, group=self.group
)
else:
dist.broadcast_object_list(mc_fabric_handle_list, src=0)
mc_fabric_handle_data = mc_fabric_handle_list[0]

# Sync device to ensure broadcast is complete
cuda.cuCtxSynchronize()
# Import multicast handle for non-root ranks
Expand Down Expand Up @@ -830,6 +843,7 @@ def __init__(
group_rank: int,
device: torch.device,
mn_nvlink: bool = True,
group: dist.ProcessGroup = None,
):
"""
Constructor for McastGpuBuffer.
Expand All @@ -842,7 +856,7 @@ def __init__(
mn_nvlink: Flag indicating if multi-node NVLink is used
"""
self.mcast_device_memory = McastDeviceMemory(
buf_size, group_size, group_rank, device.index, mn_nvlink
buf_size, group_size, group_rank, device.index, mn_nvlink, group
)
self.buf_size = buf_size
self.local_device = device
Expand Down
12 changes: 4 additions & 8 deletions flashinfer/comm/trtllm_mnnvl_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import List, Optional, Tuple

import torch
from mpi4py import MPI
import torch.distributed as dist

from flashinfer.comm.mapping import Mapping

Expand All @@ -21,11 +21,6 @@
from .mnnvl import McastGPUBuffer


def mpi_barrier():
"""MPI barrier - could potentially be replaced with dist.barrier()"""
MPI.COMM_WORLD.Barrier()


def gen_trtllm_mnnvl_comm_module() -> JitSpec:
return gen_jit_spec(
"trtllm_mnnvl_comm",
Expand Down Expand Up @@ -132,7 +127,7 @@ def trtllm_mnnvl_rmsnorm(


def get_allreduce_mnnvl_workspace(
mapping: Mapping, dtype: torch.dtype
mapping: Mapping, dtype: torch.dtype, group: Optional[dist.ProcessGroup] = None
) -> Tuple[McastGPUBuffer, torch.Tensor, int]:
"""Get workspace buffers needed for multi-node NVLink all-reduce operation.

Expand Down Expand Up @@ -174,14 +169,15 @@ def get_allreduce_mnnvl_workspace(
mapping.tp_rank,
torch.device("cuda", mapping.local_rank),
mapping.is_multi_node() or force_mn,
group=group,
)

# Initialize the unicast buffer with -0.0
mcast_buffer.lamport_initialize(mapping.tp_rank, dtype)

# CPU barrier since we assume this should not be called in cuda graph
torch.cuda.synchronize()
mpi_barrier()
dist.barrier()

# This is a buffer to maintain the state of this allreduce Op
# [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter]
Expand Down
22 changes: 12 additions & 10 deletions flashinfer/jit/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,20 @@ def _get_workspace_dir_name() -> pathlib.Path:
FLASHINFER_WORKSPACE_DIR = _get_workspace_dir_name()
FLASHINFER_JIT_DIR = FLASHINFER_WORKSPACE_DIR / "cached_ops"
FLASHINFER_GEN_SRC_DIR = FLASHINFER_WORKSPACE_DIR / "generated"
_package_root = pathlib.Path(__file__).resolve().parents[1]
FLASHINFER_DATA = _package_root / "data"
FLASHINFER_INCLUDE_DIR = _package_root / "data" / "include"
FLASHINFER_CSRC_DIR = _package_root / "data" / "csrc"
# FLASHINFER_SRC_DIR = _package_root / "data" / "src"
FLASHINFER_TVM_BINDING_DIR = _package_root / "data" / "tvm_binding"
FLASHINFER_AOT_DIR = _package_root / "data" / "aot"
# TODO (pranavm): Check if this is right?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @pranavm-nvidia , we shouldn't change the logic here.

When you install flashinfer (using pip install), the data folder will be created:

[tool.setuptools.package-dir]
"flashinfer.data" = "."
"flashinfer.data.aot" = "build/aot-ops-package-dir"
"flashinfer.data.cutlass" = "3rdparty/cutlass"
"flashinfer.data.spdlog" = "3rdparty/spdlog"
[tool.setuptools.package-data]
"flashinfer.data" = [
"csrc/**",
"include/**",
"tvm_binding/**",
"version.txt"
]
"flashinfer.data.aot" = [
"**"
]
"flashinfer.data.cutlass" = [
"include/**",
"tools/util/include/**"
]
"flashinfer.data.spdlog" = [
"include/**",
]

all of the original paths such as _package_root / "include" or _package_root / "3rdparty" will be gone when you build the packages.

If you develop locally in editable mode, you will not see the difference because your package root is the repository path, but if you install flashinfer package on pypi on a new environment, you will find there is no include/3rdparty path in flashinfer package directory, and that why there is a data path designed to handle these source code dependencies.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I had a feeling it was a setup issue on my end. What's the correct way to test local changes then? Do I always need to reinstall the package when I make a change? I was so far just setting PYTHONPATH

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the correct way to test local changes then? Do I always need to reinstall the package when I make a change?

You can make editable installations:

pip install -e . -v

and you don't have to reinstall the package for source code changes (because it's editable).

# Why were these pointing to non-existent directories? Must be a copy missing somewhere?
_package_root = pathlib.Path(__file__).resolve().parents[2]
FLASHINFER_DATA = _package_root
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @yzh119

FLASHINFER_INCLUDE_DIR = _package_root / "include"
FLASHINFER_CSRC_DIR = _package_root / "csrc"
# FLASHINFER_SRC_DIR = _package_root / "src"
FLASHINFER_TVM_BINDING_DIR = _package_root / "tvm_binding"
FLASHINFER_AOT_DIR = _package_root / "aot"
CUTLASS_INCLUDE_DIRS = [
_package_root / "data" / "cutlass" / "include",
_package_root / "data" / "cutlass" / "tools" / "util" / "include",
_package_root / "3rdparty" / "cutlass" / "include",
_package_root / "3rdparty" / "cutlass" / "tools" / "util" / "include",
]
SPDLOG_INCLUDE_DIR = _package_root / "data" / "spdlog" / "include"
SPDLOG_INCLUDE_DIR = _package_root / "3rdparty" / "spdlog" / "include"


def get_nvshmem_include_dirs():
Expand Down
50 changes: 35 additions & 15 deletions tests/test_trtllm_mnnvl_allreduce.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
# Check torch version:
import os
import sys
import traceback

import pytest
import torch
from mpi4py import MPI # Added MPI import
import torch.distributed as dist

import flashinfer.comm as comm
import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar
from flashinfer.comm.mapping import Mapping
from flashinfer.comm.mnnvl import McastDeviceMemory, McastGPUBuffer

# Use flashinfer.norm.rmsnorm as reference implementation.
from flashinfer.norm import rmsnorm
Expand Down Expand Up @@ -42,7 +39,7 @@ def row_linear_residual_norm_fusion_forward(
tensor_parallel_size = mapping.tp_size
tensor_parallel_rank = mapping.tp_rank

MPI.COMM_WORLD.barrier()
dist.barrier()

def func(
input,
Expand Down Expand Up @@ -70,7 +67,7 @@ def func(
prenorm_output = torch.empty_like(residual)
normed_output = torch.empty_like(residual)

trtllm_mnnvl_ar.mpi_barrier()
dist.barrier()

trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_rmsnorm(
prenorm_output,
Expand Down Expand Up @@ -152,7 +149,7 @@ def func(
)


"""Main test function that runs on each MPI rank"""
"""Main test function that runs on each distributed rank"""


@pytest.mark.parametrize(
Expand All @@ -173,9 +170,29 @@ def test_mnnvl_allreduce_full(
):
monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce.

# Get MPI info
rank = MPI.COMM_WORLD.Get_rank()
world_size = MPI.COMM_WORLD.Get_size()
# Hack to get an address that all nodes in the slurm job can reach
master_addr = os.environ.get("MASTER_ADDR", None)
if master_addr is None:
# Use SLURM's first node as master
master_addr = os.environ.get("SLURM_NODELIST", "localhost").split(",")[0]
# Remove brackets if present (e.g., "node[01-02]" -> "node01")
if "[" in master_addr:
import re

master_addr = re.sub(r"\[.*\]", "0", master_addr)
master_port = os.environ.get("MASTER_PORT", "12345")

if not dist.is_initialized():
dist.init_process_group(
backend="gloo",
init_method=f"tcp://{master_addr}:{master_port}",
world_size=2,
rank=int(os.environ.get("SLURM_PROCID")),
)

# Get distributed info
rank = dist.get_rank()
world_size = dist.get_world_size()
gpus_per_node = torch.cuda.device_count()

if gpus_per_node == 0:
Expand All @@ -184,7 +201,9 @@ def test_mnnvl_allreduce_full(
# Ensure we have exactly 2 ranks for this test
if world_size < 2:
if rank == 0:
print(f"ERROR: This test requires at least 2 MPI ranks, got {world_size}")
print(
f"ERROR: This test requires at least 2 distributed ranks, got {world_size}"
)
sys.exit(1)

mapping = Mapping(
Expand Down Expand Up @@ -287,7 +306,7 @@ def test_mnnvl_allreduce_full(
)

# Synchronize before next test
trtllm_mnnvl_ar.mpi_barrier()
dist.barrier()

print(
f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}, dtype={dtype}"
Expand All @@ -298,7 +317,8 @@ def test_mnnvl_allreduce_full(
failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype} failed: {e}"
print(failure_message)
# Gather failure status from all ranks
all_failures = MPI.COMM_WORLD.allgather(rank_failed)
all_failures = [None for _ in range(world_size)]
dist.all_gather_object(all_failures, rank_failed)

# If any rank failed, fail the test
if any(all_failures):
Expand All @@ -308,12 +328,12 @@ def test_mnnvl_allreduce_full(

# Fail the test on all ranks
pytest.fail(f"Test failed on ranks {failed_ranks}")
trtllm_mnnvl_ar.mpi_barrier()
dist.barrier()

finally:
# Ensure cleanup happens for this list's workspace
if "mcast_buffer_mnnvl" in locals():
del mcast_buffer_mnnvl

# Final synchronization and check for failures across all ranks
trtllm_mnnvl_ar.mpi_barrier()
dist.barrier()