@@ -19,15 +19,15 @@ def combine_shuffling(
19
19
token_counts : torch .Tensor ,
20
20
expert_start : Optional [int ] = None ,
21
21
expert_end : Optional [int ] = None ,
22
- is_balanced : bool = False ,
22
+ is_padded : bool = False ,
23
23
) -> Tuple [torch .Tensor , torch .Tensor ]:
24
24
# pyre-ignore
25
25
return _combine_or_split_shuffling (
26
26
tokens = tokens ,
27
27
token_counts = token_counts ,
28
28
expert_start = expert_start ,
29
29
expert_end = expert_end ,
30
- is_balanced = is_balanced ,
30
+ is_padded = is_padded ,
31
31
is_combine = True ,
32
32
)
33
33
@@ -37,7 +37,7 @@ def split_shuffling(
37
37
token_counts : torch .Tensor ,
38
38
expert_start : Optional [int ] = None ,
39
39
expert_end : Optional [int ] = None ,
40
- is_balanced : bool = False ,
40
+ is_padded : bool = False ,
41
41
init_with_zeros : bool = False ,
42
42
) -> torch .Tensor :
43
43
# pyre-ignore
@@ -46,7 +46,7 @@ def split_shuffling(
46
46
token_counts = token_counts ,
47
47
expert_start = expert_start ,
48
48
expert_end = expert_end ,
49
- is_balanced = is_balanced ,
49
+ is_padded = is_padded ,
50
50
is_combine = False ,
51
51
init_with_zeros = init_with_zeros ,
52
52
)
@@ -57,7 +57,7 @@ def _combine_or_split_shuffling(
57
57
token_counts : torch .Tensor ,
58
58
expert_start : Optional [int ],
59
59
expert_end : Optional [int ],
60
- is_balanced : bool ,
60
+ is_padded : bool ,
61
61
is_combine : bool ,
62
62
init_with_zeros : bool = False ,
63
63
) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
@@ -67,6 +67,10 @@ def _combine_or_split_shuffling(
67
67
68
68
T , D = tokens .shape
69
69
EP , E = token_counts .shape
70
+ B_T = - 1
71
+ if is_padded :
72
+ assert T % EP == 0
73
+ B_T = T // EP
70
74
71
75
if expert_start is None :
72
76
expert_start = 0
@@ -95,8 +99,6 @@ def _combine_or_split_shuffling(
95
99
)
96
100
else :
97
101
output_token_counts = None
98
- T_BUCKET_CAP = 16384
99
- T_BUCKET = min (triton .next_power_of_2 (T ), T_BUCKET_CAP )
100
102
101
103
BLOCK_E = max (triton .next_power_of_2 (E ), 8 )
102
104
BLOCK_EG = max (triton .next_power_of_2 (EG ), 8 )
@@ -108,9 +110,9 @@ def _combine_or_split_shuffling(
108
110
output_tokens ,
109
111
output_token_counts ,
110
112
is_combine ,
111
- is_balanced ,
112
- T_BUCKET ,
113
113
expert_start ,
114
+ is_padded ,
115
+ B_T ,
114
116
EG ,
115
117
EP ,
116
118
E ,
@@ -133,7 +135,7 @@ def _combine_or_split_shuffling(
133
135
134
136
torch .library .define (
135
137
"fbgemm::combine_shuffling" ,
136
- "(Tensor tokens, Tensor token_counts, int? expert_start = None, int? expert_end = None, bool? is_balanced = False) -> (Tensor, Tensor)" ,
138
+ "(Tensor tokens, Tensor token_counts, int? expert_start = None, int? expert_end = None, bool? is_padded = False) -> (Tensor, Tensor)" ,
137
139
)
138
140
139
141
@@ -143,7 +145,7 @@ def combine_shuffling_meta(
143
145
token_counts ,
144
146
expert_start ,
145
147
expert_end ,
146
- is_balanced ,
148
+ is_padded ,
147
149
):
148
150
_ , E = token_counts .shape
149
151
if expert_start is None :
@@ -165,22 +167,22 @@ def combine_shuffling_cuda(
165
167
token_counts ,
166
168
expert_start = None ,
167
169
expert_end = None ,
168
- is_balanced = False ,
170
+ is_padded = False ,
169
171
):
170
172
return combine_shuffling (
171
173
tokens ,
172
174
token_counts ,
173
175
expert_start ,
174
176
expert_end ,
175
- is_balanced ,
177
+ is_padded ,
176
178
)
177
179
178
180
179
181
_SPLIT_SHUFFLING_OP_NAME = "fbgemm::split_shuffling"
180
182
181
183
torch .library .define (
182
184
"fbgemm::split_shuffling" ,
183
- "(Tensor tokens, Tensor token_counts, int? expert_start = None, int? expert_end = None, bool? is_balanced = False, bool? init_with_zeros = False) -> Tensor" ,
185
+ "(Tensor tokens, Tensor token_counts, int? expert_start = None, int? expert_end = None, bool? is_padded = False, bool? init_with_zeros = False) -> Tensor" ,
184
186
)
185
187
186
188
@@ -190,7 +192,7 @@ def split_shuffling_meta(
190
192
token_counts ,
191
193
expert_start ,
192
194
expert_end ,
193
- is_balanced ,
195
+ is_padded ,
194
196
):
195
197
output_tokens = torch .empty_like (tokens )
196
198
return output_tokens
@@ -202,14 +204,14 @@ def split_shuffling_cuda(
202
204
token_counts ,
203
205
expert_start = None ,
204
206
expert_end = None ,
205
- is_balanced = False ,
207
+ is_padded = False ,
206
208
):
207
209
return split_shuffling (
208
210
tokens ,
209
211
token_counts ,
210
212
expert_start ,
211
213
expert_end ,
212
- is_balanced ,
214
+ is_padded ,
213
215
)
214
216
215
217
@@ -252,8 +254,6 @@ def split_shuffling_cuda(
252
254
configs = _AMD_CONFIGS if torch .version .hip else _NV_CONFIGS ,
253
255
key = [
254
256
"COMBINE" ,
255
- "BALANCED" ,
256
- "T_BUCKET" ,
257
257
"EG" ,
258
258
"EP" ,
259
259
"E" ,
@@ -267,9 +267,9 @@ def _fbgemm_combine_or_split_shuffling(
267
267
output_tokens_ptr ,
268
268
output_token_counts_ptr ,
269
269
COMBINE : tl .constexpr ,
270
- BALANCED ,
271
- T_BUCKET ,
272
270
EG_START ,
271
+ PADDED ,
272
+ B_T : tl .constexpr ,
273
273
EG : tl .constexpr ,
274
274
EP : tl .constexpr ,
275
275
E : tl .constexpr ,
@@ -294,8 +294,11 @@ def _fbgemm_combine_or_split_shuffling(
294
294
rank = tidx // (EG * SPLIT_D )
295
295
local_expert = (tidx % (EG * SPLIT_D )) // SPLIT_D
296
296
didx = tidx % SPLIT_D
297
+ # All experts in communication group
297
298
offs_e = tl .arange (0 , BLOCK_E )
299
+ # Local experts
298
300
offs_eg = tl .arange (0 , BLOCK_EG )
301
+ # Ranks
299
302
offs_ep = tl .arange (0 , BLOCK_EP )
300
303
301
304
global_expert = local_expert + EG_START
@@ -307,59 +310,56 @@ def _fbgemm_combine_or_split_shuffling(
307
310
other = 0 ,
308
311
) # [EP, E]
309
312
310
- input_token_counts_eg = tl .load (
311
- input_token_counts_ptr + offs_ep [:, None ] * E + EG_START + offs_eg [None , :],
312
- eviction_policy = "evict_last" ,
313
- mask = ((offs_ep [:, None ] < EP ) & (offs_eg [None , :] < EG )),
314
- other = 0 ,
315
- ) # [EP, EG]
313
+ if E == EG :
314
+ input_token_counts_eg = input_token_counts
315
+ else :
316
+ input_token_counts_eg = tl .load (
317
+ input_token_counts_ptr + offs_ep [:, None ] * E + EG_START + offs_eg [None , :],
318
+ eviction_policy = "evict_last" ,
319
+ mask = ((offs_ep [:, None ] < EP ) & (offs_eg [None , :] < EG )),
320
+ other = 0 ,
321
+ ) # [EP, EG]
316
322
317
323
if COMBINE :
318
324
LAST_TILE : tl .constexpr = EP * EG * SPLIT_D
319
325
320
326
if tidx == LAST_TILE :
321
- if EG == E :
322
- output_token_counts = tl .sum (input_token_counts , axis = 0 )
323
- tl .store (
324
- output_token_counts_ptr + offs_e ,
325
- output_token_counts ,
326
- mask = (offs_e < E ),
327
- )
328
- output_token_counts = tl .sum (output_token_counts )
329
- tl .store (output_token_counts_ptr + E , output_token_counts )
330
- else :
331
- output_token_counts_eg = tl .sum (input_token_counts_eg , axis = 0 )
332
- tl .store (
333
- output_token_counts_ptr + offs_eg ,
334
- output_token_counts_eg ,
335
- mask = (offs_eg < EG ),
336
- )
337
- output_token_counts_eg = tl .sum (output_token_counts_eg )
338
- tl .store (output_token_counts_ptr + EG , output_token_counts_eg )
327
+ output_token_counts_eg = tl .sum (input_token_counts_eg , axis = 0 )
328
+ tl .store (
329
+ output_token_counts_ptr + offs_eg ,
330
+ output_token_counts_eg ,
331
+ mask = (offs_eg < EG ),
332
+ )
333
+ output_token_counts_eg = tl .sum (output_token_counts_eg )
334
+ tl .store (output_token_counts_ptr + EG , output_token_counts_eg )
339
335
return
340
336
341
337
cond0 = offs_ep [:, None ] < rank
342
338
cond1 = offs_ep [:, None ] == rank
343
339
344
340
cond2 = offs_e [None , :] < global_expert
345
- cond3 = offs_e [None , :] == global_expert
346
-
347
- # r < rank || (r == rank && e < expert)
348
- ep_first_order = tl .sum (tl .where (cond0 or (cond1 and cond2 ), input_token_counts , 0 ))
349
- if EG == E :
350
- # e < expert || (e == expert && r < rank)
351
- expert_first_order = tl .sum (
352
- tl .where (cond2 or (cond3 and cond0 ), input_token_counts , 0 )
341
+
342
+ if PADDED :
343
+ tl .device_assert (B_T >= 0 )
344
+ # Only need information from previous experts in the same rank.
345
+ ep_first_order = (
346
+ tl .sum (tl .where (cond1 and cond2 , input_token_counts , 0 )) + B_T * rank
353
347
)
354
348
else :
355
- # e < expert || (e == expert && r < rank)
356
- cond4 = offs_eg [None , :] < local_expert
357
- cond5 = offs_eg [None , :] == local_expert
358
-
359
- expert_first_order = tl .sum (
360
- tl .where (cond4 or (cond5 and cond0 ), input_token_counts_eg , 0 )
349
+ # r < rank || (r == rank && e < expert)
350
+ ep_first_order = tl .sum (
351
+ tl .where (cond0 or (cond1 and cond2 ), input_token_counts , 0 )
361
352
)
362
353
354
+ cond4 = offs_eg [None , :] < local_expert
355
+ cond5 = offs_eg [None , :] == local_expert
356
+
357
+ # Expert first only need information from local experts across ranks.
358
+ # e < expert || (e == expert && r < rank)
359
+ expert_first_order = tl .sum (
360
+ tl .where (cond4 or (cond5 and cond0 ), input_token_counts_eg , 0 )
361
+ )
362
+
363
363
if COMBINE :
364
364
input_offset = ep_first_order
365
365
output_offset = expert_first_order
0 commit comments