24
24
spawn ) # pyright: ignore[reportPrivateImportUsage]
25
25
from typing_extensions import Concatenate , ParamSpec
26
26
27
- import vllm .model_executor .layers .fused_moe # noqa
28
27
from vllm .config import VllmConfig , set_current_vllm_config
29
28
from vllm .model_executor .layers .activation import SiluAndMul
29
+ from vllm .model_executor .layers .fused_moe import override_config
30
30
from vllm .model_executor .layers .fused_moe .fused_batched_moe import (
31
- BatchedDispatchCombine , BatchedExperts )
32
- from vllm .model_executor .layers .fused_moe .fused_moe import fused_topk
31
+ BatchedDispatchCombine , BatchedExperts , BatchedTritonExperts )
32
+ from vllm .model_executor .layers .fused_moe .fused_moe import (fused_topk ,
33
+ get_default_config )
33
34
from vllm .model_executor .layers .fused_moe .modular_kernel import (
34
35
FusedMoEModularKernel )
35
36
from vllm .model_executor .layers .fused_moe .pplx_dispatch_combine import (
36
37
PplxDispatchCombine )
37
38
from vllm .platforms import current_platform
38
39
40
+ PPLX_DISPATCH_COMBOS = [(4 , 128 , 128 ), (32 , 1024 , 512 ), (64 , 1024 , 512 ),
41
+ (222 , 2048 , 1024 )]
42
+
43
+ PPLX_MOE_COMBOS = [
44
+ (1 , 128 , 128 ),
45
+ (2 , 128 , 512 ),
46
+ (3 , 1024 , 2048 ),
47
+ (32 , 128 , 1024 ),
48
+ (45 , 512 , 2048 ),
49
+ (64 , 1024 , 1024 ),
50
+ (222 , 1024 , 2048 ),
51
+ ]
52
+
39
53
NUM_EXPERTS = [8 , 64 ]
40
54
EP_SIZE = [1 , 4 ]
41
- TOP_KS = [2 , 6 ]
55
+ TOP_KS = [1 , 2 , 6 ]
42
56
43
57
vllm_config = VllmConfig ()
44
58
vllm_config .scheduler_config .max_num_seqs = 128
@@ -298,7 +312,6 @@ def test_fused_moe_batched_experts(
298
312
torch_output ,
299
313
atol = 2e-2 ,
300
314
rtol = 0 )
301
- torch .set_printoptions (profile = "full" )
302
315
torch .testing .assert_close (baseline_output ,
303
316
batched_output ,
304
317
atol = 2e-2 ,
@@ -426,25 +439,24 @@ def _pplx_dispatch_combine(
426
439
nvshmem_finalize ()
427
440
428
441
429
- # TODO: this test point does not work for M == 1
430
- @ pytest . mark . parametrize ( "m" , [ 4 , 32 , 64 , 222 ])
431
- @ pytest . mark . parametrize ( "n" , [ 128 , 1024 , 2048 ])
432
- @pytest .mark .parametrize ("k " , [ 128 , 512 , 1024 ] )
442
+ # TODO: this test point does not work for odd M due to how the test is
443
+ # written, not due to limitations of the pplx kernels. The pplx_moe
444
+ # test below is able to deal with odd M.
445
+ @pytest .mark .parametrize ("mnk " , PPLX_DISPATCH_COMBOS )
433
446
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
434
447
@pytest .mark .parametrize ("topk" , TOP_KS )
435
448
@pytest .mark .parametrize ("dtype" , [torch .bfloat16 ])
436
449
@pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]])
437
450
@requires_pplx
438
451
def test_pplx_dispatch_combine (
439
- m : int ,
440
- n : int ,
441
- k : int ,
452
+ mnk : tuple [int , int , int ],
442
453
e : int ,
443
454
topk : int ,
444
455
dtype : torch .dtype ,
445
456
world_dp_size : tuple [int , int ],
446
457
):
447
458
current_platform .seed_everything (7 )
459
+ m , n , k = mnk
448
460
world_size , dp_size = world_dp_size
449
461
device = "cuda"
450
462
a = torch .randn ((m , k ), device = device , dtype = dtype ) / 10
@@ -454,15 +466,11 @@ def test_pplx_dispatch_combine(
454
466
topk , e )
455
467
456
468
457
- def pplx_moe (pgi , dp_size , a , w1 , w2 , topk_weight , topk_ids ):
458
- assert torch .cuda .current_device () == pgi .local_rank
459
-
469
+ def pplx_moe (rank , world_size , dp_size , a , w1 , w2 , topk_weight , topk_ids ):
470
+ device = torch .device ("cuda" , rank )
460
471
hidden_dim = a .shape [1 ]
461
472
num_experts = w1 .shape [0 ]
462
473
block_size = 128
463
- device = pgi .device
464
- rank = pgi .rank
465
- world_size = pgi .world_size
466
474
topk = topk_ids .shape [1 ]
467
475
max_num_tokens = rank_chunk (a .shape [0 ], 0 , world_size )
468
476
@@ -490,29 +498,39 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
490
498
dp_size ,
491
499
)
492
500
493
- experts = BatchedExperts (max_num_tokens = a .shape [0 ],
494
- world_size = world_size ,
495
- dp_size = dp_size )
501
+ experts = BatchedTritonExperts (max_num_tokens = a .shape [0 ],
502
+ world_size = world_size ,
503
+ dp_size = dp_size )
496
504
497
505
fused_experts = FusedMoEModularKernel (
498
506
dispatch_combine ,
499
507
experts ,
500
508
)
501
509
502
- # TODO: workers with the same dp_rank must use the exact same inputs.
503
-
510
+ # Note: workers with the same dp_rank must use the exact same inputs.
504
511
a_chunk = chunk_by_rank (a , rank , world_size ).to (device )
505
512
chunk_topk_weight = chunk_by_rank (topk_weight , rank , world_size ).to (device )
506
513
chunk_topk_ids = chunk_by_rank (topk_ids , rank , world_size ).to (device )
507
514
508
- out = fused_experts (
509
- a_chunk ,
510
- # Chunking weights like this only works for batched format
511
- chunk_by_rank (w1 , rank , world_size ).to (device ),
512
- chunk_by_rank (w2 , rank , world_size ).to (device ),
513
- chunk_topk_weight ,
514
- chunk_topk_ids ,
515
- global_num_experts = num_experts )
515
+ # Chunking weights like this only works for batched format
516
+ w1_chunk = chunk_by_rank (w1 , rank , world_size ).to (device )
517
+ w2_chunk = chunk_by_rank (w2 , rank , world_size ).to (device )
518
+
519
+ @torch .compile (backend = 'inductor' , fullgraph = True )
520
+ def _fused_experts (a , w1 , w2 , topk_weight , topk_ids , global_num_experts ):
521
+ return fused_experts (a ,
522
+ w1 ,
523
+ w2 ,
524
+ topk_weight ,
525
+ topk_ids ,
526
+ global_num_experts = global_num_experts )
527
+
528
+ out = _fused_experts (a_chunk ,
529
+ w1_chunk ,
530
+ w2_chunk ,
531
+ chunk_topk_weight ,
532
+ chunk_topk_ids ,
533
+ global_num_experts = num_experts )
516
534
517
535
torch .cuda .synchronize ()
518
536
@@ -546,8 +564,7 @@ def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
546
564
experts ,
547
565
)
548
566
549
- # TODO: workers with the same dp_rank must use the exact same inputs.
550
-
567
+ # Note: workers with the same dp_rank must use the exact same inputs.
551
568
a_chunk = chunk_by_rank (a , rank , world_size ).to (device )
552
569
chunk_topk_weight = chunk_by_rank (topk_weight , rank , world_size ).to (device )
553
570
chunk_topk_ids = chunk_by_rank (topk_ids , rank , world_size ).to (device )
@@ -581,10 +598,14 @@ def _pplx_moe(
581
598
m , k = a .shape
582
599
e , _ , n = w2 .shape
583
600
584
- with set_current_vllm_config (vllm_config ):
601
+ moe_config = get_default_config (m , e , n , k , topk , a .dtype , False )
602
+
603
+ with set_current_vllm_config (vllm_config ), override_config (moe_config ):
585
604
topk_weight , topk_ids , _ = fused_topk (a , score , topk , False )
586
605
torch_output = torch_moe2 (a , w1 , w2 , topk_weight , topk_ids )
587
- pplx_output = pplx_moe (pgi , dp_size , a , w1 , w2 , topk_weight , topk_ids )
606
+ pplx_output = pplx_moe (pgi .rank , pgi .world_size , dp_size , a , w1 , w2 ,
607
+ topk_weight , topk_ids )
608
+ # TODO: fix + re-enable
588
609
#batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight,
589
610
# topk_ids)
590
611
@@ -597,24 +618,21 @@ def _pplx_moe(
597
618
nvshmem_finalize ()
598
619
599
620
600
- @pytest .mark .parametrize ("m" , [1 , 2 , 3 , 32 , 45 , 64 , 222 ])
601
- @pytest .mark .parametrize ("n" , [128 , 1024 , 2048 ])
602
- @pytest .mark .parametrize ("k" , [128 , 512 , 1024 ])
621
+ @pytest .mark .parametrize ("mnk" , PPLX_MOE_COMBOS )
603
622
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
604
623
@pytest .mark .parametrize ("topk" , TOP_KS )
605
624
@pytest .mark .parametrize ("dtype" , [torch .bfloat16 ])
606
625
@pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]])
607
626
@requires_pplx
608
627
def test_pplx_moe (
609
- m : int ,
610
- n : int ,
611
- k : int ,
628
+ mnk : tuple [int , int , int ],
612
629
e : int ,
613
630
topk : int ,
614
631
dtype : torch .dtype ,
615
632
world_dp_size : tuple [int , int ],
616
633
):
617
634
current_platform .seed_everything (7 )
635
+ m , n , k = mnk
618
636
world_size , dp_size = world_dp_size
619
637
a = torch .randn ((m , k ), device = "cuda" , dtype = dtype ) / 10
620
638
w1 = torch .randn ((e , 2 * n , k ), device = "cuda" , dtype = dtype ) / 10
0 commit comments