Skip to content

Commit 7db0061

Browse files
committed
dispatch/combine unit test
Signed-off-by: Bill Nell <[email protected]>
1 parent 46d09b7 commit 7db0061

File tree

1 file changed

+58
-46
lines changed

1 file changed

+58
-46
lines changed

tests/kernels/test_pplx_moe.py

Lines changed: 58 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -308,17 +308,27 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids):
308308
# torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
309309

310310

311+
def chunk_by_rank(t, r, w):
312+
num = t.shape[0]
313+
assert num % w == 0, f"{num}, {w}" # for now
314+
chunk = num // w
315+
#print(f"chunk {t.shape}, {w}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}")
316+
return t[(r * chunk):(r + 1)*chunk]
317+
318+
311319
def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
312320
assert torch.cuda.current_device() == pgi.local_rank
313321

314322
num_tokens, hidden_dim = a.shape
315323
num_experts = w1.shape[0]
316324
block_size = 128
317325
device = pgi.device
326+
rank_num_tokens = num_tokens // pgi.world_size
318327

319328
max_num_tokens = num_tokens
320-
print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}")
329+
#print(f"device = {device}, max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}, dp_size = {dp_size}")
321330
rank = pgi.rank
331+
world_size = pgi.world_size
322332

323333
ata = AllToAll(
324334
max_num_tokens=max_num_tokens,
@@ -342,22 +352,15 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
342352

343353
dispatch_combine = PplxDispatchCombine(
344354
ata,
345-
max_num_tokens,
355+
max_num_tokens, # // world_size?
346356
pgi.world_size,
347357
dp_size,
348358
rank,
349359
a.dtype,
350360
)
351361

352-
def chunk_by_rank(t, r):
353-
num = t.shape[0]
354-
assert num % pgi.world_size == 0, f"{num}, {pgi.world_size}" # for now
355-
chunk = num // pgi.world_size
356-
print(f"chunk {t.shape}, {pgi.world_size}, {r}, {chunk}, {r*chunk}:{(r + 1)*chunk}")
357-
return t[(r * chunk):(r + 1)*chunk]
358-
359-
a_chunk = chunk_by_rank(a, rank).to(device)
360-
score_chunk = chunk_by_rank(scores, rank).to(device)
362+
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
363+
score_chunk = chunk_by_rank(scores, rank, world_size).to(device)
361364
chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False)
362365

363366
#print(f"chunk_topk_ids = {chunk_topk_ids}")
@@ -391,36 +394,41 @@ def chunk_by_rank(t, r):
391394

392395
torch.distributed.barrier()
393396

394-
return out[:num_tokens]
397+
#print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}")
398+
399+
#torch.distributed.all_reduce(out)
400+
401+
#print(f"AR OUT {rank}: {out.shape} {out}")
402+
403+
return out[:rank_num_tokens]
395404

396405

397406
def _pplx_dispatch_combine(
398407
pgi: ProcessGroupInfo,
399408
dp_size: int,
400-
m: int,
401-
n: int,
402-
k: int,
403-
e: int,
409+
a: torch.Tensor,
410+
w1: torch.Tensor,
411+
w2: torch.Tensor,
412+
score: torch.Tensor,
404413
topk: int,
405414
dtype: torch.dtype,
406415
):
407416
uid = nvshmem_get_unique_id() if pgi.rank == 0 else nvshmem_alloc_empty_unique_id()
408417
torch.distributed.broadcast(uid, src=0)
409418
nvshmem_init(uid, pgi.rank, pgi.world_size)
410419

411-
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
412-
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
413-
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
414-
415-
score = torch.randn((m, e), device="cuda", dtype=dtype)
420+
m, k = a.shape
421+
e, _, n = w2.shape
416422

417423
topk_weight, topk_ids = fused_topk(a, score, topk, False)
418424

419-
print(f"a {a.shape}")
420-
a_rep = torch.repeat_interleave(a, topk, dim=1)
421-
print(f"a_rep {a_rep.shape}")
425+
#print(f"a {a.shape}")
426+
a_rep = torch.repeat_interleave(a, topk, dim=0)
427+
#print(f"a_rep {a_rep.shape} {a_rep.view(-1, topk, k)}")
428+
429+
torch_output = (a_rep.view(-1, topk, k) * topk_weight.view(-1, topk, 1)).to(a.dtype).sum(dim=1)
422430

423-
torch_output = (a_rep.view(-1, topk, k) * topk_weight.view(-1, topk, 1)).sum(dim=1).to(a.dtype)
431+
#print(f"torch_output {pgi.rank}: {torch_output.shape} {torch_output}")
424432

425433
pplx_output = torch_pplx_dispatch_combine(pgi,
426434
dp_size,
@@ -437,23 +445,25 @@ def _pplx_dispatch_combine(
437445
print("OUTPUT")
438446
print(pplx_output)
439447

448+
torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device)
449+
440450
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
441451

442452
nvshmem_finalize()
443453

444454

445-
# @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128])
446-
# @pytest.mark.parametrize("n", [128, 1024, 2048])
447-
# @pytest.mark.parametrize("k", [128, 511, 1024])
448-
# @pytest.mark.parametrize("e", NUM_EXPERTS)
449-
# @pytest.mark.parametrize("topk", TOP_KS)
450-
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
451-
@pytest.mark.parametrize("m", [128]) ##, 32]) #, 1024 * 128])
452-
@pytest.mark.parametrize("n", [128])
453-
@pytest.mark.parametrize("k", [128])
454-
@pytest.mark.parametrize("e", [8]) #NUM_EXPERTS)
455-
@pytest.mark.parametrize("topk", [2]) #TOP_KS)
456-
@pytest.mark.parametrize("dtype", [torch.bfloat16])
455+
@pytest.mark.parametrize("m", [2, 32, 64, 222]) #, 1024 * 128]) # what is restriction on this?
456+
@pytest.mark.parametrize("n", [128, 1024, 2048])
457+
@pytest.mark.parametrize("k", [128, 512, 1024]) # restrictions here?
458+
@pytest.mark.parametrize("e", NUM_EXPERTS)
459+
@pytest.mark.parametrize("topk", TOP_KS)
460+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
461+
# @pytest.mark.parametrize("m", [2]) ##, 32]) #, 1024 * 128])
462+
# @pytest.mark.parametrize("n", [128])
463+
# @pytest.mark.parametrize("k", [128])
464+
# @pytest.mark.parametrize("e", [8]) #NUM_EXPERTS)
465+
# @pytest.mark.parametrize("topk", [2]) #TOP_KS)
466+
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
457467
def test_pplx_dispatch_combine(
458468
m: int,
459469
n: int,
@@ -469,8 +479,14 @@ def test_pplx_dispatch_combine(
469479
else:
470480
world_size = 2
471481
dp_size = 1
482+
483+
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
484+
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
485+
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
486+
score = torch.randn((m, e), device="cuda", dtype=dtype)
487+
472488
parallel_launch(
473-
world_size, _pplx_dispatch_combine, dp_size, m, n, k, e, topk, dtype
489+
world_size, _pplx_dispatch_combine, dp_size, a, w1, w2, score, topk, dtype
474490
)
475491

476492

@@ -483,6 +499,7 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
483499
max_num_tokens = round_up(a.shape[0], 128) #tokens_per_expert.max()
484500
print(f"max_num_tokens = {max_num_tokens}, topk = {topk}, num_ex = {num_experts}/{num_local_experts}")
485501
rank = pgi.rank
502+
world_size = pgi.world_size
486503

487504
ata = AllToAll(
488505
max_num_tokens=max_num_tokens,
@@ -520,14 +537,9 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
520537
experts,
521538
)
522539

523-
def chunk_by_rank(t, r):
524-
num = t.shape[0]
525-
assert num % pgi.world_size == 0, f"{num}, {dp_size}" # for now
526-
chunk = num // pgi.world_size
527-
return t[(r * chunk):(r + 1)*chunk]
528-
529-
a_chunk = chunk_by_rank(a, rank)
530-
chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, chunk_by_rank(scores, rank), topk, False)
540+
a_chunk = chunk_by_rank(a, rank, world_size)
541+
score_chunk = chunk_by_rank(scores, rank, world_size)
542+
chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False)
531543

532544
print(f"chunk_topk_ids = {chunk_topk_ids}")
533545

0 commit comments

Comments
 (0)