9
9
import torch
10
10
import traceback
11
11
12
- from torch .nn import Parameter
13
- from torch .nn import functional as F
14
12
from torch .multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
15
- from typing import Callable , Concatenate , ParamSpec
13
+ from typing import Callable , Concatenate , ParamSpec , Tuple
16
14
17
15
from pplx_kernels import AllToAll
18
16
from pplx_kernels .nvshmem import (
25
23
import vllm .model_executor .layers .fused_moe # noqa
26
24
from tests .kernels .utils import (compute_max_diff , opcheck , stack_and_dev ,
27
25
torch_moe , torch_moe_single )
28
- from vllm import _custom_ops as ops
26
+ # from vllm import _custom_ops as ops
29
27
from vllm .config import VllmConfig , set_current_vllm_config
30
- from vllm .model_executor .layers .fused_moe import fused_moe
28
+ # from vllm.model_executor.layers.fused_moe import fused_moe
31
29
#from vllm.model_executor.layers.fused_moe.fused_batched_moe import fused_batched_experts
32
30
from vllm .model_executor .layers .fused_moe .fused_moe import (
33
31
fused_topk , moe_align_block_size )
34
- from vllm .model_executor .layers .fused_moe .moe_torch_iterative import (
35
- fused_moe as iterative_moe )
36
- from vllm .model_executor .layers .quantization .utils .marlin_utils_test import (
37
- marlin_quantize )
38
- from vllm .model_executor .layers .quantization .utils .quant_utils import (
39
- quantize_weights )
40
- from vllm .model_executor .models .mixtral import MixtralMoE
41
32
from vllm .platforms import current_platform
42
- from vllm .scalar_type import scalar_types
43
- from vllm .utils import round_up
44
33
45
34
from vllm .model_executor .layers .activation import SiluAndMul
46
35
47
36
from vllm .model_executor .layers .fused_moe .fused_moe import TritonExperts , BatchedDispatchCombine , BatchedExperts , fused_experts
48
- from vllm .model_executor .layers .fused_moe .modular_kernel import FusedMoEModularKernel , FusedMoEQuantizeDispatchCombine
37
+ from vllm .model_executor .layers .fused_moe .modular_kernel import FusedMoEModularKernel
49
38
from vllm .model_executor .layers .fused_moe .pplx_dispatch_combine import PplxDispatchCombine
50
39
51
40
NUM_EXPERTS = [8 , 64 ]
@@ -373,7 +362,8 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
373
362
num_experts , # store at PplxDispatchCombine creation?
374
363
None
375
364
)
376
- #torch.cuda.synchronize() # necessary?
365
+
366
+ b_a = b_a * 1.5
377
367
378
368
out = torch .full (
379
369
(max_num_tokens , hidden_dim ),
@@ -392,7 +382,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
392
382
393
383
ata .destroy ()
394
384
395
- torch .distributed .barrier ()
385
+ # torch.distributed.barrier()
396
386
397
387
#print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}")
398
388
@@ -406,27 +396,34 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
406
396
def _pplx_dispatch_combine (
407
397
pgi : ProcessGroupInfo ,
408
398
dp_size : int ,
409
- a : torch .Tensor ,
410
- w1 : torch .Tensor ,
411
- w2 : torch .Tensor ,
412
- score : torch .Tensor ,
399
+ m , n , k , e ,
400
+ #a: torch.Tensor,
401
+ #w1: torch.Tensor,
402
+ #w2: torch.Tensor,
403
+ #score: torch.Tensor,
413
404
topk : int ,
414
405
dtype : torch .dtype ,
415
406
):
416
407
uid = nvshmem_get_unique_id () if pgi .rank == 0 else nvshmem_alloc_empty_unique_id ()
417
408
torch .distributed .broadcast (uid , src = 0 )
418
409
nvshmem_init (uid , pgi .rank , pgi .world_size )
410
+ device = pgi .device
419
411
420
- m , k = a .shape
421
- e , _ , n = w2 .shape
412
+ a = torch .randn ((m , k ), device = device , dtype = dtype ) / 10
413
+ w1 = torch .randn ((e , 2 * n , k ), device = device , dtype = dtype ) / 10
414
+ w2 = torch .randn ((e , k , n ), device = device , dtype = dtype ) / 10
415
+ score = torch .randn ((m , e ), device = device , dtype = dtype )
416
+
417
+ #m, k = a.shape
418
+ #e, _, n = w2.shape
422
419
423
420
topk_weight , topk_ids = fused_topk (a , score , topk , False )
424
421
425
422
#print(f"a {a.shape}")
426
423
a_rep = torch .repeat_interleave (a , topk , dim = 0 )
427
424
#print(f"a_rep {a_rep.shape} {a_rep.view(-1, topk, k)}")
428
425
429
- torch_output = (a_rep .view (- 1 , topk , k ) * topk_weight .view (- 1 , topk , 1 )).to ( a . dtype ). sum (dim = 1 )
426
+ torch_output = (a_rep .view (- 1 , topk , k ) * 1.5 * topk_weight .view (- 1 , topk , 1 )).sum (dim = 1 ). to ( a . dtype )
430
427
431
428
#print(f"torch_output {pgi.rank}: {torch_output.shape} {torch_output}")
432
429
@@ -452,35 +449,28 @@ def _pplx_dispatch_combine(
452
449
nvshmem_finalize ()
453
450
454
451
455
- @pytest .mark .parametrize ("m" , [2 , 32 , 64 , 222 ]) #, 1024 * 128])
452
+ @pytest .mark .parametrize ("m" , [4 , 32 , 64 , 222 ]) #, 1024 * 128])
456
453
@pytest .mark .parametrize ("n" , [128 , 1024 , 2048 ])
457
454
@pytest .mark .parametrize ("k" , [128 , 512 , 1024 ]) # restrictions? % 128?
458
455
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
459
456
@pytest .mark .parametrize ("topk" , TOP_KS )
460
457
@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
458
+ @pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]]) #, [[4, 2]])
461
459
def test_pplx_dispatch_combine (
462
460
m : int ,
463
461
n : int ,
464
462
k : int ,
465
463
e : int ,
466
464
topk : int ,
467
465
dtype : torch .dtype ,
466
+ world_dp_size : Tuple [int , int ],
468
467
):
469
468
current_platform .seed_everything (7 )
470
- if False :
471
- world_size = 4
472
- dp_size = 2
473
- else :
474
- world_size = 2
475
- dp_size = 1
476
-
477
- a = torch .randn ((m , k ), device = "cuda" , dtype = dtype ) / 10
478
- w1 = torch .randn ((e , 2 * n , k ), device = "cuda" , dtype = dtype ) / 10
479
- w2 = torch .randn ((e , k , n ), device = "cuda" , dtype = dtype ) / 10
480
- score = torch .randn ((m , e ), device = "cuda" , dtype = dtype )
469
+ world_size , dp_size = world_dp_size
481
470
482
471
parallel_launch (
483
- world_size , _pplx_dispatch_combine , dp_size , a , w1 , w2 , score , topk , dtype
472
+ #world_size, _pplx_dispatch_combine, dp_size, a, w1, w2, score, topk, dtype
473
+ world_size , _pplx_dispatch_combine , dp_size , m , n , k , e , topk , dtype
484
474
)
485
475
486
476
@@ -489,9 +479,10 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
489
479
490
480
num_tokens , hidden_dim = a .shape
491
481
num_experts = w1 .shape [0 ]
482
+ num_local_experts = num_experts // pgi .world_size
492
483
block_size = 128
493
484
device = pgi .device
494
- rank_num_tokens = num_tokens // pgi .world_size
485
+ rank_num_tokens = num_tokens // pgi .world_size # TODO even divide
495
486
496
487
max_num_tokens = num_tokens
497
488
#print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}")
@@ -518,6 +509,9 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
518
509
),
519
510
)
520
511
512
+ w1 = w1 .to (device )
513
+ w2 = w2 .to (device )
514
+
521
515
dispatch_combine = PplxDispatchCombine (
522
516
ata ,
523
517
max_num_tokens , # // world_size?
@@ -538,73 +532,77 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
538
532
score_chunk = chunk_by_rank (scores , rank , world_size ).to (device )
539
533
chunk_topk_weight , chunk_topk_ids = fused_topk (a_chunk , score_chunk , topk , False )
540
534
541
- #print(f"chunk_topk_ids = {chunk_topk_ids}")
535
+ #print(f"chunk_topk_ids {rank} {chunk_topk_ids.shape} {chunk_topk_ids.view(-1) }")
542
536
543
537
out = fused_experts (
544
538
a_chunk ,
545
- w1 , # chunk?
546
- w2 , # chunk?
539
+ w1 ,
540
+ w2 ,
547
541
chunk_topk_weight ,
548
542
chunk_topk_ids ,
549
- global_num_experts = num_experts #? num_local_experts?
543
+ global_num_experts = num_local_experts #? num_local_experts?
550
544
)
551
545
552
546
torch .cuda .synchronize ()
553
547
554
548
ata .destroy ()
555
549
556
- torch .distributed .barrier ()
550
+ # torch.distributed.barrier()
557
551
558
552
#print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}")
559
553
560
554
#torch.distributed.all_reduce(out)
561
555
562
- print (f"OUT { rank } : { out .shape } { out } " )
556
+ # print(f"OUT {rank}: {out.shape} {out}")
563
557
564
558
return out [:rank_num_tokens ]
565
559
566
560
567
561
def _pplx_moe (
568
562
pgi : ProcessGroupInfo ,
569
563
dp_size : int ,
570
- m : int ,
571
- n : int ,
572
- k : int ,
573
- e : int ,
564
+ a : torch . Tensor ,
565
+ w1 : torch . Tensor ,
566
+ w2 : torch . Tensor ,
567
+ score : torch . Tensor ,
574
568
topk : int ,
575
569
dtype : torch .dtype ,
576
570
):
577
571
uid = nvshmem_get_unique_id () if pgi .rank == 0 else nvshmem_alloc_empty_unique_id ()
578
572
torch .distributed .broadcast (uid , src = 0 )
579
573
nvshmem_init (uid , pgi .rank , pgi .world_size )
580
574
581
- a = torch .randn ((m , k ), device = "cuda" , dtype = dtype ) / 10
582
- w1 = torch .randn ((e , 2 * n , k ), device = "cuda" , dtype = dtype ) / 10
583
- w2 = torch .randn ((e , k , n ), device = "cuda" , dtype = dtype ) / 10
575
+ m , k = a .shape
576
+ e , _ , n = w2 .shape
584
577
585
- score = torch .randn (( m , e ), device = "cuda" , dtype = dtype )
578
+ torch .set_printoptions ( profile = "full" )
586
579
587
580
vllm_config = VllmConfig ()
588
581
with set_current_vllm_config (vllm_config ):
589
582
topk_weight , topk_ids = fused_topk (a , score , topk , False )
590
583
584
+ #print(f"topk_ids {pgi.rank} {topk_ids.shape} {topk_ids.view(-1)}")
585
+
591
586
torch_output = torch_moe2 (a , w1 , w2 , topk_weight , topk_ids )
592
587
593
- pplxd_output = torch_pplx_moe (pgi ,
594
- dp_size ,
595
- a ,
596
- w1 ,
597
- w2 ,
598
- score ,
599
- topk )
588
+ pplx_output = torch_pplx_moe (pgi ,
589
+ dp_size ,
590
+ a ,
591
+ w1 ,
592
+ w2 ,
593
+ score ,
594
+ topk )
595
+
596
+ #print(f"torch_output {pgi.rank}: {torch_output}")
600
597
601
598
if False :
602
- torch .set_printoptions (profile = "full" )
603
599
print ("BASELINE" )
604
600
print (torch_output )
605
601
print ("OUTPUT" )
606
602
print (pplx_output )
607
603
604
+ torch_output = chunk_by_rank (torch_output , pgi .rank , pgi .world_size ).to (pplx_output .device )
605
+
608
606
torch .testing .assert_close (pplx_output , torch_output , atol = 2e-2 , rtol = 0 )
609
607
610
608
nvshmem_finalize ()
@@ -616,28 +614,31 @@ def _pplx_moe(
616
614
# @pytest.mark.parametrize("e", NUM_EXPERTS)
617
615
# @pytest.mark.parametrize("topk", TOP_KS)
618
616
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
619
- @pytest .mark .parametrize ("m" , [128 ]) ##, 32]) #, 1024 * 128])
617
+ @pytest .mark .parametrize ("m" , [64 ]) ##, 32]) #, 1024 * 128])
620
618
@pytest .mark .parametrize ("n" , [128 ])
621
619
@pytest .mark .parametrize ("k" , [128 ])
622
620
@pytest .mark .parametrize ("e" , [8 ]) #NUM_EXPERTS)
623
621
@pytest .mark .parametrize ("topk" , [2 ]) #TOP_KS)
624
622
@pytest .mark .parametrize ("dtype" , [torch .bfloat16 ])
623
+ @pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]]) #, [4, 2]])
625
624
def test_pplx_moe (
626
625
m : int ,
627
626
n : int ,
628
627
k : int ,
629
628
e : int ,
630
629
topk : int ,
631
630
dtype : torch .dtype ,
631
+ world_dp_size : Tuple [int , int ],
632
632
):
633
633
current_platform .seed_everything (7 )
634
- if False :
635
- world_size = 4
636
- dp_size = 2
637
- else :
638
- world_size = 2
639
- dp_size = 1
634
+ world_size , dp_size = world_dp_size
635
+ a = torch . randn (( m , k ), device = "cuda" , dtype = dtype ) / 10
636
+ w1 = torch . randn (( e , 2 * n , k ), device = "cuda" , dtype = dtype ) / 10
637
+ w2 = torch . randn (( e , k , n ), device = "cuda" , dtype = dtype ) / 10
638
+ score = torch . randn (( m , e ), device = "cuda" , dtype = dtype )
639
+
640
640
parallel_launch (
641
- world_size , _pplx_moe , dp_size , m , n , k , e , topk , dtype
641
+ world_size , _pplx_moe , dp_size , a , w1 , w2 , score , topk , dtype
642
+ #world_size, _pplx_moe, dp_size, m, n, k, e, topk, dtype
642
643
)
643
644
0 commit comments