|
16 | 16 | import ctypes
|
17 | 17 | import logging
|
18 | 18 | import os
|
| 19 | +from abc import ABC, abstractmethod |
| 20 | +from dataclasses import dataclass |
19 | 21 | import platform
|
20 | 22 | import sys
|
21 | 23 | from typing import Any, Dict, List, Optional
|
@@ -220,6 +222,51 @@ def set_mpi_comm(cls, new_comm: MPI.Intracomm):
|
220 | 222 | def __getattr__(self, name):
|
221 | 223 | return getattr(self._comm, name)
|
222 | 224 |
|
| 225 | + class CommBackend(ABC): |
| 226 | + """Abstract communication backend interface""" |
| 227 | + @abstractmethod |
| 228 | + def Get_rank(self) -> int: ... |
| 229 | + |
| 230 | + @abstractmethod |
| 231 | + def Get_size(self) -> int: ... |
| 232 | + |
| 233 | + @abstractmethod |
| 234 | + def allgather(self, data: int) -> List[int]: ... |
| 235 | + |
| 236 | + @abstractmethod |
| 237 | + def allgather_bytes(self, data): ... |
| 238 | + |
| 239 | + @abstractmethod |
| 240 | + def Split(self, color: int, key: int) -> 'CommBackend': ... |
| 241 | + class LegacyMPIBackend(CommBackend): |
| 242 | + """Adapter for the original MpiComm singleton pattern""" |
| 243 | + def __init__(self): |
| 244 | + self._mpicomm = MpiComm() |
| 245 | + |
| 246 | + def Get_rank(self) -> int: |
| 247 | + return self._mpicomm.Get_rank() |
| 248 | + |
| 249 | + def Get_size(self) -> int: |
| 250 | + return self._mpicomm.Get_size() |
| 251 | + |
| 252 | + def allgather(self, data: int) -> List[int]: |
| 253 | + return self._mpicomm.allgather(data) |
| 254 | + |
| 255 | + def allgather_bytes(self, data): |
| 256 | + return self._mpicomm.allgather(data) |
| 257 | + |
| 258 | + def Split(self, color: int, key: int) -> CommBackend: |
| 259 | + # Original split logic |
| 260 | + new_comm = self._mpicomm.Split(color, key) |
| 261 | + return LegacyMPIBackend() # Returns new adapter |
| 262 | + @dataclass |
| 263 | + class MnnvlConfig: |
| 264 | + """Configuration for MNNVL memory management""" |
| 265 | + comm_backend: Optional[CommBackend] = None |
| 266 | + allocation_granularity: int = 0 |
| 267 | + fabric_page_size: int = 1 << 29 # 512MB |
| 268 | + |
| 269 | + |
223 | 270 | class MnnvlMemory: # type: ignore[no-redef]
|
224 | 271 | initialized: bool = False
|
225 | 272 |
|
@@ -275,6 +322,17 @@ def initialize():
|
275 | 322 | pynvml.nvmlInit()
|
276 | 323 | MnnvlMemory.initialized = True
|
277 | 324 |
|
| 325 | + @staticmethod |
| 326 | + def set_comm(mapping: Mapping, config: MnnvlConfig = None): |
| 327 | + # print("set_comm"*10) |
| 328 | + # print(f"config:{config}, tp_rank:{mapping.tp_rank}") |
| 329 | + MnnvlMemory._config = config or MnnvlConfig(comm_backend=LegacyMPIBackend()) |
| 330 | + comm0 = config.comm_backend |
| 331 | + comm = comm0.Split( |
| 332 | + mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank |
| 333 | + ) |
| 334 | + MnnvlMemory.comm = comm |
| 335 | + |
278 | 336 | @staticmethod
|
279 | 337 | def get_comm(mapping: Mapping):
|
280 | 338 | if MnnvlMemory.comm is not None:
|
|
0 commit comments