Skip to content

Commit 86c2055

Browse files
committed
wip
Signed-off-by: Bill Nell <[email protected]>
1 parent fe1974a commit 86c2055

File tree

3 files changed

+93
-77
lines changed

3 files changed

+93
-77
lines changed

tests/kernels/test_pplx_moe.py

Lines changed: 90 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -164,18 +164,19 @@ def torch_dispatch(
164164
a: torch.Tensor,
165165
topk_ids: torch.Tensor,
166166
num_experts: int
167-
) -> torch.Tensor:
167+
) -> Tuple[torch.Tensor, torch.Tensor]:
168168
assert topk_ids.dim() == 2
169169
assert topk_ids.shape[0] == a.shape[0]
170170

171171
num_tokens = a.shape[0]
172172
topk = topk_ids.shape[1]
173173

174174
tokens_per_expert = torch.bincount(topk_ids.view(-1), minlength=num_experts)
175-
176175
max_num_tokens = tokens_per_expert.max()
176+
177177
b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]),
178178
dtype=a.dtype, device=a.device)
179+
179180
#print(f"b_a shape {b_a.shape}")
180181

181182
token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device)
@@ -242,59 +243,58 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids):
242243
topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1)
243244

244245

245-
# @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128])
246-
# @pytest.mark.parametrize("n", [128, 1024, 2048])
247-
# @pytest.mark.parametrize("k", [128, 511, 1024])
248-
# @pytest.mark.parametrize("e", NUM_EXPERTS)
249-
# @pytest.mark.parametrize("topk", TOP_KS)
250-
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
251-
# def test_fused_moe_batched_experts(
252-
# m: int,
253-
# n: int,
254-
# k: int,
255-
# e: int,
256-
# topk: int,
257-
# dtype: torch.dtype,
258-
# ):
259-
# current_platform.seed_everything(7)
260-
261-
# a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
262-
# w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
263-
# w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
264-
265-
# score = torch.randn((m, e), device="cuda", dtype=dtype)
266-
267-
# vllm_config = VllmConfig()
268-
# with set_current_vllm_config(vllm_config):
269-
# topk_weight, topk_ids = fused_topk(a, score, topk, False)
270-
271-
# torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
272-
273-
# if True:
274-
# triton_output = torch_batched_moe(a,
275-
# w1,
276-
# w2,
277-
# topk_weight,
278-
# topk_ids)
279-
# else:
280-
# b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e)
281-
# triton_output = fused_batched_experts(
282-
# b_a,
283-
# w1,
284-
# w2,
285-
# topk_weight,
286-
# topk_ids,
287-
# global_num_experts=e
288-
# )
289-
290-
# if False:
291-
# torch.set_printoptions(profile="full")
292-
# print("BASELINE")
293-
# print(torch_output)
294-
# print("OUTPUT")
295-
# print(triton_output)
296-
297-
# torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
246+
@pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128])
247+
@pytest.mark.parametrize("n", [128, 1024, 2048])
248+
@pytest.mark.parametrize("k", [128, 511, 1024])
249+
@pytest.mark.parametrize("e", NUM_EXPERTS)
250+
@pytest.mark.parametrize("topk", TOP_KS)
251+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
252+
def test_fused_moe_batched_experts(
253+
m: int,
254+
n: int,
255+
k: int,
256+
e: int,
257+
topk: int,
258+
dtype: torch.dtype,
259+
):
260+
current_platform.seed_everything(7)
261+
262+
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
263+
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
264+
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
265+
score = torch.randn((m, e), device="cuda", dtype=dtype)
266+
267+
vllm_config = VllmConfig()
268+
with set_current_vllm_config(vllm_config):
269+
topk_weight, topk_ids = fused_topk(a, score, topk, False)
270+
271+
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
272+
273+
if True:
274+
triton_output = torch_batched_moe(a,
275+
w1,
276+
w2,
277+
topk_weight,
278+
topk_ids)
279+
else:
280+
b_a, tokens_per_expert = batch_by_experts(a, topk_ids, e)
281+
triton_output = fused_batched_experts(
282+
b_a,
283+
w1,
284+
w2,
285+
topk_weight,
286+
topk_ids,
287+
global_num_experts=e
288+
)
289+
290+
if False:
291+
torch.set_printoptions(profile="full")
292+
print("BASELINE")
293+
print(torch_output)
294+
print("OUTPUT")
295+
print(triton_output)
296+
297+
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
298298

299299

300300
def chunk_by_rank(t, r, w):
@@ -310,6 +310,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
310310

311311
num_tokens, hidden_dim = a.shape
312312
num_experts = w1.shape[0]
313+
num_local_experts = w1.shape[0] // pgi.world_size
313314
block_size = 128
314315
device = pgi.device
315316
rank_num_tokens = num_tokens // pgi.world_size
@@ -352,7 +353,7 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
352353
score_chunk = chunk_by_rank(scores, rank, world_size).to(device)
353354
chunk_topk_weight, chunk_topk_ids = fused_topk(a_chunk, score_chunk, topk, False)
354355

355-
#print(f"chunk_topk_ids = {chunk_topk_ids}")
356+
#print(f"chunk_topk_ids = {chunk_topk_ids.view(-1)}")
356357

357358
b_a, b_a_scale, expert_num_tokens = dispatch_combine.dispatch(
358359
a_chunk,
@@ -363,6 +364,25 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
363364
None
364365
)
365366

367+
#topk_weight, topk_ids = fused_topk(a_chunk, score_chunk, topk, False)
368+
naive_b_a, tokens_per_expert = torch_dispatch(a_chunk, chunk_topk_ids, num_experts)
369+
370+
torch.distributed.all_reduce(tokens_per_expert)
371+
#max_num = tokens_per_expert.max()
372+
tokens_per_expert = chunk_by_rank(tokens_per_expert, rank, world_size).to(dtype=torch.int32)
373+
374+
#print(f"tpe {tokens_per_expert}")
375+
#print(f"ent {expert_num_tokens}")
376+
377+
#naive_b_a = chunk_by_rank(naive_b_a, rank, world_size)
378+
379+
#torch.set_printoptions(profile="full")
380+
#print("b_a", b_a[:naive_b_a.shape[1]])
381+
#print("naive_b_a", naive_b_a)
382+
383+
torch.testing.assert_close(tokens_per_expert, expert_num_tokens, atol=0, rtol=0)
384+
#torch.testing.assert_close(b_a[:, :naive_b_a.shape[1]], naive_b_a, atol=2e-2, rtol=0)
385+
366386
b_a = b_a * 1.5
367387

368388
out = torch.full(
@@ -382,8 +402,6 @@ def torch_pplx_dispatch_combine(pgi, dp_size, a, w1, w2, scores, topk):
382402

383403
ata.destroy()
384404

385-
#torch.distributed.barrier()
386-
387405
#print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}")
388406

389407
#torch.distributed.all_reduce(out)
@@ -547,8 +565,6 @@ def torch_pplx_moe(pgi, dp_size, a, w1, w2, scores, topk):
547565

548566
ata.destroy()
549567

550-
#torch.distributed.barrier()
551-
552568
#print(f"OUT {rank}: {out.shape} {out[:rank_num_tokens]}")
553569

554570
#torch.distributed.all_reduce(out)
@@ -593,8 +609,6 @@ def _pplx_moe(
593609
score,
594610
topk)
595611

596-
#print(f"torch_output {pgi.rank}: {torch_output}")
597-
598612
if False:
599613
print("BASELINE")
600614
print(torch_output)
@@ -603,23 +617,25 @@ def _pplx_moe(
603617

604618
torch_output = chunk_by_rank(torch_output, pgi.rank, pgi.world_size).to(pplx_output.device)
605619

620+
#print(f"torch_output {pgi.rank}: {torch_output.shape} {torch_output}")
621+
606622
torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0)
607623

608624
nvshmem_finalize()
609625

610626

611-
# @pytest.mark.parametrize("m", [1, 33, 64, 222]) #, 1024 * 128])
612-
# @pytest.mark.parametrize("n", [128, 1024, 2048])
613-
# @pytest.mark.parametrize("k", [128, 512, 1024])
614-
# @pytest.mark.parametrize("e", NUM_EXPERTS)
615-
# @pytest.mark.parametrize("topk", TOP_KS)
616-
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
617-
@pytest.mark.parametrize("m", [64]) ##, 32]) #, 1024 * 128])
618-
@pytest.mark.parametrize("n", [128])
619-
@pytest.mark.parametrize("k", [128])
620-
@pytest.mark.parametrize("e", [8]) #NUM_EXPERTS)
621-
@pytest.mark.parametrize("topk", [2]) #TOP_KS)
622-
@pytest.mark.parametrize("dtype", [torch.bfloat16])
627+
@pytest.mark.parametrize("m", [2, 32, 64, 222]) #, 1024 * 128])
628+
@pytest.mark.parametrize("n", [128, 1024, 2048])
629+
@pytest.mark.parametrize("k", [128, 512, 1024])
630+
@pytest.mark.parametrize("e", NUM_EXPERTS)
631+
@pytest.mark.parametrize("topk", TOP_KS)
632+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
633+
# @pytest.mark.parametrize("m", [64]) ##, 32]) #, 1024 * 128])
634+
# @pytest.mark.parametrize("n", [128])
635+
# @pytest.mark.parametrize("k", [128])
636+
# @pytest.mark.parametrize("e", [8]) #NUM_EXPERTS)
637+
# @pytest.mark.parametrize("topk", [2]) #TOP_KS)
638+
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
623639
@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]])
624640
def test_pplx_moe(
625641
m: int,

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1858,8 +1858,8 @@ def workspace_shapes(
18581858
a: torch.Tensor,
18591859
) -> Tuple[int, int, torch.dtype]:
18601860
max_num_tokens = a.shape[1]
1861-
workspace13 = num_experts * max_num_tokens * K * 2 # *2 = HACK!!!!!
1862-
workspace2 = max_num_tokens * (N // 2)
1861+
workspace13 = num_experts * max_num_tokens * K * topk * 2 # TODO: *2 is a hack
1862+
workspace2 = max_num_tokens * N
18631863
return (workspace13, workspace2, a_dtype)
18641864

18651865
def apply(

vllm/model_executor/layers/fused_moe/pplx_dispatch_combine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def dispatch(
8484
dtype=a1q.dtype,
8585
device=device,
8686
)
87-
expert_x.fill_(torch.nan) # debugging, remove later
87+
expert_x.fill_(0) #torch.nan # debugging, remove later
8888

8989
logger.debug(f"GOT HERE B {self.rank}")
9090

0 commit comments

Comments
 (0)