28
28
from vllm .config import VllmConfig , set_current_vllm_config
29
29
from vllm .model_executor .layers .activation import SiluAndMul
30
30
from vllm .model_executor .layers .fused_moe .fused_batched_moe import (
31
+ BatchedDispatchCombine ,
31
32
BatchedExperts )
32
33
from vllm .model_executor .layers .fused_moe .fused_moe import fused_topk
33
34
from vllm .model_executor .layers .fused_moe .modular_kernel import (
@@ -170,7 +171,7 @@ def torch_dispatch(
170
171
assert topk_ids .dim () == 2
171
172
assert topk_ids .shape [0 ] == a .shape [0 ]
172
173
173
- num_tokens = a .shape [ 0 ]
174
+ num_tokens , hidden_dim = a .shape
174
175
topk = topk_ids .shape [1 ]
175
176
176
177
tokens_per_expert = torch .bincount (topk_ids .view (- 1 ),
@@ -181,7 +182,7 @@ def torch_dispatch(
181
182
if max_num_tokens is None :
182
183
max_num_tokens = int (tokens_per_expert .max ().item ())
183
184
184
- b_a = torch .zeros ((num_experts , max_num_tokens , a . shape [ 1 ] ),
185
+ b_a = torch .zeros ((num_experts , max_num_tokens , hidden_dim ),
185
186
dtype = a .dtype ,
186
187
device = a .device )
187
188
@@ -198,7 +199,7 @@ def torch_dispatch(
198
199
199
200
200
201
def torch_combine (b_out , topk_weight , topk_ids ):
201
- num_tokens , topk = topk_ids .shape
202
+ num_tokens = topk_ids .shape [ 0 ]
202
203
num_experts = b_out .shape [0 ]
203
204
K = b_out .shape [- 1 ]
204
205
out = torch .zeros ((num_tokens , K ), dtype = b_out .dtype , device = b_out .device )
@@ -240,6 +241,22 @@ def torch_batched_moe(a, w1, w2, topk_weight, topk_ids):
240
241
return torch_combine (out , topk_weight , topk_ids )
241
242
242
243
244
+ def batched_moe (a , w1 , w2 , topk_weight , topk_ids ):
245
+ num_experts = w1 .shape [0 ]
246
+
247
+ fused_experts = FusedMoEModularKernel (
248
+ BatchedDispatchCombine (a .shape [0 ], world_size = 1 , dp_size = 1 , rank = 0 ),
249
+ BatchedExperts (a .shape [0 ])
250
+ )
251
+
252
+ return fused_experts (a ,
253
+ w1 ,
254
+ w2 ,
255
+ topk_weight ,
256
+ topk_ids ,
257
+ num_experts )
258
+
259
+
243
260
# TODO: same as torch_moe but with fused_topk factored out.
244
261
def torch_moe2 (a , w1 , w2 , topk_weight , topk_ids ):
245
262
M , K = a .shape
@@ -262,7 +279,7 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids):
262
279
@pytest .mark .parametrize ("k" , [128 , 511 , 1024 ])
263
280
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
264
281
@pytest .mark .parametrize ("topk" , TOP_KS )
265
- @pytest .mark .parametrize ("dtype" , [torch .float16 , torch . bfloat16 ])
282
+ @pytest .mark .parametrize ("dtype" , [torch .bfloat16 ])
266
283
def test_fused_moe_batched_experts (
267
284
m : int ,
268
285
n : int ,
@@ -280,10 +297,13 @@ def test_fused_moe_batched_experts(
280
297
281
298
with set_current_vllm_config (vllm_config ):
282
299
topk_weight , topk_ids = fused_topk (a , score , topk , False )
283
- torch_output = torch_moe2 (a , w1 , w2 , topk_weight , topk_ids )
284
- triton_output = torch_batched_moe (a , w1 , w2 , topk_weight , topk_ids )
300
+ baseline_output = torch_moe2 (a , w1 , w2 , topk_weight , topk_ids )
301
+ torch_output = torch_batched_moe (a , w1 , w2 , topk_weight , topk_ids )
302
+ batched_output = batched_moe (a , w1 , w2 , topk_weight , topk_ids )
285
303
286
- torch .testing .assert_close (triton_output , torch_output , atol = 2e-2 , rtol = 0 )
304
+ torch .testing .assert_close (baseline_output , torch_output , atol = 2e-2 , rtol = 0 )
305
+ torch .set_printoptions (profile = "full" )
306
+ torch .testing .assert_close (baseline_output , batched_output , atol = 2e-2 , rtol = 0 )
287
307
288
308
289
309
def rank_chunk (num , r , w ):
@@ -473,6 +493,8 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
473
493
experts ,
474
494
)
475
495
496
+ # TODO: workers with the same dp_rank must use the exact same inputs.
497
+
476
498
a_chunk = chunk_by_rank (a , rank , world_size ).to (device )
477
499
chunk_topk_weight = chunk_by_rank (topk_weight , rank , world_size ).to (device )
478
500
chunk_topk_ids = chunk_by_rank (topk_ids , rank , world_size ).to (device )
@@ -528,7 +550,7 @@ def _pplx_moe(
528
550
@pytest .mark .parametrize ("k" , [128 , 512 , 1024 ])
529
551
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
530
552
@pytest .mark .parametrize ("topk" , TOP_KS )
531
- @pytest .mark .parametrize ("dtype" , [torch .float16 , torch . bfloat16 ])
553
+ @pytest .mark .parametrize ("dtype" , [torch .bfloat16 ])
532
554
@pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]])
533
555
@requires_pplx
534
556
def test_pplx_moe (
0 commit comments