Skip to content

Commit 3e8a0e3

Browse files
committed
fix M=1 pplx test
Signed-off-by: Bill Nell <[email protected]>
1 parent f74ab61 commit 3e8a0e3

File tree

2 files changed

+56
-52
lines changed

2 files changed

+56
-52
lines changed

tests/kernels/moe/test_pplx_moe.py

Lines changed: 55 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -297,18 +297,24 @@ def chunk_by_rank(t, r, w):
297297
return t[(r * chunk):(r + 1) * chunk]
298298

299299

300-
def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
300+
ata = None
301+
302+
def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
301303
assert torch.cuda.current_device() == pgi.local_rank
302304

305+
topk = topk_ids.shape[1]
306+
307+
#tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts)
308+
303309
num_tokens, hidden_dim = a.shape
304-
num_experts = w1.shape[0]
305310
block_size = 128
306311
device = pgi.device
307312
rank = pgi.rank
308313
world_size = pgi.world_size
309-
rank_num_tokens = rank_chunk(num_tokens, rank, world_size)
310-
max_num_tokens = max(num_tokens, 1)
314+
max_num_tokens = rank_chunk(num_tokens, 0, world_size)
315+
print(f"MAX_NUM_TOKENS = {max_num_tokens}")
311316

317+
global ata
312318
ata = AllToAll.internode(
313319
max_num_tokens=max_num_tokens,
314320
num_experts=num_experts,
@@ -333,21 +339,25 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
333339
)
334340

335341
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
336-
score_chunk = chunk_by_rank(scores, rank, world_size).to(device)
337-
chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk,
338-
False)
342+
num_tokens = a_chunk.shape[0]
343+
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
344+
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
345+
346+
print(f"{rank}: shapes {a_chunk.shape}, {chunk_topk_weight.shape}, {chunk_topk_ids.shape}, E={num_experts}")
339347

340348
b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch(
341349
a_chunk,
342350
None,
343351
None,
344352
chunk_topk_weight,
345353
chunk_topk_ids,
346-
num_experts, # store at PplxDispatchCombine creation?
354+
num_experts,
347355
None,
348356
False,
349357
)
350358

359+
#torch.cuda.synchronize()
360+
351361
if False:
352362
naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids,
353363
num_experts)
@@ -364,7 +374,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
364374
b_a = b_a * 1.5
365375

366376
out = torch.full(
367-
(rank_num_tokens, hidden_dim),
377+
(max_num_tokens, hidden_dim),
368378
torch.nan,
369379
dtype=a.dtype,
370380
device=device,
@@ -377,60 +387,56 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
377387
chunk_topk_ids,
378388
False,
379389
)
380-
torch.cuda.synchronize()
381390

382-
ata.destroy()
391+
#torch.cuda.synchronize()
392+
393+
#ata.destroy()
383394

384-
return out[:rank_num_tokens]
395+
return out[:num_tokens]
385396

386397

387398
def _pplx_dispatch_combine(
388399
pgi: ProcessGroupInfo,
389400
dp_size: int,
390-
m,
391-
n,
392-
k,
393-
e,
394-
topk: int,
395-
dtype: torch.dtype,
401+
a,
402+
topk_weight,
403+
topk_ids,
404+
num_experts,
396405
):
397406
uid = nvshmem_get_unique_id(
398407
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
399408
torch.distributed.broadcast(uid, src=0)
400409
nvshmem_init(uid, pgi.rank, pgi.world_size)
401410
device = pgi.device
402411

403-
a = torch.randn((m, k), device=device, dtype=dtype) / 10
404-
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 10
405-
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10
406-
score = torch.randn((m, e), device=device, dtype=dtype)
407-
408-
topk_weight, topk_ids = fused_topk(a, score, topk, False)
412+
k = a.shape[1]
413+
topk = topk_ids.shape[1]
409414

410-
a_rep = torch.repeat_interleave(a, topk, dim=0)
415+
a_rep = torch.repeat_interleave(a, topk, dim=0).to(device)
411416

412417
torch_output = (a_rep.view(-1, topk, k) * 1.5 *
413-
topk_weight.view(-1, topk, 1)).sum(dim=1).to(a.dtype)
418+
topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to(a.dtype)
414419

415-
pplx_output = torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, score,
416-
topk)
420+
pplx_output = pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts)
417421

418422
torch_output = chunk_by_rank(torch_output, pgi.rank,
419423
pgi.world_size).to(pplx_output.device)
420424

425+
print(f"{pgi.rank}: out shapes {pplx_output.shape}, {torch_output.shape}")
426+
421427
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
422428

423429
nvshmem_finalize()
424430

425431

426432
# TODO: M < world_size doesn't appear to be supported by pplx?
427-
@pytest.mark.parametrize("m", [4, 32, 64, 222])
433+
@pytest.mark.parametrize("m", [1, 4, 32, 64, 222])
428434
@pytest.mark.parametrize("n", [128, 1024, 2048])
429435
@pytest.mark.parametrize("k", [128, 512, 1024])
430436
@pytest.mark.parametrize("e", NUM_EXPERTS)
431437
@pytest.mark.parametrize("topk", TOP_KS)
432438
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
433-
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [[4, 2]])
439+
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #[[4, 2]])
434440
@requires_pplx
435441
def test_pplx_dispatch_combine(
436442
m: int,
@@ -443,22 +449,27 @@ def test_pplx_dispatch_combine(
443449
):
444450
current_platform.seed_everything(7)
445451
world_size, dp_size = world_dp_size
452+
device = "cuda"
453+
454+
a = torch.randn((m, k), device=device, dtype=dtype) / 10
455+
score = torch.randn((m, e), device=device, dtype=dtype)
456+
457+
topk_weight, topk_ids = fused_topk(a, score, topk, False)
446458

447-
parallel_launch(world_size, _pplx_dispatch_combine, dp_size, m, n, k, e,
448-
topk, dtype)
459+
parallel_launch(world_size, _pplx_dispatch_combine, dp_size, a, topk_weight, topk_ids, e)
449460

450461

451-
def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
462+
def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
452463
assert torch.cuda.current_device() == pgi.local_rank
453464

454-
num_tokens, hidden_dim = a.shape
465+
hidden_dim = a.shape[1]
455466
num_experts = w1.shape[0]
456467
block_size = 128
457468
device = pgi.device
458469
rank = pgi.rank
459470
world_size = pgi.world_size
460-
rank_num_tokens = rank_chunk(num_tokens, rank, world_size)
461-
max_num_tokens = num_tokens
471+
topk = topk_ids.shape[1]
472+
max_num_tokens = rank_chunk(a.shape[0], 0, world_size)
462473

463474
ata = AllToAll.internode(
464475
max_num_tokens=max_num_tokens,
@@ -474,9 +485,6 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
474485
torch.float32.itemsize)),
475486
)
476487

477-
w1 = w1.to(device)
478-
w2 = w2.to(device)
479-
480488
dispatch_combine = PplxDispatchCombine(
481489
ata,
482490
max_num_tokens,
@@ -493,15 +501,14 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
493501
)
494502

495503
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
496-
score_chunk = chunk_by_rank(scores, rank, world_size).to(device)
497-
chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk,
498-
False)
504+
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
505+
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
499506

500507
out = fused_experts(
501508
a_chunk,
502509
# Chunking weights like this only works for batched format
503-
chunk_by_rank(w1, rank, world_size),
504-
chunk_by_rank(w2, rank, world_size),
510+
chunk_by_rank(w1, rank, world_size).to(device),
511+
chunk_by_rank(w2, rank, world_size).to(device),
505512
chunk_topk_weight,
506513
chunk_topk_ids,
507514
global_num_experts=num_experts)
@@ -510,7 +517,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
510517

511518
ata.destroy()
512519

513-
return out[:rank_num_tokens]
520+
return out
514521

515522

516523
def _pplx_moe(
@@ -521,7 +528,6 @@ def _pplx_moe(
521528
w2: torch.Tensor,
522529
score: torch.Tensor,
523530
topk: int,
524-
dtype: torch.dtype,
525531
):
526532
uid = nvshmem_get_unique_id(
527533
) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
@@ -534,7 +540,7 @@ def _pplx_moe(
534540
with set_current_vllm_config(vllm_config):
535541
topk_weight, topk_ids = fused_topk(a, score, topk, False)
536542
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
537-
pplx_output = torch_pplx_moe(pgi, dp_size, a, w1, w2, score, topk)
543+
pplx_output = pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids)
538544

539545
torch_output = chunk_by_rank(torch_output, pgi.rank,
540546
pgi.world_size).to(pplx_output.device)
@@ -544,8 +550,7 @@ def _pplx_moe(
544550
nvshmem_finalize()
545551

546552

547-
# TODO: M < world_size doesn't appear to be supported by pplx?
548-
@pytest.mark.parametrize("m", [2, 3, 32, 45, 64, 222])
553+
@pytest.mark.parametrize("m", [1, 2, 3, 32, 45, 64, 222])
549554
@pytest.mark.parametrize("n", [128, 1024, 2048])
550555
@pytest.mark.parametrize("k", [128, 512, 1024])
551556
@pytest.mark.parametrize("e", NUM_EXPERTS)
@@ -569,5 +574,4 @@ def test_pplx_moe(
569574
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
570575
score = torch.randn((m, e), device="cuda", dtype=dtype)
571576

572-
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk,
573-
dtype)
577+
parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk)

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def set_dispatch_combine(
251251
self, dispatch_combine: FusedMoEQuantizeDispatchCombine) -> bool:
252252
assert self.fused_experts == fused_experts
253253

254-
experts: FusedMoEPermuteExpertsUnpermute = None
254+
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
255255

256256
if isinstance(dispatch_combine,
257257
(BatchedDispatchCombine, PplxDispatchCombine)):

0 commit comments

Comments
 (0)