Skip to content

Commit 0189214

Browse files
committed
Add alternative comm backend for mnnvl
1 parent 111e3d4 commit 0189214

File tree

2 files changed

+67
-3
lines changed

2 files changed

+67
-3
lines changed

flashinfer/comm/mnnvl.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import ctypes
1717
import logging
1818
import os
19+
from abc import ABC, abstractmethod
20+
from dataclasses import dataclass
1921
import platform
2022
import sys
2123
from typing import Any, Dict, List, Optional
@@ -220,6 +222,51 @@ def set_mpi_comm(cls, new_comm: MPI.Intracomm):
220222
def __getattr__(self, name):
221223
return getattr(self._comm, name)
222224

225+
class CommBackend(ABC):
226+
"""Abstract communication backend interface"""
227+
@abstractmethod
228+
def Get_rank(self) -> int: ...
229+
230+
@abstractmethod
231+
def Get_size(self) -> int: ...
232+
233+
@abstractmethod
234+
def allgather(self, data: int) -> List[int]: ...
235+
236+
@abstractmethod
237+
def allgather_bytes(self, data): ...
238+
239+
@abstractmethod
240+
def Split(self, color: int, key: int) -> 'CommBackend': ...
241+
class LegacyMPIBackend(CommBackend):
242+
"""Adapter for the original MpiComm singleton pattern"""
243+
def __init__(self):
244+
self._mpicomm = MpiComm()
245+
246+
def Get_rank(self) -> int:
247+
return self._mpicomm.Get_rank()
248+
249+
def Get_size(self) -> int:
250+
return self._mpicomm.Get_size()
251+
252+
def allgather(self, data: int) -> List[int]:
253+
return self._mpicomm.allgather(data)
254+
255+
def allgather_bytes(self, data):
256+
return self._mpicomm.allgather(data)
257+
258+
def Split(self, color: int, key: int) -> CommBackend:
259+
# Original split logic
260+
new_comm = self._mpicomm.Split(color, key)
261+
return LegacyMPIBackend() # Returns new adapter
262+
@dataclass
263+
class MnnvlConfig:
264+
"""Configuration for MNNVL memory management"""
265+
comm_backend: Optional[CommBackend] = None
266+
allocation_granularity: int = 0
267+
fabric_page_size: int = 1 << 29 # 512MB
268+
269+
223270
class MnnvlMemory: # type: ignore[no-redef]
224271
initialized: bool = False
225272

@@ -275,6 +322,17 @@ def initialize():
275322
pynvml.nvmlInit()
276323
MnnvlMemory.initialized = True
277324

325+
@staticmethod
326+
def set_comm(mapping: Mapping, config: MnnvlConfig = None):
327+
# print("set_comm"*10)
328+
# print(f"config:{config}, tp_rank:{mapping.tp_rank}")
329+
MnnvlMemory._config = config or MnnvlConfig(comm_backend=LegacyMPIBackend())
330+
comm0 = config.comm_backend
331+
comm = comm0.Split(
332+
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
333+
)
334+
MnnvlMemory.comm = comm
335+
278336
@staticmethod
279337
def get_comm(mapping: Mapping):
280338
if MnnvlMemory.comm is not None:

flashinfer/comm/trtllm_alltoall.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..jit import gen_jit_spec
2727
from ..utils import register_custom_op
2828
from .mapping import Mapping
29-
from .mnnvl import MnnvlMemory
29+
from .mnnvl import (MnnvlMemory, MnnvlConfig)
3030

3131

3232
def gen_comm_alltoall_module() -> JitSpec:
@@ -389,27 +389,33 @@ class MnnvlMoe:
389389
moe_mapping: Mapping = None
390390

391391
@staticmethod
392-
def get_moe_workspaces(mapping: Mapping):
392+
def get_moe_workspaces(mapping: Mapping, config: Optional[MnnvlConfig] = None):
393393
if MnnvlMoe.moe_workspace is not None:
394394
assert mapping == MnnvlMoe.moe_mapping, "only one moe mapping supported now"
395395
return MnnvlMoe.moe_workspace_tensor
396396

397397
MnnvlMoe.moe_mapping = mapping
398398
workspace_size_per_rank = get_moe_commworkspace_size_per_rank(mapping.tp_size)
399+
if config:
400+
MnnvlMemory.set_comm(mapping, config)
401+
MnnvlMemory.initialize()
399402
MnnvlMoe.moe_workspace = MnnvlMemory(mapping, workspace_size_per_rank)
400403
MnnvlMoe.moe_workspace_tensor = MnnvlMoe.moe_workspace.as_torch_strided_tensor(
401404
torch.uint64
402405
)
403406
return MnnvlMoe.moe_workspace_tensor
404407

405408
@staticmethod
406-
def get_moe_prepare_workspace(mapping: Mapping):
409+
def get_moe_prepare_workspace(mapping: Mapping, config: Optional[MnnvlConfig] = None):
407410
if MnnvlMoe.moe_prepare_workspace_tensor is not None:
408411
assert mapping == MnnvlMoe.moe_mapping, "only one moe mapping supported now"
409412
return MnnvlMoe.moe_prepare_workspace_tensor
410413
workspace_size_per_rank = get_moe_prepare_workspace_size_per_rank(
411414
mapping.tp_size
412415
)
416+
if config:
417+
MnnvlMemory.set_comm(mapping, config)
418+
MnnvlMemory.initialize()
413419
MnnvlMoe.moe_prepare_workspace = MnnvlMemory(mapping, workspace_size_per_rank)
414420
MnnvlMoe.moe_prepare_workspace_tensor = (
415421
MnnvlMoe.moe_prepare_workspace.as_torch_strided_tensor(torch.uint64)

0 commit comments

Comments
 (0)