Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 83 additions & 5 deletions flashinfer/comm/mnnvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import ctypes
import logging
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
import platform
import sys
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, TYPE_CHECKING

import torch
from cuda import cuda
Expand Down Expand Up @@ -208,18 +210,86 @@ def supports_mnnvl() -> bool:

else:
import pynvml
from mpi4py import MPI
if TYPE_CHECKING:
from mpi4py import MPI

def lazy_import_mpi():
"""Lazy import for mpi4py"""
try:
from mpi4py import MPI
return MPI
except ImportError:
raise ImportError("mpi4py is not installed")

class MpiComm: # type: ignore[no-redef]
_comm: MPI.Intracomm = MPI.COMM_WORLD
# _comm: MPI.Intracomm = MPI.COMM_WORLD
_comm: Any = None
_MPI: Any = None

@classmethod
def set_mpi_comm(cls, new_comm: MPI.Intracomm):
cls._comm = new_comm
def _get_mpi(cls):
if cls._MPI is None:
cls._MPI = lazy_import_mpi()
cls._comm = cls._MPI.COMM_WORLD
return cls._MPI

@classmethod
def set_mpi_comm(cls, new_comm: Any):
cls._get_mpi()
# Optional: add type checking here
cls._comm = new_comm

def __getattr__(self, name):
if self._comm is None:
self._get_mpi()
return getattr(self._comm, name)

class CommBackend(ABC):
"""Abstract communication backend interface"""
@abstractmethod
def Get_rank(self) -> int: ...

@abstractmethod
def Get_size(self) -> int: ...

@abstractmethod
def allgather(self, data: int) -> List[int]: ...

@abstractmethod
def allgather_bytes(self, data): ...

@abstractmethod
def Split(self, color: int, key: int) -> 'CommBackend': ...

class LegacyMPIBackend(CommBackend):
def __init__(self):
self._mpicomm = MpiComm()

def Get_rank(self) -> int:
return self._mpicomm.Get_rank()

def Get_size(self) -> int:
return self._mpicomm.Get_size()

def allgather(self, data: int) -> List[int]:
return self._mpicomm.allgather(data)

def allgather_bytes(self, data):
return self._mpicomm.allgather(data)

def Split(self, color: int, key: int) -> CommBackend:
# Original split logic
new_comm = self._mpicomm.Split(color, key)
return LegacyMPIBackend() # Returns new adapter

@dataclass
class MnnvlConfig:
"""Configuration for MNNVL memory management"""
comm_backend: Optional[CommBackend] = None
allocation_granularity: int = 0
fabric_page_size: int = 1 << 29 # 512MB


class MnnvlMemory: # type: ignore[no-redef]
initialized: bool = False

Expand Down Expand Up @@ -275,6 +345,14 @@ def initialize():
pynvml.nvmlInit()
MnnvlMemory.initialized = True

@staticmethod
def set_comm(mapping: Mapping, config: MnnvlConfig = None):
MnnvlMemory._config = config or MnnvlConfig(comm_backend=LegacyMPIBackend())
comm = config.comm_backend.Split(
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
)
MnnvlMemory.comm = comm

@staticmethod
def get_comm(mapping: Mapping):
if MnnvlMemory.comm is not None:
Expand Down
6 changes: 4 additions & 2 deletions flashinfer/comm/trtllm_alltoall.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..jit import gen_jit_spec
from ..utils import register_custom_op
from .mapping import Mapping
from .mnnvl import MnnvlMemory
from .mnnvl import (MnnvlMemory, MnnvlConfig)


def gen_comm_alltoall_module() -> JitSpec:
Expand Down Expand Up @@ -296,13 +296,15 @@ class MnnvlMoe:
moe_mapping: Mapping = None

@staticmethod
def get_moe_workspaces(mapping: Mapping):
def get_moe_workspaces(mapping: Mapping, config: Optional[MnnvlConfig]=None):
if MnnvlMoe.moe_workspace is not None:
assert mapping == MnnvlMoe.moe_mapping, "only one moe mapping supported now"
return MnnvlMoe.moe_workspace_tensor

MnnvlMoe.moe_mapping = mapping
workspace_size_per_rank = get_moe_commworkspace_size_per_rank(mapping.tp_size)
if config:
MnnvlMemory.set_comm(mapping, config)
MnnvlMoe.moe_workspace = MnnvlMemory(mapping, workspace_size_per_rank)
MnnvlMoe.moe_workspace_tensor = MnnvlMoe.moe_workspace.as_torch_strided_tensor(
torch.uint64
Expand Down
Loading