Skip to content

Commit 886045e

Browse files
committed
fix M=1 pplx test
Signed-off-by: Bill Nell <[email protected]>
1 parent 3e8a0e3 commit 886045e

File tree

1 file changed

+19
-49
lines changed

1 file changed

+19
-49
lines changed

tests/kernels/moe/test_pplx_moe.py

Lines changed: 19 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from vllm.config import VllmConfig, set_current_vllm_config
2929
from vllm.model_executor.layers.activation import SiluAndMul
3030
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
31-
BatchedExperts)
31+
BatchedExperts, BatchedTritonExperts)
3232
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
3333
from vllm.model_executor.layers.fused_moe.modular_kernel import (
3434
FusedMoEModularKernel)
@@ -293,34 +293,26 @@ def rank_chunk(num, r, w):
293293

294294
def chunk_by_rank(t, r, w):
295295
chunk = rank_chunk(t.shape[0], r, w)
296-
#print(f"chunk {t.shape}, {w}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}")
297296
return t[(r * chunk):(r + 1) * chunk]
298297

299298

300-
ata = None
301-
302299
def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
303300
assert torch.cuda.current_device() == pgi.local_rank
304301

305302
topk = topk_ids.shape[1]
306-
307-
#tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts)
308-
309-
num_tokens, hidden_dim = a.shape
303+
num_tokens, hidden_dim = a.shape[1]
310304
block_size = 128
311305
device = pgi.device
312306
rank = pgi.rank
313307
world_size = pgi.world_size
314308
max_num_tokens = rank_chunk(num_tokens, 0, world_size)
315-
print(f"MAX_NUM_TOKENS = {max_num_tokens}")
316309

317-
global ata
318310
ata = AllToAll.internode(
319311
max_num_tokens=max_num_tokens,
320312
num_experts=num_experts,
321313
experts_per_token=topk,
322314
rank=rank,
323-
world_size=pgi.world_size,
315+
world_size=world_size,
324316
dp_size=dp_size,
325317
hidden_dim=hidden_dim,
326318
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
@@ -332,19 +324,15 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
332324
dispatch_combine = PplxDispatchCombine(
333325
ata,
334326
max_num_tokens,
335-
pgi.world_size,
327+
world_size,
336328
dp_size,
337329
rank,
338-
a.dtype,
339330
)
340331

341332
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
342-
num_tokens = a_chunk.shape[0]
343333
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
344334
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
345335

346-
print(f"{rank}: shapes {a_chunk.shape}, {chunk_topk_weight.shape}, {chunk_topk_ids.shape}, E={num_experts}")
347-
348336
b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch(
349337
a_chunk,
350338
None,
@@ -356,21 +344,6 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
356344
False,
357345
)
358346

359-
#torch.cuda.synchronize()
360-
361-
if False:
362-
naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids,
363-
num_experts)
364-
365-
torch.distributed.all_reduce(tokens_per_expert)
366-
tokens_per_expert = chunk_by_rank(tokens_per_expert, rank,
367-
world_size).to(dtype=torch.int32)
368-
369-
torch.testing.assert_close(tokens_per_expert,
370-
expert_num_tokens,
371-
atol=0,
372-
rtol=0)
373-
374347
b_a = b_a * 1.5
375348

376349
out = torch.full(
@@ -388,9 +361,11 @@ def pplx_dispatch_combine(pgi, dp_size, a, topk_weight, topk_ids, num_experts):
388361
False,
389362
)
390363

391-
#torch.cuda.synchronize()
364+
torch.cuda.synchronize()
392365

393-
#ata.destroy()
366+
ata.destroy()
367+
368+
num_tokens = a_chunk.shape[0]
394369

395370
return out[:num_tokens]
396371

@@ -399,8 +374,8 @@ def _pplx_dispatch_combine(
399374
pgi: ProcessGroupInfo,
400375
dp_size: int,
401376
a,
402-
topk_weight,
403-
topk_ids,
377+
score,
378+
topk,
404379
num_experts,
405380
):
406381
uid = nvshmem_get_unique_id(
@@ -409,8 +384,8 @@ def _pplx_dispatch_combine(
409384
nvshmem_init(uid, pgi.rank, pgi.world_size)
410385
device = pgi.device
411386

387+
topk_weight, topk_ids = fused_topk(a, score, topk, False)
412388
k = a.shape[1]
413-
topk = topk_ids.shape[1]
414389

415390
a_rep = torch.repeat_interleave(a, topk, dim=0).to(device)
416391

@@ -422,21 +397,19 @@ def _pplx_dispatch_combine(
422397
torch_output = chunk_by_rank(torch_output, pgi.rank,
423398
pgi.world_size).to(pplx_output.device)
424399

425-
print(f"{pgi.rank}: out shapes {pplx_output.shape}, {torch_output.shape}")
426-
427400
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
428401

429402
nvshmem_finalize()
430403

431404

432-
# TODO: M < world_size doesn't appear to be supported by pplx?
433-
@pytest.mark.parametrize("m", [1, 4, 32, 64, 222])
405+
# TODO: this test point does not work for M == 1
406+
@pytest.mark.parametrize("m", [4, 32, 64, 222])
434407
@pytest.mark.parametrize("n", [128, 1024, 2048])
435408
@pytest.mark.parametrize("k", [128, 512, 1024])
436409
@pytest.mark.parametrize("e", NUM_EXPERTS)
437410
@pytest.mark.parametrize("topk", TOP_KS)
438411
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
439-
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #[[4, 2]])
412+
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
440413
@requires_pplx
441414
def test_pplx_dispatch_combine(
442415
m: int,
@@ -450,13 +423,10 @@ def test_pplx_dispatch_combine(
450423
current_platform.seed_everything(7)
451424
world_size, dp_size = world_dp_size
452425
device = "cuda"
453-
454426
a = torch.randn((m, k), device=device, dtype=dtype) / 10
455427
score = torch.randn((m, e), device=device, dtype=dtype)
456428

457-
topk_weight, topk_ids = fused_topk(a, score, topk, False)
458-
459-
parallel_launch(world_size, _pplx_dispatch_combine, dp_size, a, topk_weight, topk_ids, e)
429+
parallel_launch(world_size, _pplx_dispatch_combine, dp_size, a, score, topk, e)
460430

461431

462432
def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
@@ -476,7 +446,7 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
476446
num_experts=num_experts,
477447
experts_per_token=topk,
478448
rank=rank,
479-
world_size=pgi.world_size,
449+
world_size=world_size,
480450
dp_size=dp_size,
481451
hidden_dim=hidden_dim,
482452
hidden_dim_bytes=hidden_dim * a.dtype.itemsize,
@@ -488,12 +458,12 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
488458
dispatch_combine = PplxDispatchCombine(
489459
ata,
490460
max_num_tokens,
491-
pgi.world_size,
461+
world_size,
492462
dp_size,
493463
rank,
494464
)
495465

496-
experts = BatchedExperts(max_num_tokens)
466+
experts = BatchedExperts(a.shape[0])
497467

498468
fused_experts = FusedMoEModularKernel(
499469
dispatch_combine,
@@ -556,7 +526,7 @@ def _pplx_moe(
556526
@pytest.mark.parametrize("e", NUM_EXPERTS)
557527
@pytest.mark.parametrize("topk", TOP_KS)
558528
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
559-
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
529+
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
560530
@requires_pplx
561531
def test_pplx_moe(
562532
m: int,

0 commit comments

Comments
 (0)