Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 81 additions & 5 deletions flashinfer/comm/mnnvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import ctypes
import logging
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
import platform
import sys
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TYPE_CHECKING

import torch
from cuda import cuda
Expand Down Expand Up @@ -129,6 +131,22 @@ def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]:
return device_ptr


class CommBackend(ABC):
"""Abstract communication backend interface"""

@abstractmethod
def Get_rank(self) -> int: ...

@abstractmethod
def Get_size(self) -> int: ...

@abstractmethod
def allgather(self, data: int) -> List[int]: ...

@abstractmethod
def Split(self, color: int, key: int) -> "CommBackend": ...


if IS_BUILDING_DOCS:
# Mock classes for building docs

Expand Down Expand Up @@ -208,18 +226,66 @@ def supports_mnnvl() -> bool:

else:
import pynvml
from mpi4py import MPI

if TYPE_CHECKING:
from mpi4py import MPI # noqa: F401

def lazy_import_mpi():
"""Lazy import for mpi4py"""
try:
from mpi4py import MPI

return MPI
except ImportError as err:
raise ImportError("mpi4py is not installed") from err # type: ignore[no-redef]

class MpiComm: # type: ignore[no-redef]
_comm: MPI.Intracomm = MPI.COMM_WORLD
_comm: Any = None
_MPI: Any = None

@classmethod
def set_mpi_comm(cls, new_comm: MPI.Intracomm):
def _get_mpi(cls):
if cls._MPI is None:
cls._MPI = lazy_import_mpi()
cls._comm = cls._MPI.COMM_WORLD
return cls._MPI

@classmethod
def set_mpi_comm(cls, new_comm: Any):
cls._get_mpi()
# Optional: add type checking here
cls._comm = new_comm

def __getattr__(self, name):
if self._comm is None:
self._get_mpi()
return getattr(self._comm, name)

class MPIBackend(CommBackend):
def __init__(self):
self._mpicomm = MpiComm()

def Get_rank(self) -> int:
return self._mpicomm.Get_rank()

def Get_size(self) -> int:
return self._mpicomm.Get_size()

def allgather(self, data: int) -> List[int]:
return self._mpicomm.allgather(data)

def Split(self, color: int, key: int) -> CommBackend:
self._mpicomm = self._mpicomm.Split(color, key)
return MPIBackend() # Returns new adapter

@dataclass
class MnnvlConfig:
"""Configuration for MNNVL memory management"""

comm_backend: Optional[CommBackend] = None
allocation_granularity: int = 0
fabric_page_size: int = 1 << 29 # 512MB

class MnnvlMemory: # type: ignore[no-redef]
initialized: bool = False

Expand All @@ -234,13 +300,15 @@ class MnnvlMemory: # type: ignore[no-redef]
fabric_page_size: int = 1 << 29

# MPI communicator
comm = None
comm: Optional[CommBackend] = None

dev_id: int = None

allocated_map: Dict[int, Any] = {}
address_refcnt: Dict[int, Any] = {}

config: Optional[MnnvlConfig] = None

def __init__(self, mapping: Mapping, size: int):
self.mapping = mapping
self.segment_size = size
Expand Down Expand Up @@ -275,6 +343,14 @@ def initialize():
pynvml.nvmlInit()
MnnvlMemory.initialized = True

@staticmethod
def set_comm_from_config(mapping: Mapping, config: MnnvlConfig = None):
MnnvlMemory.config = config or MnnvlConfig(comm_backend=MPIBackend()) # type: ignore[attr-defined]
comm = config.comm_backend.Split(
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
)
MnnvlMemory.comm = comm # type: ignore[assignment]

@staticmethod
def get_comm(mapping: Mapping):
if MnnvlMemory.comm is not None:
Expand Down
6 changes: 4 additions & 2 deletions flashinfer/comm/trtllm_alltoall.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..jit import gen_jit_spec
from ..utils import register_custom_op
from .mapping import Mapping
from .mnnvl import MnnvlMemory
from .mnnvl import MnnvlMemory, MnnvlConfig


def gen_comm_alltoall_module() -> JitSpec:
Expand Down Expand Up @@ -296,13 +296,15 @@ class MnnvlMoe:
moe_mapping: Mapping = None

@staticmethod
def get_moe_workspaces(mapping: Mapping):
def get_moe_workspaces(mapping: Mapping, config: Optional[MnnvlConfig] = None):
if MnnvlMoe.moe_workspace is not None:
assert mapping == MnnvlMoe.moe_mapping, "only one moe mapping supported now"
return MnnvlMoe.moe_workspace_tensor

MnnvlMoe.moe_mapping = mapping
workspace_size_per_rank = get_moe_commworkspace_size_per_rank(mapping.tp_size)
if config:
MnnvlMemory.set_comm_from_config(mapping, config) # type: ignore[attr-defined]
MnnvlMoe.moe_workspace = MnnvlMemory(mapping, workspace_size_per_rank)
MnnvlMoe.moe_workspace_tensor = MnnvlMoe.moe_workspace.as_torch_strided_tensor(
torch.uint64
Expand Down
183 changes: 183 additions & 0 deletions tests/test_mnnvl_custom_comm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import multiprocessing as mp
import socket
from typing import Any

import pytest
import torch
import torch.distributed as dist

import pynvml

from flashinfer.comm.mapping import Mapping
from flashinfer.comm.mnnvl import MnnvlConfig, MnnvlMemory
from flashinfer.comm.mnnvl import CommBackend as CommBackend


pynvml.nvmlInit()


class CustomCommunicator(CommBackend):
def __init__(self, group):
self._group = group

def Get_rank(self) -> int:
return dist.get_rank(self._group)

def Get_size(self) -> int:
return dist.get_world_size(self._group)

def allgather(self, data: int | bytes):
device = f"cuda:{torch.cuda.current_device()}"
if isinstance(data, int):
local_tensor = torch.tensor([data], device=device, dtype=torch.int32)
world_size = self.Get_size()
gathered = [torch.zeros_like(local_tensor) for _ in range(world_size)]

dist.all_gather(gathered, local_tensor, group=self._group)
return [int(x.item()) for x in gathered]

elif isinstance(data, bytes):
local_tensor = torch.ByteTensor(list(data)).unsqueeze(0).to(device)
world_size = self.Get_size()
gathered = [data] * self.Get_size()
dist.all_gather_object(gathered, data, group=self._group)
return gathered
else:
raise TypeError(f"Unsupported type for allgather: {type(data)}")

def Split(self, color: int, key: int) -> "CustomCommunicator":
return self


def get_open_port() -> int:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
except OSError:
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("::1", 0))
return s.getsockname()[1]


def multi_process_parallel(
world_size: int, dtype: torch.dtype, test_target: Any, target_args: tuple = ()
) -> None:
mp.set_start_method("spawn", force=True)

procs = []
distributed_init_port = get_open_port()
for i in range(world_size):
proc_args = (world_size, i, dtype, distributed_init_port) + target_args
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
proc.start()
procs.append(proc)

for i in range(world_size):
procs[i].join()
assert procs[i].exitcode == 0, (
f"Process {i} failed with exit code {procs[i].exitcode}"
)


def align_memory(size: int):
align_size = 2 * 1024 * 1024
return (size + align_size - 1) // align_size * align_size


def _init_mnnvl_memory(world_size, rank, dtype, distributed_init_port):
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group(
backend="nccl",
init_method=distributed_init_method,
rank=rank,
world_size=world_size,
)
group = dist.group.WORLD

torch.cuda.set_device(rank)
MnnvlMemory.initialize()
mapping = Mapping(world_size, rank, world_size, tp_size=world_size)

allocate0_size = 4 * 1024 * 1024 - 3 * 1024
mnnvl_config = MnnvlConfig(
comm_backend=CustomCommunicator(group),
fabric_page_size=1 << 29, # 512MB
allocation_granularity=0, # Auto-detect
)
MnnvlMemory.set_comm_from_config(mapping, mnnvl_config)
mnnvl_memory0 = MnnvlMemory(mapping, allocate0_size)
allocate0_size_aligned = align_memory(allocate0_size)

assert MnnvlMemory.current_mem_offset == allocate0_size_aligned
tensor0 = mnnvl_memory0.as_torch_strided_tensor(torch.int32)
numel_per_rank = allocate0_size // 4
tensor0[(rank + 1) % world_size] = torch.arange(
start=rank, end=rank + numel_per_rank, device="cuda"
)
dist.barrier(group=group)
for r in range(world_size):
torch.equal(
tensor0[(r + 1) % world_size],
torch.arange(start=r, end=r + numel_per_rank, device="cuda"),
)

allocate1_size = 30 * 1024 * 1024 - 2 * 1024
mnnvl_memory1 = MnnvlMemory(mapping, allocate1_size)
allocate1_size_aligned = align_memory(allocate1_size)
assert (
MnnvlMemory.current_mem_offset
== allocate0_size_aligned + allocate1_size_aligned
)
tensor1 = mnnvl_memory1.as_torch_strided_tensor(torch.float32)
numel_per_rank = allocate1_size // 4
tensor1[(rank + 5) % world_size] = torch.arange(
start=rank,
end=rank + numel_per_rank,
dtype=torch.float32,
device="cuda",
)
dist.barrier(group=group)
for r in range(world_size):
torch.equal(
tensor1[(r + 5) % world_size],
torch.arange(
start=r, end=r + numel_per_rank, dtype=torch.float32, device="cuda"
),
)
dist.barrier(group=group)
del tensor0, mnnvl_memory0
dist.barrier(group=group)

large_allocation2_size = 768 * 1024 * 1024
large_mnnvl_memory2 = MnnvlMemory(mapping, large_allocation2_size)
allocate2_size_aligned = align_memory(large_allocation2_size)
assert MnnvlMemory.current_mem_offset == allocate2_size_aligned
assert large_mnnvl_memory2.rank_stride == (1 << 30)

del tensor1


@pytest.mark.skipif(
not MnnvlMemory.supports_mnnvl(),
reason="Mnnvl memory is not supported on this platform",
)
@pytest.mark.parametrize("world_size", [2, 4])
def test_mnnvl_custom_communicator(world_size):
dtype = torch.float16
available_gpus = torch.cuda.device_count()
if world_size > available_gpus:
raise ValueError(
f"world_size {world_size} is greater than available_gpus {available_gpus}"
)
print(f"Running test for world_size={world_size}")

multi_process_parallel(
world_size,
dtype,
_init_mnnvl_memory,
target_args=(),
)
print(f"custom mnnvl communicator world_size = {world_size}: OK")