Skip to content

Commit 603e527

Browse files
ZhangLirong-amdroot
andauthored
Support mori in aiter (ROCm#1453)
* support mori in aiter * format with coplit * add prepare comm * add gpu_per_node * format some code --------- Co-authored-by: root <[email protected]>
1 parent 9691959 commit 603e527

File tree

5 files changed

+504
-10
lines changed

5 files changed

+504
-10
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import torch
2+
import importlib.util
3+
from .base_device_communicator import All2AllManagerBase, Cache
4+
from functools import cache
5+
from aiter import logger
6+
7+
8+
@cache
9+
def _has_module(module_name: str) -> bool:
10+
"""Return True if *module_name* can be found in the current environment.
11+
The result is cached so that subsequent queries for the same module incur
12+
no additional overhead.
13+
"""
14+
return importlib.util.find_spec(module_name) is not None
15+
16+
17+
def has_mori() -> bool:
18+
"""Whether the optional `mori` package is available."""
19+
return _has_module("mori")
20+
21+
22+
class MoriAll2AllManager(All2AllManagerBase):
23+
def __init__(self, cpu_group):
24+
assert has_mori(), (
25+
"MoRI kernels not found. Please follow https://github.com/ROCm/mori/blob/main/README.md"
26+
" to install MoRI kernels."
27+
) # noqa
28+
import mori
29+
30+
super().__init__(cpu_group)
31+
self.handle_cache = Cache()
32+
33+
torch._C._distributed_c10d._register_process_group("mori", cpu_group)
34+
mori.shmem.shmem_torch_process_group_init("mori")
35+
36+
def _make_all2all_kwargs(
37+
self,
38+
rank: int,
39+
num_ep_ranks: int,
40+
input_dtype: torch.dtype,
41+
quant_dtype: torch.dtype,
42+
token_hidden_size: int,
43+
scale_dim: int,
44+
scale_type_size: int,
45+
max_num_tokens_per_dp_rank: int,
46+
num_local_experts: int,
47+
num_experts_per_token: int,
48+
gpu_per_node: int,
49+
):
50+
import mori # type: ignore[import-not-found]
51+
52+
if not self.internode:
53+
# single node
54+
kernel_type = mori.ops.EpDispatchCombineKernelType.IntraNode
55+
warp_num_per_block = 16
56+
block_num = 80
57+
rdma_block_num = 0
58+
else:
59+
# multi node
60+
kernel_type = mori.ops.EpDispatchCombineKernelType.InterNodeV1
61+
warp_num_per_block = 16
62+
block_num = 32
63+
rdma_block_num = 16
64+
65+
return dict(
66+
rank=rank,
67+
world_size=num_ep_ranks,
68+
data_type=quant_dtype,
69+
hidden_dim=token_hidden_size,
70+
scale_dim=scale_dim,
71+
scale_type_size=scale_type_size,
72+
max_token_type_size=input_dtype.itemsize,
73+
max_num_inp_token_per_rank=max_num_tokens_per_dp_rank,
74+
num_experts_per_rank=num_local_experts,
75+
num_experts_per_token=num_experts_per_token,
76+
warp_num_per_block=warp_num_per_block,
77+
block_num=block_num,
78+
kernel_type=kernel_type,
79+
rdma_block_num=rdma_block_num,
80+
gpu_per_node=gpu_per_node,
81+
)
82+
83+
def _make_handle(self, **kwargs):
84+
import mori # type: ignore[import-not-found]
85+
86+
mori_config = mori.ops.EpDispatchCombineConfig(**kwargs)
87+
handle = mori.ops.EpDispatchCombineOp(mori_config)
88+
return handle
89+
90+
def get_handle(self, kwargs):
91+
import mori # type: ignore[import-not-found]
92+
93+
mori_kwargs = self._make_all2all_kwargs(**kwargs)
94+
logger.debug("MoRI all2all args %s", mori_kwargs)
95+
handle: mori.ops.EpDispatchCombineOp = self.handle_cache.get_or_create(
96+
mori_kwargs, self._make_handle
97+
)
98+
return handle

aiter/dist/device_communicators/base_device_communicator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,9 @@ def __init__(
123123
# all2all_backend = config.parallel_config.all2all_backend
124124

125125
self.is_ep_communicator = "ep" in unique_name
126-
self.use_all2all = self.is_ep_communicator and use_ep
127-
self.all2all_backend = all2all_backend
126+
self.use_all2all = self.is_ep_communicator
127+
# self.all2all_backend = all2all_backend
128+
self.all2all_backend = "mori"
128129
self.all2all_manager: All2AllManagerBase | None = None
129130

130131
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:

aiter/dist/device_communicators/communicator_cuda.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def __init__(
1919
device_group: ProcessGroup | None = None,
2020
unique_name: str = "",
2121
):
22+
self._all2all_manager = None
23+
self._all2all_manager_created = False
24+
2225
super().__init__(cpu_group, device, device_group, unique_name)
2326
if "tp" not in unique_name:
2427
# custom allreduce or torch symm mem can be used only by tp
@@ -84,39 +87,57 @@ def __init__(
8487
# # currently be an MI300 series.
8588
self.qr_comm = QuickAllReduce(group=self.cpu_group, device=self.device)
8689

87-
if self.use_all2all:
90+
@property
91+
def all2all_manager(self):
92+
# Lazily create all2all_manager to avoid tp/dp/ep group haven't been created yet
93+
if not self._all2all_manager_created and self.use_all2all:
94+
self._all2all_manager_created = True
95+
8896
if self.all2all_backend == "naive":
8997
from .all2all import NaiveAll2AllManager
9098

91-
self.all2all_manager = NaiveAll2AllManager(self.cpu_group)
99+
self._all2all_manager = NaiveAll2AllManager(self.cpu_group)
92100
elif self.all2all_backend == "allgather_reducescatter":
93101
from .all2all import AgRsAll2AllManager
94102

95-
self.all2all_manager = AgRsAll2AllManager(self.cpu_group)
103+
self._all2all_manager = AgRsAll2AllManager(self.cpu_group)
96104
elif self.all2all_backend == "pplx":
97105
from .all2all import PPLXAll2AllManager
98106

99-
self.all2all_manager = PPLXAll2AllManager(self.cpu_group)
107+
self._all2all_manager = PPLXAll2AllManager(self.cpu_group)
100108
elif self.all2all_backend == "deepep_high_throughput":
101109
from .all2all import DeepEPHTAll2AllManager
102110

103-
self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
111+
self._all2all_manager = DeepEPHTAll2AllManager(self.cpu_group)
104112
elif self.all2all_backend == "deepep_low_latency":
105113
from .all2all import DeepEPLLAll2AllManager
106114

107-
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
115+
self._all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
116+
elif self.all2all_backend == "mori":
117+
from .all2all import MoriAll2AllManager
118+
119+
self._all2all_manager = MoriAll2AllManager(self.cpu_group)
108120
elif self.all2all_backend == "flashinfer_all2allv":
109121
from .all2all import FlashInferAllToAllManager
110122

111-
self.all2all_manager = FlashInferAllToAllManager(self.cpu_group)
123+
self._all2all_manager = FlashInferAllToAllManager(self.cpu_group)
112124
else:
113125
raise ValueError(f"Unknown all2all backend: {self.all2all_backend}")
114126

115127
if is_global_first_rank():
116128
logger.info(
117129
"Using %s all2all manager.",
118-
self.all2all_manager.__class__.__name__,
130+
self._all2all_manager.__class__.__name__,
119131
)
132+
# if self._all2all_manager is None:
133+
# raise ValueError(f"all2all_manager is None for {self.unique_name}")
134+
return self._all2all_manager
135+
136+
@all2all_manager.setter
137+
def all2all_manager(self, value):
138+
self._all2all_manager = value
139+
if value is not None:
140+
self._all2all_manager_created = True
120141

121142
def all_reduce(
122143
self, input_, use_new: bool = False, ca_fp8_quant: bool = False

aiter/dist/parallel_state.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,10 @@ def recv(
814814
torch.distributed.recv(tensor, self.ranks[src], self.device_group)
815815
return tensor
816816

817+
def prepare_communication_buffer_for_model(self, model: torch.nn.Module):
818+
if self.device_communicator is not None:
819+
self.device_communicator.prepare_communication_buffer_for_model(model)
820+
817821
def destroy(self):
818822
if hasattr(self, "device_group"):
819823
torch.distributed.destroy_process_group(self.device_group)
@@ -1359,3 +1363,21 @@ def _node_count(pg: ProcessGroup) -> int:
13591363
node_assignment[other_rank] = next_node_id
13601364

13611365
return next_node_id
1366+
1367+
1368+
def prepare_communication_buffer_for_model(model: torch.nn.Module):
1369+
"""Prepare the communication buffer for the model.
1370+
Traditional communication libraries like NCCL are almost
1371+
model agnostic. However, emerging new communication libraries like
1372+
MoE all2all (DeepEP) usually allocate the communication buffer
1373+
based on the model shape for optimal performance.
1374+
"""
1375+
logger.debug(f"prepare_communication_buffer_for_model: {_TP} {_PP} {_DP} {_EP}")
1376+
if _TP is not None:
1377+
_TP.prepare_communication_buffer_for_model(model)
1378+
if _PP is not None:
1379+
_PP.prepare_communication_buffer_for_model(model)
1380+
if _DP is not None:
1381+
_DP.prepare_communication_buffer_for_model(model)
1382+
if _EP is not None:
1383+
_EP.prepare_communication_buffer_for_model(model)

0 commit comments

Comments
 (0)