@@ -308,17 +308,27 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids):
308
308
# torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
309
309
310
310
311
+ def chunk_by_rank (t , r , w ):
312
+ num = t .shape [0 ]
313
+ assert num % w == 0 , f"{ num } , { w } " # for now
314
+ chunk = num // w
315
+ #print(f"chunk {t.shape}, {w}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}")
316
+ return t [(r * chunk ):(r + 1 )* chunk ]
317
+
318
+
311
319
def torch_pplx_dispatch_combine (pgi , dp_size , a , w1 , w2 , scores , topk ):
312
320
assert torch .cuda .current_device () == pgi .local_rank
313
321
314
322
num_tokens , hidden_dim = a .shape
315
323
num_experts = w1 .shape [0 ]
316
324
block_size = 128
317
325
device = pgi .device
326
+ rank_num_tokens = num_tokens // pgi .world_size
318
327
319
328
max_num_tokens = num_tokens
320
- print (f"device = { device } , max_num_tokens = { max_num_tokens } , topk = { topk } , num_ex = { num_experts } , dp_size = { dp_size } " )
329
+ # print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}")
321
330
rank = pgi .rank
331
+ world_size = pgi .world_size
322
332
323
333
ata = AllToAll (
324
334
max_num_tokens = max_num_tokens ,
@@ -342,22 +352,15 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
342
352
343
353
dispatch_combine = PplxDispatchCombine (
344
354
ata ,
345
- max_num_tokens ,
355
+ max_num_tokens , # // world_size?
346
356
pgi .world_size ,
347
357
dp_size ,
348
358
rank ,
349
359
a .dtype ,
350
360
)
351
361
352
- def chunk_by_rank (t , r ):
353
- num = t .shape [0 ]
354
- assert num % pgi .world_size == 0 , f"{ num } , { pgi .world_size } " # for now
355
- chunk = num // pgi .world_size
356
- print (f"chunk { t .shape } , { pgi .world_size } , { r } , { chunk } , { r * chunk } :{ (r + 1 )* chunk } " )
357
- return t [(r * chunk ):(r + 1 )* chunk ]
358
-
359
- a_chunk = chunk_by_rank (a , rank ).to (device )
360
- score_chunk = chunk_by_rank (scores , rank ).to (device )
362
+ a_chunk = chunk_by_rank (a , rank , world_size ).to (device )
363
+ score_chunk = chunk_by_rank (scores , rank , world_size ).to (device )
361
364
chunk_topk_weight , chunk_topk_ids = fused_topk (a_chunk , score_chunk , topk , False )
362
365
363
366
#print(f"chunk_topk_ids = {chunk_topk_ids}")
@@ -391,36 +394,41 @@ def chunk_by_rank(t, r):
391
394
392
395
torch .distributed .barrier ()
393
396
394
- return out [:num_tokens ]
397
+ #print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}")
398
+
399
+ #torch.distributed.all_reduce(out)
400
+
401
+ #print(f"AR OUT {rank}: {out.shape} {out}")
402
+
403
+ return out [:rank_num_tokens ]
395
404
396
405
397
406
def _pplx_dispatch_combine (
398
407
pgi : ProcessGroupInfo ,
399
408
dp_size : int ,
400
- m : int ,
401
- n : int ,
402
- k : int ,
403
- e : int ,
409
+ a : torch . Tensor ,
410
+ w1 : torch . Tensor ,
411
+ w2 : torch . Tensor ,
412
+ score : torch . Tensor ,
404
413
topk : int ,
405
414
dtype : torch .dtype ,
406
415
):
407
416
uid = nvshmem_get_unique_id () if pgi .rank == 0 else nvshmem_alloc_empty_unique_id ()
408
417
torch .distributed .broadcast (uid , src = 0 )
409
418
nvshmem_init (uid , pgi .rank , pgi .world_size )
410
419
411
- a = torch .randn ((m , k ), device = "cuda" , dtype = dtype ) / 10
412
- w1 = torch .randn ((e , 2 * n , k ), device = "cuda" , dtype = dtype ) / 10
413
- w2 = torch .randn ((e , k , n ), device = "cuda" , dtype = dtype ) / 10
414
-
415
- score = torch .randn ((m , e ), device = "cuda" , dtype = dtype )
420
+ m , k = a .shape
421
+ e , _ , n = w2 .shape
416
422
417
423
topk_weight , topk_ids = fused_topk (a , score , topk , False )
418
424
419
- print (f"a { a .shape } " )
420
- a_rep = torch .repeat_interleave (a , topk , dim = 1 )
421
- print (f"a_rep { a_rep .shape } " )
425
+ #print(f"a {a.shape}")
426
+ a_rep = torch .repeat_interleave (a , topk , dim = 0 )
427
+ #print(f"a_rep {a_rep.shape} {a_rep.view(-1, topk, k)}")
428
+
429
+ torch_output = (a_rep .view (- 1 , topk , k ) * topk_weight .view (- 1 , topk , 1 )).to (a .dtype ).sum (dim = 1 )
422
430
423
- torch_output = ( a_rep . view ( - 1 , topk , k ) * topk_weight . view ( - 1 , topk , 1 )). sum ( dim = 1 ). to ( a . dtype )
431
+ #print(f" torch_output {pgi.rank}: {torch_output.shape} {torch_output}" )
424
432
425
433
pplx_output = torch_pplx_dispatch_combine (pgi ,
426
434
dp_size ,
@@ -437,23 +445,25 @@ def _pplx_dispatch_combine(
437
445
print ("OUTPUT" )
438
446
print (pplx_output )
439
447
448
+ torch_output = chunk_by_rank (torch_output , pgi .rank , pgi .world_size ).to (pplx_output .device )
449
+
440
450
torch .testing .assert_close (pplx_output , torch_output , atol = 2e-2 , rtol = 0 )
441
451
442
452
nvshmem_finalize ()
443
453
444
454
445
- # @pytest.mark.parametrize("m", [1, 33 , 64, 222]) #, 1024 * 128])
446
- # @pytest.mark.parametrize("n", [128, 1024, 2048])
447
- # @pytest.mark.parametrize("k", [128, 511 , 1024])
448
- # @pytest.mark.parametrize("e", NUM_EXPERTS)
449
- # @pytest.mark.parametrize("topk", TOP_KS)
450
- # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
451
- @pytest .mark .parametrize ("m" , [128 ]) ##, 32]) #, 1024 * 128])
452
- @pytest .mark .parametrize ("n" , [128 ])
453
- @pytest .mark .parametrize ("k" , [128 ])
454
- @pytest .mark .parametrize ("e" , [8 ]) #NUM_EXPERTS)
455
- @pytest .mark .parametrize ("topk" , [2 ]) #TOP_KS)
456
- @pytest .mark .parametrize ("dtype" , [torch .bfloat16 ])
455
+ @pytest .mark .parametrize ("m" , [2 , 32 , 64 , 222 ]) #, 1024 * 128]) # what is restriction on this?
456
+ @pytest .mark .parametrize ("n" , [128 , 1024 , 2048 ])
457
+ @pytest .mark .parametrize ("k" , [128 , 512 , 1024 ]) # restrictions here?
458
+ @pytest .mark .parametrize ("e" , NUM_EXPERTS )
459
+ @pytest .mark .parametrize ("topk" , TOP_KS )
460
+ @pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
461
+ # @pytest.mark.parametrize("m", [2 ]) ##, 32]) #, 1024 * 128])
462
+ # @pytest.mark.parametrize("n", [128])
463
+ # @pytest.mark.parametrize("k", [128])
464
+ # @pytest.mark.parametrize("e", [8]) #NUM_EXPERTS)
465
+ # @pytest.mark.parametrize("topk", [2]) #TOP_KS)
466
+ # @pytest.mark.parametrize("dtype", [torch.bfloat16])
457
467
def test_pplx_dispatch_combine (
458
468
m : int ,
459
469
n : int ,
@@ -469,8 +479,14 @@ def test_pplx_dispatch_combine(
469
479
else :
470
480
world_size = 2
471
481
dp_size = 1
482
+
483
+ a = torch .randn ((m , k ), device = "cuda" , dtype = dtype ) / 10
484
+ w1 = torch .randn ((e , 2 * n , k ), device = "cuda" , dtype = dtype ) / 10
485
+ w2 = torch .randn ((e , k , n ), device = "cuda" , dtype = dtype ) / 10
486
+ score = torch .randn ((m , e ), device = "cuda" , dtype = dtype )
487
+
472
488
parallel_launch (
473
- world_size , _pplx_dispatch_combine , dp_size , m , n , k , e , topk , dtype
489
+ world_size , _pplx_dispatch_combine , dp_size , a , w1 , w2 , score , topk , dtype
474
490
)
475
491
476
492
@@ -483,6 +499,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
483
499
max_num_tokens = round_up (a .shape [0 ], 128 ) #tokens_per_expert.max()
484
500
print (f"max_num_tokens = { max_num_tokens } , topk = { topk } , num_ex = { num_experts } /{ num_local_experts } " )
485
501
rank = pgi .rank
502
+ world_size = pgi .world_size
486
503
487
504
ata = AllToAll (
488
505
max_num_tokens = max_num_tokens ,
@@ -520,14 +537,9 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
520
537
experts ,
521
538
)
522
539
523
- def chunk_by_rank (t , r ):
524
- num = t .shape [0 ]
525
- assert num % pgi .world_size == 0 , f"{ num } , { dp_size } " # for now
526
- chunk = num // pgi .world_size
527
- return t [(r * chunk ):(r + 1 )* chunk ]
528
-
529
- a_chunk = chunk_by_rank (a , rank )
530
- chunk_topk_weight , chunk_topk_ids = fused_topk (a_chunk , chunk_by_rank (scores , rank ), topk , False )
540
+ a_chunk = chunk_by_rank (a , rank , world_size )
541
+ score_chunk = chunk_by_rank (scores , rank , world_size )
542
+ chunk_topk_weight , chunk_topk_ids = fused_topk (a_chunk , score_chunk , topk , False )
531
543
532
544
print (f"chunk_topk_ids = { chunk_topk_ids } " )
533
545
0 commit comments