Skip to content

Commit 76f64e0

Browse files
authored
Mnnvl memory with custom communicator (#1245)
<!-- .github/pull_request_template.md --> ## 📌 Description Make Mnnvl memory initialization independent of mpi4py communicator. User/Framework can define any suitable wrapper of communicator(mostly torch.distributed) to fit in. This PR also removes the dependency of mpi4py when the communicator is not MPI(default). ## 🔍 Related Issues Depends by #1550 ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests pytest tests/test_mnnvl_custom_comm.py ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 984e95d commit 76f64e0

File tree

3 files changed

+268
-7
lines changed

3 files changed

+268
-7
lines changed

flashinfer/comm/mnnvl.py

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
import ctypes
1717
import logging
1818
import os
19+
from abc import ABC, abstractmethod
20+
from dataclasses import dataclass
1921
import platform
2022
import sys
21-
from typing import Any, Dict, List, Optional
23+
from typing import Any, Dict, List, Optional, TYPE_CHECKING
2224

2325
import torch
2426
from cuda import cuda
@@ -129,6 +131,22 @@ def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]:
129131
return device_ptr
130132

131133

134+
class CommBackend(ABC):
135+
"""Abstract communication backend interface"""
136+
137+
@abstractmethod
138+
def Get_rank(self) -> int: ...
139+
140+
@abstractmethod
141+
def Get_size(self) -> int: ...
142+
143+
@abstractmethod
144+
def allgather(self, data: int) -> List[int]: ...
145+
146+
@abstractmethod
147+
def Split(self, color: int, key: int) -> "CommBackend": ...
148+
149+
132150
if IS_BUILDING_DOCS:
133151
# Mock classes for building docs
134152

@@ -208,18 +226,66 @@ def supports_mnnvl() -> bool:
208226

209227
else:
210228
import pynvml
211-
from mpi4py import MPI
229+
230+
if TYPE_CHECKING:
231+
from mpi4py import MPI # noqa: F401
232+
233+
def lazy_import_mpi():
234+
"""Lazy import for mpi4py"""
235+
try:
236+
from mpi4py import MPI
237+
238+
return MPI
239+
except ImportError as err:
240+
raise ImportError("mpi4py is not installed") from err # type: ignore[no-redef]
212241

213242
class MpiComm: # type: ignore[no-redef]
214-
_comm: MPI.Intracomm = MPI.COMM_WORLD
243+
_comm: Any = None
244+
_MPI: Any = None
215245

216246
@classmethod
217-
def set_mpi_comm(cls, new_comm: MPI.Intracomm):
247+
def _get_mpi(cls):
248+
if cls._MPI is None:
249+
cls._MPI = lazy_import_mpi()
250+
cls._comm = cls._MPI.COMM_WORLD
251+
return cls._MPI
252+
253+
@classmethod
254+
def set_mpi_comm(cls, new_comm: Any):
255+
cls._get_mpi()
256+
# Optional: add type checking here
218257
cls._comm = new_comm
219258

220259
def __getattr__(self, name):
260+
if self._comm is None:
261+
self._get_mpi()
221262
return getattr(self._comm, name)
222263

264+
class MPIBackend(CommBackend):
265+
def __init__(self):
266+
self._mpicomm = MpiComm()
267+
268+
def Get_rank(self) -> int:
269+
return self._mpicomm.Get_rank()
270+
271+
def Get_size(self) -> int:
272+
return self._mpicomm.Get_size()
273+
274+
def allgather(self, data: int) -> List[int]:
275+
return self._mpicomm.allgather(data)
276+
277+
def Split(self, color: int, key: int) -> CommBackend:
278+
self._mpicomm = self._mpicomm.Split(color, key)
279+
return MPIBackend() # Returns new adapter
280+
281+
@dataclass
282+
class MnnvlConfig:
283+
"""Configuration for MNNVL memory management"""
284+
285+
comm_backend: Optional[CommBackend] = None
286+
allocation_granularity: int = 0
287+
fabric_page_size: int = 1 << 29 # 512MB
288+
223289
class MnnvlMemory: # type: ignore[no-redef]
224290
initialized: bool = False
225291

@@ -234,13 +300,15 @@ class MnnvlMemory: # type: ignore[no-redef]
234300
fabric_page_size: int = 1 << 29
235301

236302
# MPI communicator
237-
comm = None
303+
comm: Optional[CommBackend] = None
238304

239305
dev_id: int = None
240306

241307
allocated_map: Dict[int, Any] = {}
242308
address_refcnt: Dict[int, Any] = {}
243309

310+
config: Optional[MnnvlConfig] = None
311+
244312
def __init__(self, mapping: Mapping, size: int):
245313
self.mapping = mapping
246314
self.segment_size = size
@@ -275,6 +343,14 @@ def initialize():
275343
pynvml.nvmlInit()
276344
MnnvlMemory.initialized = True
277345

346+
@staticmethod
347+
def set_comm_from_config(mapping: Mapping, config: MnnvlConfig = None):
348+
MnnvlMemory.config = config or MnnvlConfig(comm_backend=MPIBackend()) # type: ignore[attr-defined]
349+
comm = config.comm_backend.Split(
350+
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
351+
)
352+
MnnvlMemory.comm = comm # type: ignore[assignment]
353+
278354
@staticmethod
279355
def get_comm(mapping: Mapping):
280356
if MnnvlMemory.comm is not None:

flashinfer/comm/trtllm_alltoall.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..jit import gen_jit_spec
2727
from ..utils import register_custom_op
2828
from .mapping import Mapping
29-
from .mnnvl import MnnvlMemory
29+
from .mnnvl import MnnvlMemory, MnnvlConfig
3030

3131

3232
def gen_comm_alltoall_module() -> JitSpec:
@@ -296,13 +296,15 @@ class MnnvlMoe:
296296
moe_mapping: Mapping = None
297297

298298
@staticmethod
299-
def get_moe_workspaces(mapping: Mapping):
299+
def get_moe_workspaces(mapping: Mapping, config: Optional[MnnvlConfig] = None):
300300
if MnnvlMoe.moe_workspace is not None:
301301
assert mapping == MnnvlMoe.moe_mapping, "only one moe mapping supported now"
302302
return MnnvlMoe.moe_workspace_tensor
303303

304304
MnnvlMoe.moe_mapping = mapping
305305
workspace_size_per_rank = get_moe_commworkspace_size_per_rank(mapping.tp_size)
306+
if config:
307+
MnnvlMemory.set_comm_from_config(mapping, config) # type: ignore[attr-defined]
306308
MnnvlMoe.moe_workspace = MnnvlMemory(mapping, workspace_size_per_rank)
307309
MnnvlMoe.moe_workspace_tensor = MnnvlMoe.moe_workspace.as_torch_strided_tensor(
308310
torch.uint64

tests/test_mnnvl_custom_comm.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import multiprocessing as mp
2+
import socket
3+
from typing import Any
4+
5+
import pytest
6+
import torch
7+
import torch.distributed as dist
8+
9+
import pynvml
10+
11+
from flashinfer.comm.mapping import Mapping
12+
from flashinfer.comm.mnnvl import MnnvlConfig, MnnvlMemory
13+
from flashinfer.comm.mnnvl import CommBackend as CommBackend
14+
15+
16+
pynvml.nvmlInit()
17+
18+
19+
class CustomCommunicator(CommBackend):
20+
def __init__(self, group):
21+
self._group = group
22+
23+
def Get_rank(self) -> int:
24+
return dist.get_rank(self._group)
25+
26+
def Get_size(self) -> int:
27+
return dist.get_world_size(self._group)
28+
29+
def allgather(self, data: int | bytes):
30+
device = f"cuda:{torch.cuda.current_device()}"
31+
if isinstance(data, int):
32+
local_tensor = torch.tensor([data], device=device, dtype=torch.int32)
33+
world_size = self.Get_size()
34+
gathered = [torch.zeros_like(local_tensor) for _ in range(world_size)]
35+
36+
dist.all_gather(gathered, local_tensor, group=self._group)
37+
return [int(x.item()) for x in gathered]
38+
39+
elif isinstance(data, bytes):
40+
local_tensor = torch.ByteTensor(list(data)).unsqueeze(0).to(device)
41+
world_size = self.Get_size()
42+
gathered = [data] * self.Get_size()
43+
dist.all_gather_object(gathered, data, group=self._group)
44+
return gathered
45+
else:
46+
raise TypeError(f"Unsupported type for allgather: {type(data)}")
47+
48+
def Split(self, color: int, key: int) -> "CustomCommunicator":
49+
return self
50+
51+
52+
def get_open_port() -> int:
53+
try:
54+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
55+
s.bind(("127.0.0.1", 0))
56+
return s.getsockname()[1]
57+
except OSError:
58+
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
59+
s.bind(("::1", 0))
60+
return s.getsockname()[1]
61+
62+
63+
def multi_process_parallel(
64+
world_size: int, dtype: torch.dtype, test_target: Any, target_args: tuple = ()
65+
) -> None:
66+
mp.set_start_method("spawn", force=True)
67+
68+
procs = []
69+
distributed_init_port = get_open_port()
70+
for i in range(world_size):
71+
proc_args = (world_size, i, dtype, distributed_init_port) + target_args
72+
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
73+
proc.start()
74+
procs.append(proc)
75+
76+
for i in range(world_size):
77+
procs[i].join()
78+
assert procs[i].exitcode == 0, (
79+
f"Process {i} failed with exit code {procs[i].exitcode}"
80+
)
81+
82+
83+
def align_memory(size: int):
84+
align_size = 2 * 1024 * 1024
85+
return (size + align_size - 1) // align_size * align_size
86+
87+
88+
def _init_mnnvl_memory(world_size, rank, dtype, distributed_init_port):
89+
device = torch.device(f"cuda:{rank}")
90+
torch.cuda.set_device(device)
91+
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
92+
dist.init_process_group(
93+
backend="nccl",
94+
init_method=distributed_init_method,
95+
rank=rank,
96+
world_size=world_size,
97+
)
98+
group = dist.group.WORLD
99+
100+
torch.cuda.set_device(rank)
101+
MnnvlMemory.initialize()
102+
mapping = Mapping(world_size, rank, world_size, tp_size=world_size)
103+
104+
allocate0_size = 4 * 1024 * 1024 - 3 * 1024
105+
mnnvl_config = MnnvlConfig(
106+
comm_backend=CustomCommunicator(group),
107+
fabric_page_size=1 << 29, # 512MB
108+
allocation_granularity=0, # Auto-detect
109+
)
110+
MnnvlMemory.set_comm_from_config(mapping, mnnvl_config)
111+
mnnvl_memory0 = MnnvlMemory(mapping, allocate0_size)
112+
allocate0_size_aligned = align_memory(allocate0_size)
113+
114+
assert MnnvlMemory.current_mem_offset == allocate0_size_aligned
115+
tensor0 = mnnvl_memory0.as_torch_strided_tensor(torch.int32)
116+
numel_per_rank = allocate0_size // 4
117+
tensor0[(rank + 1) % world_size] = torch.arange(
118+
start=rank, end=rank + numel_per_rank, device="cuda"
119+
)
120+
dist.barrier(group=group)
121+
for r in range(world_size):
122+
torch.equal(
123+
tensor0[(r + 1) % world_size],
124+
torch.arange(start=r, end=r + numel_per_rank, device="cuda"),
125+
)
126+
127+
allocate1_size = 30 * 1024 * 1024 - 2 * 1024
128+
mnnvl_memory1 = MnnvlMemory(mapping, allocate1_size)
129+
allocate1_size_aligned = align_memory(allocate1_size)
130+
assert (
131+
MnnvlMemory.current_mem_offset
132+
== allocate0_size_aligned + allocate1_size_aligned
133+
)
134+
tensor1 = mnnvl_memory1.as_torch_strided_tensor(torch.float32)
135+
numel_per_rank = allocate1_size // 4
136+
tensor1[(rank + 5) % world_size] = torch.arange(
137+
start=rank,
138+
end=rank + numel_per_rank,
139+
dtype=torch.float32,
140+
device="cuda",
141+
)
142+
dist.barrier(group=group)
143+
for r in range(world_size):
144+
torch.equal(
145+
tensor1[(r + 5) % world_size],
146+
torch.arange(
147+
start=r, end=r + numel_per_rank, dtype=torch.float32, device="cuda"
148+
),
149+
)
150+
dist.barrier(group=group)
151+
del tensor0, mnnvl_memory0
152+
dist.barrier(group=group)
153+
154+
large_allocation2_size = 768 * 1024 * 1024
155+
large_mnnvl_memory2 = MnnvlMemory(mapping, large_allocation2_size)
156+
allocate2_size_aligned = align_memory(large_allocation2_size)
157+
assert MnnvlMemory.current_mem_offset == allocate2_size_aligned
158+
assert large_mnnvl_memory2.rank_stride == (1 << 30)
159+
160+
del tensor1
161+
162+
163+
@pytest.mark.skipif(
164+
not MnnvlMemory.supports_mnnvl(),
165+
reason="Mnnvl memory is not supported on this platform",
166+
)
167+
@pytest.mark.parametrize("world_size", [2, 4])
168+
def test_mnnvl_custom_communicator(world_size):
169+
dtype = torch.float16
170+
available_gpus = torch.cuda.device_count()
171+
if world_size > available_gpus:
172+
raise ValueError(
173+
f"world_size {world_size} is greater than available_gpus {available_gpus}"
174+
)
175+
print(f"Running test for world_size={world_size}")
176+
177+
multi_process_parallel(
178+
world_size,
179+
dtype,
180+
_init_mnnvl_memory,
181+
target_args=(),
182+
)
183+
print(f"custom mnnvl communicator world_size = {world_size}: OK")

0 commit comments

Comments
 (0)