Skip to content

Commit 7204685

Browse files
sunfish2010facebook-github-bot
authored andcommitted
Routed mp=1 fix (#4841)
Summary: Pull Request resolved: #4841 X-link: facebookresearch/FBGEMM#1866 ep_first_order (offset) for padded and non-padded case is not the same. Fix with correct offset + adding all corresponding test cases to unit test to capture various scenarios Reviewed By: jasonjk-park Differential Revision: D81801293 fbshipit-source-id: 5c6067a4a058752667f4f1c5a531c95ff5c53eb6
1 parent 984c98b commit 7204685

File tree

2 files changed

+168
-132
lines changed

2 files changed

+168
-132
lines changed

fbgemm_gpu/experimental/gen_ai/gen_ai/moe/shuffling.py

Lines changed: 59 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ def combine_shuffling(
1919
token_counts: torch.Tensor,
2020
expert_start: Optional[int] = None,
2121
expert_end: Optional[int] = None,
22-
is_balanced: bool = False,
22+
is_padded: bool = False,
2323
) -> Tuple[torch.Tensor, torch.Tensor]:
2424
# pyre-ignore
2525
return _combine_or_split_shuffling(
2626
tokens=tokens,
2727
token_counts=token_counts,
2828
expert_start=expert_start,
2929
expert_end=expert_end,
30-
is_balanced=is_balanced,
30+
is_padded=is_padded,
3131
is_combine=True,
3232
)
3333

@@ -37,7 +37,7 @@ def split_shuffling(
3737
token_counts: torch.Tensor,
3838
expert_start: Optional[int] = None,
3939
expert_end: Optional[int] = None,
40-
is_balanced: bool = False,
40+
is_padded: bool = False,
4141
init_with_zeros: bool = False,
4242
) -> torch.Tensor:
4343
# pyre-ignore
@@ -46,7 +46,7 @@ def split_shuffling(
4646
token_counts=token_counts,
4747
expert_start=expert_start,
4848
expert_end=expert_end,
49-
is_balanced=is_balanced,
49+
is_padded=is_padded,
5050
is_combine=False,
5151
init_with_zeros=init_with_zeros,
5252
)
@@ -57,7 +57,7 @@ def _combine_or_split_shuffling(
5757
token_counts: torch.Tensor,
5858
expert_start: Optional[int],
5959
expert_end: Optional[int],
60-
is_balanced: bool,
60+
is_padded: bool,
6161
is_combine: bool,
6262
init_with_zeros: bool = False,
6363
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
@@ -67,6 +67,10 @@ def _combine_or_split_shuffling(
6767

6868
T, D = tokens.shape
6969
EP, E = token_counts.shape
70+
B_T = -1
71+
if is_padded:
72+
assert T % EP == 0
73+
B_T = T // EP
7074

7175
if expert_start is None:
7276
expert_start = 0
@@ -95,8 +99,6 @@ def _combine_or_split_shuffling(
9599
)
96100
else:
97101
output_token_counts = None
98-
T_BUCKET_CAP = 16384
99-
T_BUCKET = min(triton.next_power_of_2(T), T_BUCKET_CAP)
100102

101103
BLOCK_E = max(triton.next_power_of_2(E), 8)
102104
BLOCK_EG = max(triton.next_power_of_2(EG), 8)
@@ -108,9 +110,9 @@ def _combine_or_split_shuffling(
108110
output_tokens,
109111
output_token_counts,
110112
is_combine,
111-
is_balanced,
112-
T_BUCKET,
113113
expert_start,
114+
is_padded,
115+
B_T,
114116
EG,
115117
EP,
116118
E,
@@ -133,7 +135,7 @@ def _combine_or_split_shuffling(
133135

134136
torch.library.define(
135137
"fbgemm::combine_shuffling",
136-
"(Tensor tokens, Tensor token_counts, int? expert_start = None, int? expert_end = None, bool? is_balanced = False) -> (Tensor, Tensor)",
138+
"(Tensor tokens, Tensor token_counts, int? expert_start = None, int? expert_end = None, bool? is_padded = False) -> (Tensor, Tensor)",
137139
)
138140

139141

@@ -143,7 +145,7 @@ def combine_shuffling_meta(
143145
token_counts,
144146
expert_start,
145147
expert_end,
146-
is_balanced,
148+
is_padded,
147149
):
148150
_, E = token_counts.shape
149151
if expert_start is None:
@@ -165,22 +167,22 @@ def combine_shuffling_cuda(
165167
token_counts,
166168
expert_start=None,
167169
expert_end=None,
168-
is_balanced=False,
170+
is_padded=False,
169171
):
170172
return combine_shuffling(
171173
tokens,
172174
token_counts,
173175
expert_start,
174176
expert_end,
175-
is_balanced,
177+
is_padded,
176178
)
177179

178180

179181
_SPLIT_SHUFFLING_OP_NAME = "fbgemm::split_shuffling"
180182

181183
torch.library.define(
182184
"fbgemm::split_shuffling",
183-
"(Tensor tokens, Tensor token_counts, int? expert_start = None, int? expert_end = None, bool? is_balanced = False, bool? init_with_zeros = False) -> Tensor",
185+
"(Tensor tokens, Tensor token_counts, int? expert_start = None, int? expert_end = None, bool? is_padded = False, bool? init_with_zeros = False) -> Tensor",
184186
)
185187

186188

@@ -190,7 +192,7 @@ def split_shuffling_meta(
190192
token_counts,
191193
expert_start,
192194
expert_end,
193-
is_balanced,
195+
is_padded,
194196
):
195197
output_tokens = torch.empty_like(tokens)
196198
return output_tokens
@@ -202,14 +204,14 @@ def split_shuffling_cuda(
202204
token_counts,
203205
expert_start=None,
204206
expert_end=None,
205-
is_balanced=False,
207+
is_padded=False,
206208
):
207209
return split_shuffling(
208210
tokens,
209211
token_counts,
210212
expert_start,
211213
expert_end,
212-
is_balanced,
214+
is_padded,
213215
)
214216

215217

@@ -252,8 +254,6 @@ def split_shuffling_cuda(
252254
configs=_AMD_CONFIGS if torch.version.hip else _NV_CONFIGS,
253255
key=[
254256
"COMBINE",
255-
"BALANCED",
256-
"T_BUCKET",
257257
"EG",
258258
"EP",
259259
"E",
@@ -267,9 +267,9 @@ def _fbgemm_combine_or_split_shuffling(
267267
output_tokens_ptr,
268268
output_token_counts_ptr,
269269
COMBINE: tl.constexpr,
270-
BALANCED,
271-
T_BUCKET,
272270
EG_START,
271+
PADDED,
272+
B_T: tl.constexpr,
273273
EG: tl.constexpr,
274274
EP: tl.constexpr,
275275
E: tl.constexpr,
@@ -294,8 +294,11 @@ def _fbgemm_combine_or_split_shuffling(
294294
rank = tidx // (EG * SPLIT_D)
295295
local_expert = (tidx % (EG * SPLIT_D)) // SPLIT_D
296296
didx = tidx % SPLIT_D
297+
# All experts in communication group
297298
offs_e = tl.arange(0, BLOCK_E)
299+
# Local experts
298300
offs_eg = tl.arange(0, BLOCK_EG)
301+
# Ranks
299302
offs_ep = tl.arange(0, BLOCK_EP)
300303

301304
global_expert = local_expert + EG_START
@@ -307,59 +310,56 @@ def _fbgemm_combine_or_split_shuffling(
307310
other=0,
308311
) # [EP, E]
309312

310-
input_token_counts_eg = tl.load(
311-
input_token_counts_ptr + offs_ep[:, None] * E + EG_START + offs_eg[None, :],
312-
eviction_policy="evict_last",
313-
mask=((offs_ep[:, None] < EP) & (offs_eg[None, :] < EG)),
314-
other=0,
315-
) # [EP, EG]
313+
if E == EG:
314+
input_token_counts_eg = input_token_counts
315+
else:
316+
input_token_counts_eg = tl.load(
317+
input_token_counts_ptr + offs_ep[:, None] * E + EG_START + offs_eg[None, :],
318+
eviction_policy="evict_last",
319+
mask=((offs_ep[:, None] < EP) & (offs_eg[None, :] < EG)),
320+
other=0,
321+
) # [EP, EG]
316322

317323
if COMBINE:
318324
LAST_TILE: tl.constexpr = EP * EG * SPLIT_D
319325

320326
if tidx == LAST_TILE:
321-
if EG == E:
322-
output_token_counts = tl.sum(input_token_counts, axis=0)
323-
tl.store(
324-
output_token_counts_ptr + offs_e,
325-
output_token_counts,
326-
mask=(offs_e < E),
327-
)
328-
output_token_counts = tl.sum(output_token_counts)
329-
tl.store(output_token_counts_ptr + E, output_token_counts)
330-
else:
331-
output_token_counts_eg = tl.sum(input_token_counts_eg, axis=0)
332-
tl.store(
333-
output_token_counts_ptr + offs_eg,
334-
output_token_counts_eg,
335-
mask=(offs_eg < EG),
336-
)
337-
output_token_counts_eg = tl.sum(output_token_counts_eg)
338-
tl.store(output_token_counts_ptr + EG, output_token_counts_eg)
327+
output_token_counts_eg = tl.sum(input_token_counts_eg, axis=0)
328+
tl.store(
329+
output_token_counts_ptr + offs_eg,
330+
output_token_counts_eg,
331+
mask=(offs_eg < EG),
332+
)
333+
output_token_counts_eg = tl.sum(output_token_counts_eg)
334+
tl.store(output_token_counts_ptr + EG, output_token_counts_eg)
339335
return
340336

341337
cond0 = offs_ep[:, None] < rank
342338
cond1 = offs_ep[:, None] == rank
343339

344340
cond2 = offs_e[None, :] < global_expert
345-
cond3 = offs_e[None, :] == global_expert
346-
347-
# r < rank || (r == rank && e < expert)
348-
ep_first_order = tl.sum(tl.where(cond0 or (cond1 and cond2), input_token_counts, 0))
349-
if EG == E:
350-
# e < expert || (e == expert && r < rank)
351-
expert_first_order = tl.sum(
352-
tl.where(cond2 or (cond3 and cond0), input_token_counts, 0)
341+
342+
if PADDED:
343+
tl.device_assert(B_T >= 0)
344+
# Only need information from previous experts in the same rank.
345+
ep_first_order = (
346+
tl.sum(tl.where(cond1 and cond2, input_token_counts, 0)) + B_T * rank
353347
)
354348
else:
355-
# e < expert || (e == expert && r < rank)
356-
cond4 = offs_eg[None, :] < local_expert
357-
cond5 = offs_eg[None, :] == local_expert
358-
359-
expert_first_order = tl.sum(
360-
tl.where(cond4 or (cond5 and cond0), input_token_counts_eg, 0)
349+
# r < rank || (r == rank && e < expert)
350+
ep_first_order = tl.sum(
351+
tl.where(cond0 or (cond1 and cond2), input_token_counts, 0)
361352
)
362353

354+
cond4 = offs_eg[None, :] < local_expert
355+
cond5 = offs_eg[None, :] == local_expert
356+
357+
# Expert first only need information from local experts across ranks.
358+
# e < expert || (e == expert && r < rank)
359+
expert_first_order = tl.sum(
360+
tl.where(cond4 or (cond5 and cond0), input_token_counts_eg, 0)
361+
)
362+
363363
if COMBINE:
364364
input_offset = ep_first_order
365365
output_offset = expert_first_order

0 commit comments

Comments
 (0)