@@ -373,7 +373,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
373
373
num_experts , # store at PplxDispatchCombine creation?
374
374
None
375
375
)
376
- torch .cuda .synchronize () # necessary?
376
+ # torch.cuda.synchronize() # necessary?
377
377
378
378
out = torch .full (
379
379
(max_num_tokens , hidden_dim ),
@@ -452,18 +452,12 @@ def _pplx_dispatch_combine(
452
452
nvshmem_finalize ()
453
453
454
454
455
- @pytest .mark .parametrize ("m" , [2 , 32 , 64 , 222 ]) #, 1024 * 128]) # what is restriction on this?
455
+ @pytest .mark .parametrize ("m" , [2 , 32 , 64 , 222 ]) #, 1024 * 128])
456
456
@pytest .mark .parametrize ("n" , [128 , 1024 , 2048 ])
457
- @pytest .mark .parametrize ("k" , [128 , 512 , 1024 ]) # restrictions here ?
457
+ @pytest .mark .parametrize ("k" , [128 , 512 , 1024 ]) # restrictions? % 128 ?
458
458
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
459
459
@pytest .mark .parametrize ("topk" , TOP_KS )
460
460
@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
461
- # @pytest.mark.parametrize("m", [2]) ##, 32]) #, 1024 * 128])
462
- # @pytest.mark.parametrize("n", [128])
463
- # @pytest.mark.parametrize("k", [128])
464
- # @pytest.mark.parametrize("e", [8]) #NUM_EXPERTS)
465
- # @pytest.mark.parametrize("topk", [2]) #TOP_KS)
466
- # @pytest.mark.parametrize("dtype", [torch.bfloat16])
467
461
def test_pplx_dispatch_combine (
468
462
m : int ,
469
463
n : int ,
@@ -491,13 +485,16 @@ def test_pplx_dispatch_combine(
491
485
492
486
493
487
def torch_pplx_moe (pgi , dp_size , a , w1 , w2 , scores , topk ):
494
- hidden_dim = a .shape [- 1 ]
488
+ assert torch .cuda .current_device () == pgi .local_rank
489
+
490
+ num_tokens , hidden_dim = a .shape
495
491
num_experts = w1 .shape [0 ]
496
- num_local_experts = num_experts // pgi .world_size
497
492
block_size = 128
493
+ device = pgi .device
494
+ rank_num_tokens = num_tokens // pgi .world_size
498
495
499
- max_num_tokens = round_up ( a . shape [ 0 ], 128 ) #tokens_per_expert.max()
500
- print (f"max_num_tokens = { max_num_tokens } , topk = { topk } , num_ex = { num_experts } / { num_local_experts } " )
496
+ max_num_tokens = num_tokens
497
+ # print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size }")
501
498
rank = pgi .rank
502
499
world_size = pgi .world_size
503
500
@@ -523,7 +520,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
523
520
524
521
dispatch_combine = PplxDispatchCombine (
525
522
ata ,
526
- max_num_tokens ,
523
+ max_num_tokens , # // world_size?
527
524
pgi .world_size ,
528
525
dp_size ,
529
526
rank ,
@@ -537,53 +534,34 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
537
534
experts ,
538
535
)
539
536
540
- a_chunk = chunk_by_rank (a , rank , world_size )
541
- score_chunk = chunk_by_rank (scores , rank , world_size )
537
+ a_chunk = chunk_by_rank (a , rank , world_size ). to ( device )
538
+ score_chunk = chunk_by_rank (scores , rank , world_size ). to ( device )
542
539
chunk_topk_weight , chunk_topk_ids = fused_topk (a_chunk , score_chunk , topk , False )
543
540
544
- print (f"chunk_topk_ids = { chunk_topk_ids } " )
541
+ # print(f"chunk_topk_ids = {chunk_topk_ids}")
545
542
546
- # TODO: chunk up by rank
547
- if False :
548
- out = fused_experts (
549
- a_chunk ,
550
- w1 , # chunk?
551
- w2 , # chunk?
552
- chunk_topk_weight ,
553
- chunk_topk_ids ,
554
- global_num_experts = num_local_experts
555
- )
556
- # reduce outputs?
557
- else :
558
- b_a , b_a_scale , expert_num_tokens = dispatch_combine .dispatch (
559
- a_chunk ,
560
- None ,
561
- None ,
562
- chunk_topk_ids ,
563
- num_experts ,
564
- None
565
- )
566
- torch .cuda .synchronize ()
543
+ out = fused_experts (
544
+ a_chunk ,
545
+ w1 , # chunk?
546
+ w2 , # chunk?
547
+ chunk_topk_weight ,
548
+ chunk_topk_ids ,
549
+ global_num_experts = num_experts #? num_local_experts?
550
+ )
567
551
568
- out = torch .full (
569
- (max_num_tokens , hidden_dim ),
570
- torch .nan ,
571
- dtype = a .dtype ,
572
- device = a .device ,
573
- )
552
+ torch .cuda .synchronize ()
574
553
575
- dispatch_combine .combine (
576
- out ,
577
- b_a ,
578
- chunk_topk_weight ,
579
- chunk_topk_ids ,
580
- )
554
+ ata .destroy ()
581
555
582
- torch .cuda . synchronize ()
556
+ torch .distributed . barrier ()
583
557
584
- ata . destroy ( )
558
+ #print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}" )
585
559
586
- return out
560
+ #torch.distributed.all_reduce(out)
561
+
562
+ print (f"OUT { rank } : { out .shape } { out } " )
563
+
564
+ return out [:rank_num_tokens ]
587
565
588
566
589
567
def _pplx_moe (
@@ -612,29 +590,29 @@ def _pplx_moe(
612
590
613
591
torch_output = torch_moe2 (a , w1 , w2 , topk_weight , topk_ids )
614
592
615
- triton_output = torch_pplx_moe (pgi ,
616
- dp_size ,
617
- a ,
618
- w1 ,
619
- w2 ,
620
- score ,
621
- topk )
593
+ pplxd_output = torch_pplx_moe (pgi ,
594
+ dp_size ,
595
+ a ,
596
+ w1 ,
597
+ w2 ,
598
+ score ,
599
+ topk )
622
600
623
601
if False :
624
602
torch .set_printoptions (profile = "full" )
625
603
print ("BASELINE" )
626
604
print (torch_output )
627
605
print ("OUTPUT" )
628
- print (triton_output )
606
+ print (pplx_output )
629
607
630
- torch .testing .assert_close (triton_output , torch_output , atol = 2e-2 , rtol = 0 )
608
+ torch .testing .assert_close (pplx_output , torch_output , atol = 2e-2 , rtol = 0 )
631
609
632
610
nvshmem_finalize ()
633
611
634
612
635
613
# @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128])
636
614
# @pytest.mark.parametrize("n", [128, 1024, 2048])
637
- # @pytest.mark.parametrize("k", [128, 511 , 1024])
615
+ # @pytest.mark.parametrize("k", [128, 512 , 1024])
638
616
# @pytest.mark.parametrize("e", NUM_EXPERTS)
639
617
# @pytest.mark.parametrize("topk", TOP_KS)
640
618
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
0 commit comments