Skip to content

Commit 1938bc8

Browse files
committed
fix reference implementations
Signed-off-by: Bill Nell <[email protected]>
1 parent 4c40380 commit 1938bc8

File tree

7 files changed

+303
-80
lines changed

7 files changed

+303
-80
lines changed

examples/offline_inference/data_parallel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,12 +115,15 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
115115
# Create an LLM.
116116
cconfig = CompilationConfig(
117117
level=0,
118+
#cudagraph_capture_sizes=[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208],
119+
#cudagraph_capture_sizes=[512,256,1],
118120
)
119121
llm = LLM(model=model,
120122
tensor_parallel_size=GPUs_per_dp_rank,
121123
enforce_eager=enforce_eager,
122124
enable_expert_parallel=True,
123-
compilation_config=cconfig)
125+
compilation_config=cconfig,
126+
)
124127
outputs = llm.generate(prompts, sampling_params)
125128
# Print the outputs.
126129
for i, output in enumerate(outputs):

tests/kernels/moe/test_batched_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,9 @@ def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
6262

6363

6464
@pytest.mark.parametrize("num_experts", [16, 32])
65-
@pytest.mark.parametrize("max_tokens_per_expert", [512])
66-
@pytest.mark.parametrize("K", [256])
67-
@pytest.mark.parametrize("N", [512])
65+
@pytest.mark.parametrize("max_tokens_per_expert", [32, 64, 128, 192, 224, 256, 512])
66+
@pytest.mark.parametrize("K", [128, 256, 1024])
67+
@pytest.mark.parametrize("N", [128, 256, 512, 1024])
6868
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
6969
def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int,
7070
N: int, dtype: torch.dtype):

tests/kernels/moe/test_pplx_moe.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from vllm.config import VllmConfig, set_current_vllm_config
2929
from vllm.model_executor.layers.activation import SiluAndMul
3030
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
31+
BatchedDispatchCombine,
3132
BatchedExperts)
3233
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
3334
from vllm.model_executor.layers.fused_moe.modular_kernel import (
@@ -170,7 +171,7 @@ def torch_dispatch(
170171
assert topk_ids.dim() == 2
171172
assert topk_ids.shape[0] == a.shape[0]
172173

173-
num_tokens = a.shape[0]
174+
num_tokens, hidden_dim = a.shape
174175
topk = topk_ids.shape[1]
175176

176177
tokens_per_expert = torch.bincount(topk_ids.view(-1),
@@ -181,7 +182,7 @@ def torch_dispatch(
181182
if max_num_tokens is None:
182183
max_num_tokens = int(tokens_per_expert.max().item())
183184

184-
b_a = torch.zeros((num_experts, max_num_tokens, a.shape[1]),
185+
b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim),
185186
dtype=a.dtype,
186187
device=a.device)
187188

@@ -198,7 +199,7 @@ def torch_dispatch(
198199

199200

200201
def torch_combine(b_out, topk_weight, topk_ids):
201-
num_tokens, topk = topk_ids.shape
202+
num_tokens = topk_ids.shape[0]
202203
num_experts = b_out.shape[0]
203204
K = b_out.shape[-1]
204205
out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device)
@@ -240,6 +241,22 @@ def torch_batched_moe(a, w1, w2, topk_weight, topk_ids):
240241
return torch_combine(out, topk_weight, topk_ids)
241242

242243

244+
def batched_moe(a, w1, w2, topk_weight, topk_ids):
245+
num_experts = w1.shape[0]
246+
247+
fused_experts = FusedMoEModularKernel(
248+
BatchedDispatchCombine(a.shape[0], world_size=1, dp_size=1, rank=0),
249+
BatchedExperts(a.shape[0])
250+
)
251+
252+
return fused_experts(a,
253+
w1,
254+
w2,
255+
topk_weight,
256+
topk_ids,
257+
num_experts)
258+
259+
243260
# TODO: same as torch_moe but with fused_topk factored out.
244261
def torch_moe2(a, w1, w2, topk_weight, topk_ids):
245262
M, K = a.shape
@@ -262,7 +279,7 @@ def torch_moe2(a, w1, w2, topk_weight, topk_ids):
262279
@pytest.mark.parametrize("k", [128, 511, 1024])
263280
@pytest.mark.parametrize("e", NUM_EXPERTS)
264281
@pytest.mark.parametrize("topk", TOP_KS)
265-
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
282+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
266283
def test_fused_moe_batched_experts(
267284
m: int,
268285
n: int,
@@ -280,10 +297,13 @@ def test_fused_moe_batched_experts(
280297

281298
with set_current_vllm_config(vllm_config):
282299
topk_weight, topk_ids = fused_topk(a, score, topk, False)
283-
torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
284-
triton_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
300+
baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids)
301+
torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids)
302+
batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids)
285303

286-
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
304+
torch.testing.assert_close(baseline_output, torch_output, atol=2e-2, rtol=0)
305+
torch.set_printoptions(profile="full")
306+
torch.testing.assert_close(baseline_output, batched_output, atol=2e-2, rtol=0)
287307

288308

289309
def rank_chunk(num, r, w):
@@ -473,6 +493,8 @@ def pplx_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids):
473493
experts,
474494
)
475495

496+
# TODO: workers with the same dp_rank must use the exact same inputs.
497+
476498
a_chunk = chunk_by_rank(a, rank, world_size).to(device)
477499
chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device)
478500
chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device)
@@ -528,7 +550,7 @@ def _pplx_moe(
528550
@pytest.mark.parametrize("k", [128, 512, 1024])
529551
@pytest.mark.parametrize("e", NUM_EXPERTS)
530552
@pytest.mark.parametrize("topk", TOP_KS)
531-
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
553+
@pytest.mark.parametrize("dtype", [torch.bfloat16])
532554
@pytest.mark.parametrize("world_dp_size", [[2, 1]])
533555
@requires_pplx
534556
def test_pplx_moe(

0 commit comments

Comments
 (0)