28
28
from vllm .config import VllmConfig , set_current_vllm_config
29
29
from vllm .model_executor .layers .activation import SiluAndMul
30
30
from vllm .model_executor .layers .fused_moe .fused_batched_moe import (
31
- BatchedExperts )
31
+ BatchedExperts , BatchedTritonExperts )
32
32
from vllm .model_executor .layers .fused_moe .fused_moe import fused_topk
33
33
from vllm .model_executor .layers .fused_moe .modular_kernel import (
34
34
FusedMoEModularKernel )
@@ -293,34 +293,26 @@ def rank_chunk(num, r, w):
293
293
294
294
def chunk_by_rank (t , r , w ):
295
295
chunk = rank_chunk (t .shape [0 ], r , w )
296
- #print(f"chunk {t.shape}, {w}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}")
297
296
return t [(r * chunk ):(r + 1 ) * chunk ]
298
297
299
298
300
- ata = None
301
-
302
299
def pplx_dispatch_combine (pgi , dp_size , a , topk_weight , topk_ids , num_experts ):
303
300
assert torch .cuda .current_device () == pgi .local_rank
304
301
305
302
topk = topk_ids .shape [1 ]
306
-
307
- #tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts)
308
-
309
- num_tokens , hidden_dim = a .shape
303
+ num_tokens , hidden_dim = a .shape [1 ]
310
304
block_size = 128
311
305
device = pgi .device
312
306
rank = pgi .rank
313
307
world_size = pgi .world_size
314
308
max_num_tokens = rank_chunk (num_tokens , 0 , world_size )
315
- print (f"MAX_NUM_TOKENS = { max_num_tokens } " )
316
309
317
- global ata
318
310
ata = AllToAll .internode (
319
311
max_num_tokens = max_num_tokens ,
320
312
num_experts = num_experts ,
321
313
experts_per_token = topk ,
322
314
rank = rank ,
323
- world_size = pgi . world_size ,
315
+ world_size = world_size ,
324
316
dp_size = dp_size ,
325
317
hidden_dim = hidden_dim ,
326
318
hidden_dim_bytes = hidden_dim * a .dtype .itemsize ,
@@ -332,19 +324,15 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
332
324
dispatch_combine = PplxDispatchCombine (
333
325
ata ,
334
326
max_num_tokens ,
335
- pgi . world_size ,
327
+ world_size ,
336
328
dp_size ,
337
329
rank ,
338
- a .dtype ,
339
330
)
340
331
341
332
a_chunk = chunk_by_rank (a , rank , world_size ).to (device )
342
- num_tokens = a_chunk .shape [0 ]
343
333
chunk_topk_weight = chunk_by_rank (topk_weight , rank , world_size ).to (device )
344
334
chunk_topk_ids = chunk_by_rank (topk_ids , rank , world_size ).to (device )
345
335
346
- print (f"{ rank } : shapes { a_chunk .shape } , { chunk_topk_weight .shape } , { chunk_topk_ids .shape } , E={ num_experts } " )
347
-
348
336
b_a , b_a_scale , expert_num_tokens = dispatch_combine .dispatch (
349
337
a_chunk ,
350
338
None ,
@@ -356,21 +344,6 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
356
344
False ,
357
345
)
358
346
359
- #torch.cuda.synchronize()
360
-
361
- if False :
362
- naive_b_a , tokens_per_expert = torch_dispatch (a_chunk , chunk_topk_ids ,
363
- num_experts )
364
-
365
- torch .distributed .all_reduce (tokens_per_expert )
366
- tokens_per_expert = chunk_by_rank (tokens_per_expert , rank ,
367
- world_size ).to (dtype = torch .int32 )
368
-
369
- torch .testing .assert_close (tokens_per_expert ,
370
- expert_num_tokens ,
371
- atol = 0 ,
372
- rtol = 0 )
373
-
374
347
b_a = b_a * 1.5
375
348
376
349
out = torch .full (
@@ -388,9 +361,11 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
388
361
False ,
389
362
)
390
363
391
- # torch.cuda.synchronize()
364
+ torch .cuda .synchronize ()
392
365
393
- #ata.destroy()
366
+ ata .destroy ()
367
+
368
+ num_tokens = a_chunk .shape [0 ]
394
369
395
370
return out [:num_tokens ]
396
371
@@ -399,8 +374,8 @@ def _pplx_dispatch_combine(
399
374
pgi : ProcessGroupInfo ,
400
375
dp_size : int ,
401
376
a ,
402
- topk_weight ,
403
- topk_ids ,
377
+ score ,
378
+ topk ,
404
379
num_experts ,
405
380
):
406
381
uid = nvshmem_get_unique_id (
@@ -409,8 +384,8 @@ def _pplx_dispatch_combine(
409
384
nvshmem_init (uid , pgi .rank , pgi .world_size )
410
385
device = pgi .device
411
386
387
+ topk_weight , topk_ids = fused_topk (a , score , topk , False )
412
388
k = a .shape [1 ]
413
- topk = topk_ids .shape [1 ]
414
389
415
390
a_rep = torch .repeat_interleave (a , topk , dim = 0 ).to (device )
416
391
@@ -422,21 +397,19 @@ def _pplx_dispatch_combine(
422
397
torch_output = chunk_by_rank (torch_output , pgi .rank ,
423
398
pgi .world_size ).to (pplx_output .device )
424
399
425
- print (f"{ pgi .rank } : out shapes { pplx_output .shape } , { torch_output .shape } " )
426
-
427
400
torch .testing .assert_close (pplx_output , torch_output , atol = 2e-2 , rtol = 0 )
428
401
429
402
nvshmem_finalize ()
430
403
431
404
432
- # TODO: M < world_size doesn't appear to be supported by pplx?
433
- @pytest .mark .parametrize ("m" , [1 , 4 , 32 , 64 , 222 ])
405
+ # TODO: this test point does not work for M == 1
406
+ @pytest .mark .parametrize ("m" , [4 , 32 , 64 , 222 ])
434
407
@pytest .mark .parametrize ("n" , [128 , 1024 , 2048 ])
435
408
@pytest .mark .parametrize ("k" , [128 , 512 , 1024 ])
436
409
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
437
410
@pytest .mark .parametrize ("topk" , TOP_KS )
438
411
@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
439
- @pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]]) #[[4, 2]])
412
+ @pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]])
440
413
@requires_pplx
441
414
def test_pplx_dispatch_combine (
442
415
m : int ,
@@ -450,13 +423,10 @@ def test_pplx_dispatch_combine(
450
423
current_platform .seed_everything (7 )
451
424
world_size , dp_size = world_dp_size
452
425
device = "cuda"
453
-
454
426
a = torch .randn ((m , k ), device = device , dtype = dtype ) / 10
455
427
score = torch .randn ((m , e ), device = device , dtype = dtype )
456
428
457
- topk_weight , topk_ids = fused_topk (a , score , topk , False )
458
-
459
- parallel_launch (world_size , _pplx_dispatch_combine , dp_size , a , topk_weight , topk_ids , e )
429
+ parallel_launch (world_size , _pplx_dispatch_combine , dp_size , a , score , topk , e )
460
430
461
431
462
432
def pplx_moe (pgi , dp_size , a , w1 , w2 , topk_weight , topk_ids ):
@@ -476,7 +446,7 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
476
446
num_experts = num_experts ,
477
447
experts_per_token = topk ,
478
448
rank = rank ,
479
- world_size = pgi . world_size ,
449
+ world_size = world_size ,
480
450
dp_size = dp_size ,
481
451
hidden_dim = hidden_dim ,
482
452
hidden_dim_bytes = hidden_dim * a .dtype .itemsize ,
@@ -488,12 +458,12 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
488
458
dispatch_combine = PplxDispatchCombine (
489
459
ata ,
490
460
max_num_tokens ,
491
- pgi . world_size ,
461
+ world_size ,
492
462
dp_size ,
493
463
rank ,
494
464
)
495
465
496
- experts = BatchedExperts (max_num_tokens )
466
+ experts = BatchedExperts (a . shape [ 0 ] )
497
467
498
468
fused_experts = FusedMoEModularKernel (
499
469
dispatch_combine ,
@@ -556,7 +526,7 @@ def _pplx_moe(
556
526
@pytest .mark .parametrize ("e" , NUM_EXPERTS )
557
527
@pytest .mark .parametrize ("topk" , TOP_KS )
558
528
@pytest .mark .parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
559
- @pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]]) #, [4, 2]])
529
+ @pytest .mark .parametrize ("world_dp_size" , [[2 , 1 ]])
560
530
@requires_pplx
561
531
def test_pplx_moe (
562
532
m : int ,
0 commit comments