@@ -297,18 +297,24 @@ def chunk_by_rank(t, r, w):
297
297
return t [(r * chunk ):(r + 1 ) * chunk ]
298
298
299
299
300
- def torch_pplx_dispatch_combine (pgi , dp_size , a , w1 , w2 , scores , topk ):
300
+ ata = None
301
+
302
+ def pplx_dispatch_combine (pgi , dp_size , a , topk_weight , topk_ids , num_experts ):
301
303
assert torch .cuda .current_device () == pgi .local_rank
302
304
305
+ topk = topk_ids .shape [1 ]
306
+
307
+ #tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts)
308
+
303
309
num_tokens , hidden_dim = a .shape
304
- num_experts = w1 .shape [0 ]
305
310
block_size = 128
306
311
device = pgi .device
307
312
rank = pgi .rank
308
313
world_size = pgi .world_size
309
- rank_num_tokens = rank_chunk (num_tokens , rank , world_size )
310
- max_num_tokens = max ( num_tokens , 1 )
314
+ max_num_tokens = rank_chunk (num_tokens , 0 , world_size )
315
+ print ( f"MAX_NUM_TOKENS = { max_num_tokens } " )
311
316
317
+ global ata
312
318
ata = AllToAll .internode (
313
319
max_num_tokens = max_num_tokens ,
314
320
num_experts = num_experts ,
@@ -333,21 +339,25 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
333
339
)
334
340
335
341
a_chunk = chunk_by_rank (a , rank , world_size ).to (device )
336
- score_chunk = chunk_by_rank (scores , rank , world_size ).to (device )
337
- chunk_topk_weight , chunk_topk_ids = fused_topk (a_chunk , score_chunk , topk ,
338
- False )
342
+ num_tokens = a_chunk .shape [0 ]
343
+ chunk_topk_weight = chunk_by_rank (topk_weight , rank , world_size ).to (device )
344
+ chunk_topk_ids = chunk_by_rank (topk_ids , rank , world_size ).to (device )
345
+
346
+ print (f"{ rank } : shapes { a_chunk .shape } , { chunk_topk_weight .shape } , { chunk_topk_ids .shape } , E={ num_experts } " )
339
347
340
348
b_a , b_a_scale , expert_num_tokens = dispatch_combine .dispatch (
341
349
a_chunk ,
342
350
None ,
343
351
None ,
344
352
chunk_topk_weight ,
345
353
chunk_topk_ids ,
346
- num_experts , # store at PplxDispatchCombine creation?
354
+ num_experts ,
347
355
None ,
348
356
False ,
349
357
)
350
358
359
+ #torch.cuda.synchronize()
360
+
351
361
if False :
352
362
naive_b_a , tokens_per_expert = torch_dispatch (a_chunk , chunk_topk_ids ,
353
363
num_experts )
@@ -364,7 +374,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
364
374
b_a = b_a * 1.5
365
375
366
376
out = torch .full (
367
- (rank_num_tokens , hidden_dim ),
377
+ (max_num_tokens , hidden_dim ),
368
378
torch .nan ,
369
379
dtype = a .dtype ,
370
380
device = device ,
@@ -377,60 +387,56 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
377
387
chunk_topk_ids ,
378
388
False ,
379
389
)
380
- torch .cuda .synchronize ()
381
390
382
- ata .destroy ()
391
+ #torch.cuda.synchronize()
392
+
393
+ #ata.destroy()
383
394
384
- return out [:rank_num_tokens ]
395
+ return out [:num_tokens ]
385
396
386
397
387
398
def _pplx_dispatch_combine (
388
399
pgi : ProcessGroupInfo ,
389
400
dp_size : int ,
390
- m ,
391
- n ,
392
- k ,
393
- e ,
394
- topk : int ,
395
- dtype : torch .dtype ,
401
+ a ,
402
+ topk_weight ,
403
+ topk_ids ,
404
+ num_experts ,
396
405
):
397
406
uid = nvshmem_get_unique_id (
398
407
) if pgi .rank == 0 else nvshmem_alloc_empty_unique_id ()
399
408
torch .distributed .broadcast (uid , src = 0 )
400
409
nvshmem_init (uid , pgi .rank , pgi .world_size )
401
410
device = pgi .device
402
411
403
- a = torch .randn ((m , k ), device = device , dtype = dtype ) / 10
404
- w1 = torch .randn ((e , 2 * n , k ), device = device , dtype = dtype ) / 10
405
- w2 = torch .randn ((e , k , n ), device = device , dtype = dtype ) / 10
406
- score = torch .randn ((m , e ), device = device , dtype = dtype )
407
-
408
- topk_weight , topk_ids = fused_topk (a , score , topk , False )
412
+ k = a .shape [1 ]
413
+ topk = topk_ids .shape [1 ]
409
414
410
- a_rep = torch .repeat_interleave (a , topk , dim = 0 )
415
+ a_rep = torch .repeat_interleave (a , topk , dim = 0 ). to ( device )
411
416
412
417
torch_output = (a_rep .view (- 1 , topk , k ) * 1.5 *
413
- topk_weight .view (- 1 , topk , 1 )).sum (dim = 1 ).to (a .dtype )
418
+ topk_weight .view (- 1 , topk , 1 ). to ( device ) ).sum (dim = 1 ).to (a .dtype )
414
419
415
- pplx_output = torch_pplx_dispatch_combine (pgi , dp_size , a , w1 , w2 , score ,
416
- topk )
420
+ pplx_output = pplx_dispatch_combine (pgi , dp_size , a , topk_weight , topk_ids , num_experts )
417
421
418
422
torch_output = chunk_by_rank (torch_output , pgi .rank ,
419
423
pgi .world_size ).to (pplx_output .device )
420
424
425
+ print (f"{ pgi .rank } : out shapes { pplx_output .shape } , { torch_output .shape } " )
426
+
421
427
torch .testing .assert_close (pplx_output , torch_output , atol = 2e-2 , rtol = 0 )
422
428
423
429
nvshmem_finalize ()
424
430
425
431
426
432
# TODO: M < world_size doesn't appear to be supported by pplx?
427
- @pytest .mark .parametrize ("m" , [4 , 32 , 64 , 222 ])
433
+ @pytest .mark .parametrize ("m" , [1 , 4 , 32 , 64 , 222 ])
428
434
@pytest .mark .parametrize ("n" , [128 , 1024 , 2048 ])
429
435
@pytest .mark .parametrize ("k" , [128 , 512 , 1024 ])
430
436
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
431
437
@pytest .mark .parametrize ("topk" , TOP_KS )
432
438
@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
433
- @pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]]) #, [[4, 2]])
439
+ @pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]]) # [[4, 2]])
434
440
@requires_pplx
435
441
def test_pplx_dispatch_combine (
436
442
m : int ,
@@ -443,22 +449,27 @@ def test_pplx_dispatch_combine(
443
449
):
444
450
current_platform .seed_everything (7 )
445
451
world_size , dp_size = world_dp_size
452
+ device = "cuda"
453
+
454
+ a = torch .randn ((m , k ), device = device , dtype = dtype ) / 10
455
+ score = torch .randn ((m , e ), device = device , dtype = dtype )
456
+
457
+ topk_weight , topk_ids = fused_topk (a , score , topk , False )
446
458
447
- parallel_launch (world_size , _pplx_dispatch_combine , dp_size , m , n , k , e ,
448
- topk , dtype )
459
+ parallel_launch (world_size , _pplx_dispatch_combine , dp_size , a , topk_weight , topk_ids , e )
449
460
450
461
451
- def torch_pplx_moe (pgi , dp_size , a , w1 , w2 , scores , topk ):
462
+ def pplx_moe (pgi , dp_size , a , w1 , w2 , topk_weight , topk_ids ):
452
463
assert torch .cuda .current_device () == pgi .local_rank
453
464
454
- num_tokens , hidden_dim = a .shape
465
+ hidden_dim = a .shape [ 1 ]
455
466
num_experts = w1 .shape [0 ]
456
467
block_size = 128
457
468
device = pgi .device
458
469
rank = pgi .rank
459
470
world_size = pgi .world_size
460
- rank_num_tokens = rank_chunk ( num_tokens , rank , world_size )
461
- max_num_tokens = num_tokens
471
+ topk = topk_ids . shape [ 1 ]
472
+ max_num_tokens = rank_chunk ( a . shape [ 0 ], 0 , world_size )
462
473
463
474
ata = AllToAll .internode (
464
475
max_num_tokens = max_num_tokens ,
@@ -474,9 +485,6 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
474
485
torch .float32 .itemsize )),
475
486
)
476
487
477
- w1 = w1 .to (device )
478
- w2 = w2 .to (device )
479
-
480
488
dispatch_combine = PplxDispatchCombine (
481
489
ata ,
482
490
max_num_tokens ,
@@ -493,15 +501,14 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
493
501
)
494
502
495
503
a_chunk = chunk_by_rank (a , rank , world_size ).to (device )
496
- score_chunk = chunk_by_rank (scores , rank , world_size ).to (device )
497
- chunk_topk_weight , chunk_topk_ids = fused_topk (a_chunk , score_chunk , topk ,
498
- False )
504
+ chunk_topk_weight = chunk_by_rank (topk_weight , rank , world_size ).to (device )
505
+ chunk_topk_ids = chunk_by_rank (topk_ids , rank , world_size ).to (device )
499
506
500
507
out = fused_experts (
501
508
a_chunk ,
502
509
# Chunking weights like this only works for batched format
503
- chunk_by_rank (w1 , rank , world_size ),
504
- chunk_by_rank (w2 , rank , world_size ),
510
+ chunk_by_rank (w1 , rank , world_size ). to ( device ) ,
511
+ chunk_by_rank (w2 , rank , world_size ). to ( device ) ,
505
512
chunk_topk_weight ,
506
513
chunk_topk_ids ,
507
514
global_num_experts = num_experts )
@@ -510,7 +517,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
510
517
511
518
ata .destroy ()
512
519
513
- return out [: rank_num_tokens ]
520
+ return out
514
521
515
522
516
523
def _pplx_moe (
@@ -521,7 +528,6 @@ def _pplx_moe(
521
528
w2 : torch .Tensor ,
522
529
score : torch .Tensor ,
523
530
topk : int ,
524
- dtype : torch .dtype ,
525
531
):
526
532
uid = nvshmem_get_unique_id (
527
533
) if pgi .rank == 0 else nvshmem_alloc_empty_unique_id ()
@@ -534,7 +540,7 @@ def _pplx_moe(
534
540
with set_current_vllm_config (vllm_config ):
535
541
topk_weight , topk_ids = fused_topk (a , score , topk , False )
536
542
torch_output = torch_moe2 (a , w1 , w2 , topk_weight , topk_ids )
537
- pplx_output = torch_pplx_moe (pgi , dp_size , a , w1 , w2 , score , topk )
543
+ pplx_output = pplx_moe (pgi , dp_size , a , w1 , w2 , topk_weight , topk_ids )
538
544
539
545
torch_output = chunk_by_rank (torch_output , pgi .rank ,
540
546
pgi .world_size ).to (pplx_output .device )
@@ -544,8 +550,7 @@ def _pplx_moe(
544
550
nvshmem_finalize ()
545
551
546
552
547
- # TODO: M < world_size doesn't appear to be supported by pplx?
548
- @pytest .mark .parametrize ("m" , [2 , 3 , 32 , 45 , 64 , 222 ])
553
+ @pytest .mark .parametrize ("m" , [1 , 2 , 3 , 32 , 45 , 64 , 222 ])
549
554
@pytest .mark .parametrize ("n" , [128 , 1024 , 2048 ])
550
555
@pytest .mark .parametrize ("k" , [128 , 512 , 1024 ])
551
556
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
@@ -569,5 +574,4 @@ def test_pplx_moe(
569
574
w2 = torch .randn ((e , k , n ), device = "cuda" , dtype = dtype ) / 10
570
575
score = torch .randn ((m , e ), device = "cuda" , dtype = dtype )
571
576
572
- parallel_launch (world_size , _pplx_moe , dp_size , a , w1 , w2 , score , topk ,
573
- dtype )
577
+ parallel_launch (world_size , _pplx_moe , dp_size , a , w1 , w2 , score , topk )
0 commit comments