@@ -299,10 +299,13 @@ def test_fused_moe_batched_experts(
299
299
torch .testing .assert_close (triton_output , torch_output , atol = 2e-2 , rtol = 0 )
300
300
301
301
302
+ def rank_chunk (num , r , w ):
303
+ rem = num % w
304
+ return (num // w ) + (1 if r < rem else 0 )
305
+
306
+
302
307
def chunk_by_rank (t , r , w ):
303
- num = t .shape [0 ]
304
- assert num % w == 0 , f"{ num } , { w } " # for now
305
- chunk = num // w
308
+ chunk = rank_chunk (t .shape [0 ], r , w )
306
309
#print(f"chunk {t.shape}, {w}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}")
307
310
return t [(r * chunk ):(r + 1 )* chunk ]
308
311
@@ -312,12 +315,11 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
312
315
313
316
num_tokens , hidden_dim = a .shape
314
317
num_experts = w1 .shape [0 ]
315
- num_local_experts = w1 .shape [0 ] // pgi .world_size
316
318
block_size = 128
317
319
device = pgi .device
318
- rank_num_tokens = num_tokens // pgi .world_size
319
320
rank = pgi .rank
320
321
world_size = pgi .world_size
322
+ rank_num_tokens = rank_chunk (num_tokens , rank , world_size )
321
323
max_num_tokens = num_tokens
322
324
#print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}")
323
325
@@ -354,7 +356,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
354
356
score_chunk = chunk_by_rank (scores , rank , world_size ).to (device )
355
357
chunk_topk_weight , chunk_topk_ids = fused_topk (a_chunk , score_chunk , topk , False )
356
358
357
- print (f"chunk_topk_ids = { chunk_topk_ids .view (- 1 )} " )
359
+ # print(f"chunk_topk_ids = {chunk_topk_ids.view(-1)}")
358
360
359
361
b_a , b_a_scale , expert_num_tokens = dispatch_combine .dispatch (
360
362
a_chunk ,
@@ -372,8 +374,8 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
372
374
#max_num = tokens_per_expert.max()
373
375
tokens_per_expert = chunk_by_rank (tokens_per_expert , rank , world_size ).to (dtype = torch .int32 )
374
376
375
- print (f"tpe { tokens_per_expert } " )
376
- print (f"ent { expert_num_tokens } " )
377
+ # print(f"tpe {tokens_per_expert}")
378
+ # print(f"ent {expert_num_tokens}")
377
379
378
380
#torch.set_printoptions(profile="full")
379
381
#torch.distributed.all_reduce(naive_b_a, op=torch.distributed.ReduceOp.MAX)
@@ -501,15 +503,12 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
501
503
502
504
num_tokens , hidden_dim = a .shape
503
505
num_experts = w1 .shape [0 ]
504
- num_local_experts = num_experts // pgi .world_size
505
506
block_size = 128
506
507
device = pgi .device
507
- rank_num_tokens = num_tokens // pgi .world_size # TODO even divide
508
-
509
- max_num_tokens = num_tokens
510
- #print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}")
511
508
rank = pgi .rank
512
509
world_size = pgi .world_size
510
+ rank_num_tokens = rank_chunk (num_tokens , rank , world_size )
511
+ max_num_tokens = num_tokens
513
512
514
513
ata = AllToAll (
515
514
max_num_tokens = max_num_tokens ,
@@ -558,6 +557,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
558
557
559
558
out = fused_experts (
560
559
a_chunk ,
560
+ # Chunking weights like this only works for batched format
561
561
chunk_by_rank (w1 , rank , world_size ),
562
562
chunk_by_rank (w2 , rank , world_size ),
563
563
chunk_topk_weight ,
@@ -571,7 +571,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
571
571
572
572
#print(f"OUT {rank}: {out.shape} {out}")
573
573
574
- return out [:rank_num_tokens ] # chunk_by_rank?
574
+ return out [:rank_num_tokens ]
575
575
576
576
577
577
def _pplx_moe (
@@ -624,18 +624,13 @@ def _pplx_moe(
624
624
nvshmem_finalize ()
625
625
626
626
627
- @pytest .mark .parametrize ("m" , [2 , 32 , 64 , 222 ]) #, 1024 * 128])
628
- @pytest .mark .parametrize ("n" , [128 , 1024 , 2048 ])
629
- @pytest .mark .parametrize ("k" , [128 , 512 , 1024 ])
627
+ # TODO: M == 1 doesn't work
628
+ @pytest .mark .parametrize ("m" , [2 , 3 , 32 , 45 , 64 , 222 ]) #, 1024 * 128])
629
+ @pytest .mark .parametrize ("n" , [128 , 1024 ])# , 2048])
630
+ @pytest .mark .parametrize ("k" , [128 , 512 ]) # , 1024])
630
631
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
631
632
@pytest .mark .parametrize ("topk" , TOP_KS )
632
633
@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
633
- # @pytest.mark.parametrize("m", [64]) ##, 32]) #, 1024 * 128])
634
- # @pytest.mark.parametrize("n", [128])
635
- # @pytest.mark.parametrize("k", [128])
636
- # @pytest.mark.parametrize("e", [8]) #NUM_EXPERTS)
637
- # @pytest.mark.parametrize("topk", [2]) #TOP_KS)
638
- # @pytest.mark.parametrize("dtype", [torch.bfloat16])
639
634
@pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]]) #, [4, 2]])
640
635
def test_pplx_moe (
641
636
m : int ,
0 commit comments