Skip to content

Commit f79f96f

Browse files
committed
fix benchmark
1 parent 9f2f0cf commit f79f96f

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

test/benchmark/kernel/benchmark_fused_moe_triton.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,19 @@ def fused_moe_sglang_api(
177177
a2_scale=None,
178178
block_shape=None,
179179
):
180+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
181+
from sglang.srt.layers.moe.topk import TopK, TopKConfig, select_experts
182+
183+
topk_output = select_experts(
184+
hidden_states=x,
185+
router_logits=input_gating,
186+
topk_config=TopKConfig(top_k=topk, renormalize=False),
187+
)
180188
return fused_moe_sglang(
181189
x,
182190
w1,
183191
w2,
184-
input_gating,
185-
topk,
186-
renormalize=True,
187-
inplace=True,
192+
topk_output,
188193
use_fp8_w8a8=use_fp8_w8a8,
189194
w1_scale=w1_scale,
190195
w2_scale=w2_scale,
@@ -193,11 +198,10 @@ def fused_moe_sglang_api(
193198
block_shape=block_shape,
194199
)
195200

196-
197201
@triton.testing.perf_report(
198202
triton.testing.Benchmark(
199203
x_names=["batch_size"],
200-
x_vals=[1, 8, 16, 32, 64, 128],
204+
x_vals=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192],
201205
line_arg="provider",
202206
line_vals=[
203207
"vllm_fused_moe_triton",
@@ -264,9 +268,9 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
264268
api_func = (
265269
fused_moe_vllm_api
266270
if provider == "vllm_fused_moe_triton"
267-
else fused_moe_sglang_api
268-
if provider == "lightllm_fused_moe_triton"
269271
else fused_moe_lightllm_api
272+
if provider == "lightllm_fused_moe_triton"
273+
else fused_moe_sglang_api
270274
)
271275
for _ in range(10):
272276
api_func(
@@ -285,7 +289,8 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
285289
torch.cuda.synchronize()
286290

287291
quantiles = [0.5, 0.2, 0.8]
288-
ms, min_ms, max_ms = triton.testing.do_bench(
292+
do_bench = triton.testing.do_bench if batch_size > 256 else triton.testing.do_bench_cudagraph
293+
ms, min_ms, max_ms = do_bench(
289294
lambda: api_func(
290295
x,
291296
w1,

0 commit comments

Comments
 (0)