28
28
from vllm .model_executor .layers .activation import SiluAndMul
29
29
from vllm .model_executor .layers .fused_moe import override_config
30
30
from vllm .model_executor .layers .fused_moe .fused_batched_moe import (
31
- BatchedDispatchCombine , BatchedExperts , BatchedTritonExperts )
31
+ BatchedExperts , BatchedPrepareAndFinalize , BatchedTritonExperts )
32
32
from vllm .model_executor .layers .fused_moe .fused_moe import (fused_topk ,
33
33
get_default_config )
34
34
from vllm .model_executor .layers .fused_moe .modular_kernel import (
35
35
FusedMoEModularKernel )
36
- from vllm .model_executor .layers .fused_moe .pplx_dispatch_combine import (
37
- PplxDispatchCombine )
36
+ from vllm .model_executor .layers .fused_moe .pplx_prepare_finalize import (
37
+ PplxPrepareAndFinalize )
38
38
from vllm .platforms import current_platform
39
39
40
- PPLX_DISPATCH_COMBOS = [(4 , 128 , 128 ), (32 , 1024 , 512 ), (64 , 1024 , 512 ),
41
- (222 , 2048 , 1024 )]
40
+ PPLX_PREPARE_COMBOS = [(4 , 128 , 128 ), (32 , 1024 , 512 ), (64 , 1024 , 512 ),
41
+ (222 , 2048 , 1024 )]
42
42
43
43
PPLX_MOE_COMBOS = [
44
44
(1 , 128 , 128 ),
@@ -175,7 +175,7 @@ def parallel_launch_from_env(
175
175
)
176
176
177
177
178
- def torch_dispatch (
178
+ def torch_prepare (
179
179
a : torch .Tensor ,
180
180
topk_ids : torch .Tensor ,
181
181
num_experts : int ,
@@ -211,7 +211,8 @@ def torch_dispatch(
211
211
return b_a , tokens_per_expert
212
212
213
213
214
- def torch_combine (b_out , topk_weight , topk_ids ):
214
+ def torch_finalize (b_out : torch .Tensor , topk_weight : torch .Tensor ,
215
+ topk_ids : torch .Tensor ) -> torch .Tensor :
215
216
num_tokens = topk_ids .shape [0 ]
216
217
num_experts = b_out .shape [0 ]
217
218
K = b_out .shape [- 1 ]
@@ -231,9 +232,15 @@ def torch_combine(b_out, topk_weight, topk_ids):
231
232
return out
232
233
233
234
234
- def torch_batched_moe (a , w1 , w2 , topk_weight , topk_ids ):
235
+ def torch_batched_moe (
236
+ a : torch .Tensor ,
237
+ w1 : torch .Tensor ,
238
+ w2 : torch .Tensor ,
239
+ topk_weight : torch .Tensor ,
240
+ topk_ids : torch .Tensor ,
241
+ ) -> torch .Tensor :
235
242
num_experts = w1 .shape [0 ]
236
- b_a , tokens_per_expert = torch_dispatch (a , topk_ids , num_experts )
243
+ b_a , tokens_per_expert = torch_prepare (a , topk_ids , num_experts )
237
244
assert b_a .dim () == 3
238
245
num_tokens , topk = topk_ids .shape
239
246
_ , max_num_tokens , K = b_a .shape
@@ -251,21 +258,33 @@ def torch_batched_moe(a, w1, w2, topk_weight, topk_ids):
251
258
tmp [:num ], b_a [expert , :num , :] @ w1 [expert ].transpose (0 , 1 ))
252
259
out [expert , :num , :] = tmp [:num ] @ w2 [expert ].transpose (0 , 1 )
253
260
254
- return torch_combine (out , topk_weight , topk_ids )
261
+ return torch_finalize (out , topk_weight , topk_ids )
255
262
256
263
257
- def batched_moe (a , w1 , w2 , topk_weight , topk_ids ):
264
+ def batched_moe (
265
+ a : torch .Tensor ,
266
+ w1 : torch .Tensor ,
267
+ w2 : torch .Tensor ,
268
+ topk_weight : torch .Tensor ,
269
+ topk_ids : torch .Tensor ,
270
+ ) -> torch .Tensor :
258
271
num_experts = w1 .shape [0 ]
259
272
260
273
fused_experts = FusedMoEModularKernel (
261
- BatchedDispatchCombine (a .shape [0 ], world_size = 1 , dp_size = 1 , rank = 0 ),
274
+ BatchedPrepareAndFinalize (a .shape [0 ], world_size = 1 , dp_size = 1 , rank = 0 ),
262
275
BatchedExperts (max_num_tokens = a .shape [0 ], dp_size = 1 , world_size = 1 ))
263
276
264
277
return fused_experts (a , w1 , w2 , topk_weight , topk_ids , num_experts )
265
278
266
279
267
- # TODO: same as torch_moe but with fused_topk factored out.
268
- def torch_moe2 (a , w1 , w2 , topk_weight , topk_ids ):
280
+ # Note: same as torch_moe but with fused_topk factored out.
281
+ def torch_moe2 (
282
+ a : torch .Tensor ,
283
+ w1 : torch .Tensor ,
284
+ w2 : torch .Tensor ,
285
+ topk_weight : torch .Tensor ,
286
+ topk_ids : torch .Tensor ,
287
+ ) -> torch .Tensor :
269
288
M , K = a .shape
270
289
topk = topk_ids .shape [1 ]
271
290
a = a .view (M , - 1 , K ).repeat (1 , topk , 1 ).reshape (- 1 , K )
@@ -318,17 +337,19 @@ def test_fused_moe_batched_experts(
318
337
rtol = 0 )
319
338
320
339
321
- def rank_chunk (num , r , w ) :
340
+ def rank_chunk (num : int , r : int , w : int ) -> int :
322
341
rem = num % w
323
342
return (num // w ) + (1 if r < rem else 0 )
324
343
325
344
326
- def chunk_by_rank (t , r , w ) :
345
+ def chunk_by_rank (t : torch . Tensor , r : int , w : int ) -> torch . Tensor :
327
346
chunk = rank_chunk (t .shape [0 ], r , w )
328
347
return t [(r * chunk ):(r + 1 ) * chunk ]
329
348
330
349
331
- def pplx_dispatch_combine (pgi , dp_size , a , topk_weight , topk_ids , num_experts ):
350
+ def pplx_prepare_finalize (pgi : ProcessGroupInfo , dp_size : int , a : torch .Tensor ,
351
+ topk_weight : torch .Tensor , topk_ids : torch .Tensor ,
352
+ num_experts : int ) -> torch .Tensor :
332
353
assert torch .cuda .current_device () == pgi .local_rank
333
354
334
355
topk = topk_ids .shape [1 ]
@@ -355,7 +376,7 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
355
376
356
377
topk_ids = topk_ids .to (dtype = torch .uint32 )
357
378
358
- dispatch_combine = PplxDispatchCombine (
379
+ prepare_finalize = PplxPrepareAndFinalize (
359
380
ata ,
360
381
max_num_tokens ,
361
382
world_size ,
@@ -368,7 +389,7 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
368
389
chunk_topk_weight = chunk_by_rank (topk_weight , rank , world_size ).to (device )
369
390
chunk_topk_ids = chunk_by_rank (topk_ids , rank , world_size ).to (device )
370
391
371
- b_a , b_a_scale , expert_num_tokens = dispatch_combine . dispatch (
392
+ b_a , b_a_scale , expert_num_tokens = prepare_finalize . prepare (
372
393
a_chunk ,
373
394
None ,
374
395
None ,
@@ -388,7 +409,7 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
388
409
device = device ,
389
410
)
390
411
391
- dispatch_combine . combine (
412
+ prepare_finalize . finalize (
392
413
out ,
393
414
b_a ,
394
415
chunk_topk_weight ,
@@ -405,13 +426,13 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
405
426
return out [:num_tokens ]
406
427
407
428
408
- def _pplx_dispatch_combine (
429
+ def _pplx_prepare_finalize (
409
430
pgi : ProcessGroupInfo ,
410
431
dp_size : int ,
411
- a ,
412
- score ,
413
- topk ,
414
- num_experts ,
432
+ a : torch . Tensor ,
433
+ score : torch . Tensor ,
434
+ topk : torch . Tensor ,
435
+ num_experts : int ,
415
436
):
416
437
uid = nvshmem_get_unique_id (
417
438
) if pgi .rank == 0 else nvshmem_alloc_empty_unique_id ()
@@ -428,7 +449,7 @@ def _pplx_dispatch_combine(
428
449
topk_weight .view (- 1 , topk , 1 ).to (device )).sum (dim = 1 ).to (
429
450
a .dtype )
430
451
431
- pplx_output = pplx_dispatch_combine (pgi , dp_size , a , topk_weight , topk_ids ,
452
+ pplx_output = pplx_prepare_finalize (pgi , dp_size , a , topk_weight , topk_ids ,
432
453
num_experts )
433
454
434
455
torch_output = chunk_by_rank (torch_output , pgi .rank ,
@@ -439,16 +460,16 @@ def _pplx_dispatch_combine(
439
460
nvshmem_finalize ()
440
461
441
462
442
- # TODO: this test point does not work for odd M due to how the test is
463
+ # TODO (bnell) : this test point does not work for odd M due to how the test is
443
464
# written, not due to limitations of the pplx kernels. The pplx_moe
444
465
# test below is able to deal with odd M.
445
- @pytest .mark .parametrize ("mnk" , PPLX_DISPATCH_COMBOS )
466
+ @pytest .mark .parametrize ("mnk" , PPLX_PREPARE_COMBOS )
446
467
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
447
468
@pytest .mark .parametrize ("topk" , TOP_KS )
448
469
@pytest .mark .parametrize ("dtype" , [torch .bfloat16 ])
449
470
@pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]])
450
471
@requires_pplx
451
- def test_pplx_dispatch_combine (
472
+ def test_pplx_prepare_finalize (
452
473
mnk : tuple [int , int , int ],
453
474
e : int ,
454
475
topk : int ,
@@ -462,11 +483,22 @@ def test_pplx_dispatch_combine(
462
483
a = torch .randn ((m , k ), device = device , dtype = dtype ) / 10
463
484
score = torch .randn ((m , e ), device = device , dtype = dtype )
464
485
465
- parallel_launch (world_size , _pplx_dispatch_combine , dp_size , a , score ,
486
+ parallel_launch (world_size , _pplx_prepare_finalize , dp_size , a , score ,
466
487
topk , e )
467
488
468
489
469
- def pplx_moe (rank , world_size , dp_size , a , w1 , w2 , topk_weight , topk_ids ):
490
+ def pplx_moe (
491
+ rank : int ,
492
+ world_size : int ,
493
+ dp_size : int ,
494
+ a : torch .Tensor ,
495
+ w1 : torch .Tensor ,
496
+ w2 : torch .Tensor ,
497
+ topk_weight : torch .Tensor ,
498
+ topk_ids : torch .Tensor ,
499
+ use_compile : bool = True ,
500
+ use_cudagraphs : bool = True ,
501
+ ) -> torch .Tensor :
470
502
device = torch .device ("cuda" , rank )
471
503
hidden_dim = a .shape [1 ]
472
504
num_experts = w1 .shape [0 ]
@@ -490,7 +522,7 @@ def pplx_moe(rank, world_size, dp_size, a, w1, w2, topk_weight, topk_ids):
490
522
491
523
topk_ids = topk_ids .to (dtype = torch .uint32 )
492
524
493
- dispatch_combine = PplxDispatchCombine (
525
+ prepare_finalize = PplxPrepareAndFinalize (
494
526
ata ,
495
527
max_num_tokens ,
496
528
world_size ,
@@ -503,7 +535,7 @@ def pplx_moe(rank, world_size, dp_size, a, w1, w2, topk_weight, topk_ids):
503
535
dp_size = dp_size )
504
536
505
537
fused_experts = FusedMoEModularKernel (
506
- dispatch_combine ,
538
+ prepare_finalize ,
507
539
experts ,
508
540
)
509
541
@@ -516,14 +548,12 @@ def pplx_moe(rank, world_size, dp_size, a, w1, w2, topk_weight, topk_ids):
516
548
w1_chunk = chunk_by_rank (w1 , rank , world_size ).to (device )
517
549
w2_chunk = chunk_by_rank (w2 , rank , world_size ).to (device )
518
550
519
- @torch .compile (backend = 'inductor' , fullgraph = True )
520
- def _fused_experts (a , w1 , w2 , topk_weight , topk_ids , global_num_experts ):
521
- return fused_experts (a ,
522
- w1 ,
523
- w2 ,
524
- topk_weight ,
525
- topk_ids ,
526
- global_num_experts = global_num_experts )
551
+ if use_compile :
552
+ _fused_experts = torch .compile (fused_experts ,
553
+ backend = 'inductor' ,
554
+ fullgraph = True )
555
+ else :
556
+ _fused_experts = fused_experts
527
557
528
558
out = _fused_experts (a_chunk ,
529
559
w1_chunk ,
@@ -532,6 +562,21 @@ def _fused_experts(a, w1, w2, topk_weight, topk_ids, global_num_experts):
532
562
chunk_topk_ids ,
533
563
global_num_experts = num_experts )
534
564
565
+ if use_cudagraphs :
566
+ out .fill_ (0 )
567
+ stream = torch .cuda .Stream ()
568
+ graph = torch .cuda .CUDAGraph ()
569
+ with torch .cuda .graph (graph , stream = stream ):
570
+ out = _fused_experts (a_chunk ,
571
+ w1_chunk ,
572
+ w2_chunk ,
573
+ chunk_topk_weight ,
574
+ chunk_topk_ids ,
575
+ global_num_experts = num_experts )
576
+
577
+ torch .cuda .synchronize ()
578
+ graph .replay ()
579
+
535
580
torch .cuda .synchronize ()
536
581
537
582
ata .destroy ()
@@ -548,7 +593,7 @@ def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
548
593
world_size = pgi .world_size
549
594
max_num_tokens = rank_chunk (a .shape [0 ], 0 , world_size )
550
595
551
- dispatch_combine = BatchedDispatchCombine (
596
+ prepare_finalize = BatchedPrepareAndFinalize (
552
597
max_num_tokens = max_num_tokens ,
553
598
world_size = world_size ,
554
599
dp_size = dp_size ,
@@ -560,7 +605,7 @@ def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
560
605
dp_size = 1 )
561
606
562
607
fused_experts = FusedMoEModularKernel (
563
- dispatch_combine ,
608
+ prepare_finalize ,
564
609
experts ,
565
610
)
566
611
@@ -605,7 +650,7 @@ def _pplx_moe(
605
650
torch_output = torch_moe2 (a , w1 , w2 , topk_weight , topk_ids )
606
651
pplx_output = pplx_moe (pgi .rank , pgi .world_size , dp_size , a , w1 , w2 ,
607
652
topk_weight , topk_ids )
608
- # TODO: fix + re-enable
653
+ # TODO (bnell) : fix + re-enable
609
654
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
610
655
# topk_ids)
611
656
0 commit comments