Skip to content

Commit cb7320d

Browse files
committed
forgot file
Signed-off-by: Bill Nell <[email protected]>
1 parent 7db0061 commit cb7320d

File tree

1 file changed

+41
-63
lines changed

1 file changed

+41
-63
lines changed

tests/kernels/test_pplx_moe.py

Lines changed: 41 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
373373
num_experts, # store at PplxDispatchCombine creation?
374374
None
375375
)
376-
torch.cuda.synchronize() # necessary?
376+
#torch.cuda.synchronize() # necessary?
377377

378378
out = torch.full(
379379
(max_num_tokens, hidden_dim),
@@ -452,18 +452,12 @@ def _pplx_dispatch_combine(
452452
nvshmem_finalize()
453453

454454

455-
@pytest.mark.parametrize("m", [2, 32, 64, 222]) #, 1024 * 128]) # what is restriction on this?
455+
@pytest.mark.parametrize("m", [2, 32, 64, 222]) #, 1024 * 128])
456456
@pytest.mark.parametrize("n", [128, 1024, 2048])
457-
@pytest.mark.parametrize("k", [128, 512, 1024]) # restrictions here?
457+
@pytest.mark.parametrize("k", [128, 512, 1024]) # restrictions? % 128?
458458
@pytest.mark.parametrize("e", NUM_EXPERTS)
459459
@pytest.mark.parametrize("topk", TOP_KS)
460460
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
461-
# @pytest.mark.parametrize("m", [2]) ##, 32]) #, 1024 * 128])
462-
# @pytest.mark.parametrize("n", [128])
463-
# @pytest.mark.parametrize("k", [128])
464-
# @pytest.mark.parametrize("e", [8]) #NUM_EXPERTS)
465-
# @pytest.mark.parametrize("topk", [2]) #TOP_KS)
466-
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
467461
def test_pplx_dispatch_combine(
468462
m: int,
469463
n: int,
@@ -491,13 +485,16 @@ def test_pplx_dispatch_combine(
491485

492486

493487
def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
494-
hidden_dim = a.shape[-1]
488+
assert torch.cuda.current_device() == pgi.local_rank
489+
490+
num_tokens, hidden_dim = a.shape
495491
num_experts = w1.shape[0]
496-
num_local_experts = num_experts // pgi.world_size
497492
block_size = 128
493+
device = pgi.device
494+
rank_num_tokens = num_tokens // pgi.world_size
498495

499-
max_num_tokens = round_up(a.shape[0], 128) #tokens_per_expert.max()
500-
print(f"max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}/{num_local_experts}")
496+
max_num_tokens = num_tokens
497+
#print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}")
501498
rank = pgi.rank
502499
world_size = pgi.world_size
503500

@@ -523,7 +520,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
523520

524521
dispatch_combine = PplxDispatchCombine(
525522
ata,
526-
max_num_tokens,
523+
max_num_tokens, # // world_size?
527524
pgi.world_size,
528525
dp_size,
529526
rank,
@@ -537,53 +534,34 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
537534
experts,
538535
)
539536

540-
a_chunk = chunk_by_rank(a, rank, world_size)
541-
score_chunk = chunk_by_rank(scores, rank, world_size)
537+
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
538+
score_chunk = chunk_by_rank(scores, rank, world_size).to(device)
542539
chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False)
543540

544-
print(f"chunk_topk_ids = {chunk_topk_ids}")
541+
#print(f"chunk_topk_ids = {chunk_topk_ids}")
545542

546-
# TODO: chunk up by rank
547-
if False:
548-
out = fused_experts(
549-
a_chunk,
550-
w1, # chunk?
551-
w2, # chunk?
552-
chunk_topk_weight,
553-
chunk_topk_ids,
554-
global_num_experts=num_local_experts
555-
)
556-
# reduce outputs?
557-
else:
558-
b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch(
559-
a_chunk,
560-
None,
561-
None,
562-
chunk_topk_ids,
563-
num_experts,
564-
None
565-
)
566-
torch.cuda.synchronize()
543+
out = fused_experts(
544+
a_chunk,
545+
w1, # chunk?
546+
w2, # chunk?
547+
chunk_topk_weight,
548+
chunk_topk_ids,
549+
global_num_experts=num_experts #? num_local_experts?
550+
)
567551

568-
out = torch.full(
569-
(max_num_tokens, hidden_dim),
570-
torch.nan,
571-
dtype=a.dtype,
572-
device=a.device,
573-
)
552+
torch.cuda.synchronize()
574553

575-
dispatch_combine.combine(
576-
out,
577-
b_a,
578-
chunk_topk_weight,
579-
chunk_topk_ids,
580-
)
554+
ata.destroy()
581555

582-
torch.cuda.synchronize()
556+
torch.distributed.barrier()
583557

584-
ata.destroy()
558+
#print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}")
585559

586-
return out
560+
#torch.distributed.all_reduce(out)
561+
562+
print(f"OUT {rank}: {out.shape} {out}")
563+
564+
return out[:rank_num_tokens]
587565

588566

589567
def _pplx_moe(
@@ -612,29 +590,29 @@ def _pplx_moe(
612590

613591
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
614592

615-
triton_output = torch_pplx_moe(pgi,
616-
dp_size,
617-
a,
618-
w1,
619-
w2,
620-
score,
621-
topk)
593+
pplxd_output = torch_pplx_moe(pgi,
594+
dp_size,
595+
a,
596+
w1,
597+
w2,
598+
score,
599+
topk)
622600

623601
if False:
624602
torch.set_printoptions(profile="full")
625603
print("BASELINE")
626604
print(torch_output)
627605
print("OUTPUT")
628-
print(triton_output)
606+
print(pplx_output)
629607

630-
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
608+
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
631609

632610
nvshmem_finalize()
633611

634612

635613
# @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128])
636614
# @pytest.mark.parametrize("n", [128, 1024, 2048])
637-
# @pytest.mark.parametrize("k", [128, 511, 1024])
615+
# @pytest.mark.parametrize("k", [128, 512, 1024])
638616
# @pytest.mark.parametrize("e", NUM_EXPERTS)
639617
# @pytest.mark.parametrize("topk", TOP_KS)
640618
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])

0 commit comments

Comments
 (0)