@@ -177,14 +177,19 @@ def fused_moe_sglang_api(
177177 a2_scale = None ,
178178 block_shape = None ,
179179):
180+ from sglang .srt .layers .moe .moe_runner import MoeRunnerConfig
181+ from sglang .srt .layers .moe .topk import TopK , TopKConfig , select_experts
182+
183+ topk_output = select_experts (
184+ hidden_states = x ,
185+ router_logits = input_gating ,
186+ topk_config = TopKConfig (top_k = topk , renormalize = False ),
187+ )
180188 return fused_moe_sglang (
181189 x ,
182190 w1 ,
183191 w2 ,
184- input_gating ,
185- topk ,
186- renormalize = True ,
187- inplace = True ,
192+ topk_output ,
188193 use_fp8_w8a8 = use_fp8_w8a8 ,
189194 w1_scale = w1_scale ,
190195 w2_scale = w2_scale ,
@@ -193,11 +198,10 @@ def fused_moe_sglang_api(
193198 block_shape = block_shape ,
194199 )
195200
196-
197201@triton .testing .perf_report (
198202 triton .testing .Benchmark (
199203 x_names = ["batch_size" ],
200- x_vals = [1 , 8 , 16 , 32 , 64 , 128 ],
204+ x_vals = [1 , 8 , 16 , 32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 , 8192 ],
201205 line_arg = "provider" ,
202206 line_vals = [
203207 "vllm_fused_moe_triton" ,
@@ -264,9 +268,9 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
264268 api_func = (
265269 fused_moe_vllm_api
266270 if provider == "vllm_fused_moe_triton"
267- else fused_moe_sglang_api
268- if provider == "lightllm_fused_moe_triton"
269271 else fused_moe_lightllm_api
272+ if provider == "lightllm_fused_moe_triton"
273+ else fused_moe_sglang_api
270274 )
271275 for _ in range (10 ):
272276 api_func (
@@ -285,7 +289,8 @@ def benchmark(batch_size, provider, model_config, use_fp8=False):
285289 torch .cuda .synchronize ()
286290
287291 quantiles = [0.5 , 0.2 , 0.8 ]
288- ms , min_ms , max_ms = triton .testing .do_bench (
292+ do_bench = triton .testing .do_bench if batch_size > 256 else triton .testing .do_bench_cudagraph
293+ ms , min_ms , max_ms = do_bench (
289294 lambda : api_func (
290295 x ,
291296 w1 ,
0 commit comments