Skip to content

Commit 6bfb43a

Browse files
authored
benchmark: add moe to benchmark (#1497)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 718033c commit 6bfb43a

File tree

4 files changed

+1710
-2
lines changed

4 files changed

+1710
-2
lines changed

benchmarks/README.md

Lines changed: 170 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ Currently supports testing:
2222
- `group_gemm_fp8_nt_groupwise` - Group GEMM with FP8 data types using groupwise scaling.
2323
- `bmm_fp8` - Batched matrix multiplication with FP8 inputs.
2424
- `mm_fp4` - Maxtrix multiplication with NVFP4 inputs.
25+
- `trtllm_fp4_block_scale_moe` - MOE with FP4 quantized weights and block-wise scaling.
26+
- `trtllm_fp8_block_scale_moe` - MOE with FP8 quantized weights and block-wise scaling.
27+
- `trtllm_fp8_per_tensor_scale_moe` - MOE with FP8 quantized weights and per-tensor scaling.
28+
- `cutlass_fused_moe` - CUTLASS fused MoE (base/fp8/nvfp4 variants with optional TP/EP)
2529

2630
Support surface will expand to other operations such as MLA or non-attention operations in the future.
2731
## Quick Start
@@ -101,6 +105,126 @@ python3 flashinfer_benchmark.py \
101105
--scale_major_mode K \
102106
--refcheck \
103107
-vv
108+
109+
# MOE FP4 Block Scale (DeepSeekV3 routing)
110+
python3 flashinfer_benchmark.py \
111+
--routine trtllm_fp4_block_scale_moe \
112+
--num_tokens 1024 \
113+
--hidden_size 1024 \
114+
--intermediate_size 1024 \
115+
--num_experts 128 \
116+
--top_k 8 \
117+
--n_group 8 \
118+
--topk_group 4 \
119+
--routed_scaling_factor 2.5 \
120+
--use_routing_bias \
121+
--routing_method deepseek_v3 \
122+
--use_shuffled_weight \
123+
--verbose
124+
125+
# MOE FP8 Block Scale with DeepSeekV3 routing
126+
python3 flashinfer_benchmark.py \
127+
--routine trtllm_fp8_block_scale_moe \
128+
--num_tokens 1024 \
129+
--hidden_size 1024 \
130+
--intermediate_size 1024 \
131+
--num_experts 128 \
132+
--top_k 8 \
133+
--n_group 8 \
134+
--topk_group 4 \
135+
--routed_scaling_factor 2.5 \
136+
--use_routing_bias \
137+
--routing_method deepseek_v3 \
138+
--use_shuffled_weight \
139+
--verbose
140+
141+
# MOE FP8 Block Scale with Renormalize routing (no groups)
142+
python3 flashinfer_benchmark.py \
143+
--routine trtllm_fp8_block_scale_moe \
144+
--num_tokens 1024 \
145+
--hidden_size 1024 \
146+
--intermediate_size 1024 \
147+
--num_experts 128 \
148+
--top_k 1 \
149+
--routing_method renormalize \
150+
--verbose
151+
152+
# CUTLASS Fused MoE (base variant)
153+
python3 flashinfer_benchmark.py \
154+
--routine cutlass_fused_moe \
155+
--num_tokens 32 \
156+
--hidden_size 128 \
157+
--intermediate_size 128 \
158+
--num_experts 2 \
159+
--top_k 2 \
160+
--cutlass_variant base \
161+
--input_dtype float16 \
162+
--verbose
163+
164+
# CUTLASS Fused MoE (fp8 variant)
165+
python3 flashinfer_benchmark.py \
166+
--routine cutlass_fused_moe \
167+
--num_tokens 32 \
168+
--hidden_size 128 \
169+
--intermediate_size 128 \
170+
--num_experts 2 \
171+
--top_k 2 \
172+
--cutlass_variant fp8 \
173+
--input_dtype float16 \
174+
--verbose
175+
176+
# CUTLASS Fused MoE (nvfp4 weights; optional quantized input)
177+
python3 flashinfer_benchmark.py \
178+
--routine cutlass_fused_moe \
179+
--num_tokens 32 \
180+
--hidden_size 128 \
181+
--intermediate_size 128 \
182+
--num_experts 2 \
183+
--top_k 2 \
184+
--cutlass_variant nvfp4 \
185+
--input_dtype float16 \
186+
--verbose
187+
188+
# CUTLASS Fused MoE (nvfp4 weights with quantized input)
189+
python3 flashinfer_benchmark.py \
190+
--routine cutlass_fused_moe \
191+
--num_tokens 32 \
192+
--hidden_size 128 \
193+
--intermediate_size 128 \
194+
--num_experts 2 \
195+
--top_k 2 \
196+
--cutlass_variant nvfp4 \
197+
--quantized_input \
198+
--input_dtype float16 \
199+
--verbose
200+
201+
# CUTLASS Fused MoE with Expert Parallel (EP)
202+
python3 flashinfer_benchmark.py \
203+
--routine cutlass_fused_moe \
204+
--num_tokens 32 \
205+
--hidden_size 128 \
206+
--intermediate_size 128 \
207+
--num_experts 8 \
208+
--top_k 2 \
209+
--cutlass_variant base \
210+
--input_dtype float16 \
211+
--ep_size 4 \
212+
--ep_rank 0 \
213+
--verbose
214+
215+
# CUTLASS Fused MoE with Tensor Parallel (TP)
216+
python3 flashinfer_benchmark.py \
217+
--routine cutlass_fused_moe \
218+
--num_tokens 32 \
219+
--hidden_size 128 \
220+
--intermediate_size 128 \
221+
--num_experts 2 \
222+
--top_k 2 \
223+
--cutlass_variant base \
224+
--input_dtype float16 \
225+
--tp_size 2 \
226+
--tp_rank 0 \
227+
--verbose
104228
```
105229

106230
### Batch Testing
@@ -120,7 +244,9 @@ The output CSV will contain detailed metrics including:
120244
### General Flags
121245
| Flag | Description |
122246
|--------------------------|-------------------------------------------------------------------------------------------------------------|
123-
| `--routine` | Test routine to run: `BatchDecodeWithPagedKVCacheWrapper`, `BatchPrefillWithPagedKVCacheWrapper`, `BatchPrefillWithRaggedKVCacheWrapper`, `BatchMLAPagedAttentionWrapper`, `gemm_fp8_nt_groupwise`, `group_gemm_fp8_nt_groupwise`, `bmm_fp8`, `mm_fp4` |
247+
| `--routine` | Test routine to run: `BatchDecodeWithPagedKVCacheWrapper`, `BatchPrefillWithPagedKVCacheWrapper`, `BatchPrefillWithRaggedKVCacheWrapper`, `BatchMLAPagedAttentionWrapper`, `gemm_fp8_nt_groupwise`, `group_gemm_fp8_nt_groupwise`, `bmm_fp8`, `mm_fp4`, `trtllm_fp4_block_scale_moe`, `trtllm_fp8_block_scale_moe`, `trtllm_fp8_per_tensor_scale_moe` |
248+
| | |
249+
| | Also: `cutlass_fused_moe` (CUTLASS fused MoE; variants: base, fp8, nvfp4) |
124250
| `--num_iters` | Number of iterations for performance measurement |
125251
| `--dry_run_iters` | Number of warmup iterations |
126252
| `--no_cuda_graph` | Disable CUDA graph to execute kernels outside of the graph. |
@@ -165,6 +291,49 @@ The output CSV will contain detailed metrics including:
165291
| `--mat2_dtype` | Data type for second matrix (for FP8 GEMM, e.g. `fp8_e4m3`) |
166292
| `--use_128x4_sf_layout` | Use 128x4 scale/format layout for FP4 GEMM (for `mm_fp4` routine) |
167293

294+
### MOE Flags
295+
| Flag | Description |
296+
|--------------------------|-------------------------------------------------------------------------------------------------------------|
297+
| `--num_tokens` | Number of input tokens |
298+
| `--hidden_size` | Hidden dimension size |
299+
| `--intermediate_size` | Intermediate dimension size (FF layer dimension) |
300+
| `--num_experts` | Total number of experts |
301+
| `--top_k` | Number of experts to route to per token |
302+
| `--n_group` | Number of expert groups (for DeepSeek routing). Default: 1 |
303+
| `--topk_group` | Number of groups to consider for top-k routing. Default: 1 |
304+
| `--routed_scaling_factor`| Scaling factor for routing. Default: 2.5 |
305+
| `--local_expert_offset` | Offset of local experts in global expert space. Default: 0 |
306+
| `--local_num_experts` | Number of experts handled by this device. Default: equals num_experts |
307+
| `--tile_tokens_dim` | Tile dimension for tokens. Default: 8 |
308+
| `--routing_method` | Routing method: `renormalize`, `deepseek_v3`, `llama4`, `renormalize_naive`. Default: `deepseek_v3`. |
309+
| `--use_shuffled_weight` | Whether to use shuffled weight layout |
310+
| `--weight_layout` | Weight layout: 0=MajorK, 1=MajorMn, 2=BlockMajorK. Default: 0 |
311+
| `--use_routing_bias` | Whether to use routing bias |
312+
| `--use_routing_scales_on_input` | Whether to use routing scales on input (for Llama4 routing) |
313+
| `--input_dtype` | Data type of the input hidden states. Default: bfloat16 |
314+
| `--weight_dtype` | Data type of the weights (before quantization). Default: bfloat16 |
315+
| `--cutlass_variant` | CUTLASS MoE variant: `base` (no quant), `fp8` (per-tensor FP8), `nvfp4` (FP4 block-scale) |
316+
| `--quantized_input` | For `nvfp4` only: quantize input activations to FP4 |
317+
| `--tp_size` | Tensor-parallel world size |
318+
| `--tp_rank` | Tensor-parallel rank |
319+
| `--ep_size` | Expert-parallel world size |
320+
| `--ep_rank` | Expert-parallel rank |
321+
322+
### MOE Routing Method Compatibility
323+
324+
| Routing Method | Requirements | Compatible MOE Types |
325+
|------------------------|--------------|---------------------|
326+
| **deepseek_v3** | `top_k <= 8`, `topk_group <= 4`, requires `--n_group`, `--topk_group`, `--routed_scaling_factor`, `--use_routing_bias` | FP4, FP8 Block Scale |
327+
| **renormalize** | `top_k == 1` for FP8 Block Scale, `top_k <= 8` for FP4. Do NOT use `--n_group` or `--topk_group` | All MOE types |
328+
| **llama4** | `top_k == 1`, requires `--routed_scaling_factor`, `--use_routing_bias`, `--use_routing_scales_on_input`. Do NOT use `--n_group` or `--topk_group` | FP8 Per-Tensor |
329+
| **renormalize_naive** | `top_k == 1` for FP8 Block Scale, `top_k <= 8` for FP4. Do NOT use `--n_group` or `--topk_group` | FP4 primarily |
330+
331+
Notes:
332+
- Group parameters (`--n_group`, `--topk_group`) are ONLY used with DeepSeekV3 routing method. Using them with other routing methods will cause the error: "Routing kernel with groups implies DeepSeekV3 routing method."
333+
- Different MOE kernel implementations have different `top_k` constraints. FP8 MOE kernels (both Block Scale and Per-Tensor) have stricter limits than FP4 for non-DeepSeekV3 routing methods.
334+
- FP8 MOE kernels require integer values for group parameters, while FP4 MOE kernels accept optional values.
335+
- CUTLASS fused MoE (`cutlass_fused_moe`) ignores `--routing_method`, `--n_group`, and `--topk_group`; it computes routing via softmax+top-k internally from the provided logits.
336+
168337
## Tester Attention Backend Support Matrix
169338
The following support surface applies to attention operations in `flashinfer_benchmark.py`
170339
| Backend | Decode Paged | Prefill Paged | Prefill Ragged | FP8 | Notes |

benchmarks/flashinfer_benchmark.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
output_column_dict,
99
)
1010
from routines.gemm import parse_gemm_args, run_gemm_test
11+
from routines.moe import parse_moe_args, run_moe_test
1112

1213

1314
def run_test(args):
@@ -23,6 +24,8 @@ def run_test(args):
2324
res = run_attention_test(args)
2425
elif args.routine in benchmark_apis["gemm"]:
2526
res = run_gemm_test(args)
27+
elif args.routine in benchmark_apis["moe"]:
28+
res = run_moe_test(args)
2629
else:
2730
raise ValueError(f"Unsupported routine: {args.routine}")
2831

@@ -60,7 +63,9 @@ def parse_args(line=sys.argv[1:]):
6063
"-R",
6164
type=str,
6265
required=True,
63-
choices=list(benchmark_apis["attention"]) + list(benchmark_apis["gemm"]),
66+
choices=list(benchmark_apis["attention"])
67+
+ list(benchmark_apis["gemm"])
68+
+ list(benchmark_apis["moe"]),
6469
)
6570
args, _ = parser.parse_known_args(line[:])
6671

@@ -117,6 +122,8 @@ def parse_args(line=sys.argv[1:]):
117122
args = parse_attention_args(line, parser)
118123
elif args.routine in benchmark_apis["gemm"]:
119124
args = parse_gemm_args(line, parser)
125+
elif args.routine in benchmark_apis["moe"]:
126+
args = parse_moe_args(line, parser)
120127
else:
121128
raise ValueError(f"Unsupported routine: {args.routine}")
122129

benchmarks/routines/flashinfer_benchmark_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,21 @@
4040
"mma_sm",
4141
"use_128x4_sf_layout",
4242
],
43+
"moe": [
44+
"num_tokens",
45+
"hidden_size",
46+
"intermediate_size",
47+
"num_experts",
48+
"top_k",
49+
"n_group",
50+
"topk_group",
51+
"routing_method",
52+
"use_shuffled_weight",
53+
"weight_layout",
54+
"use_routing_scales_on_input",
55+
"input_dtype",
56+
"weight_dtype",
57+
],
4358
"general": [
4459
"refcheck",
4560
"no_cuda_graph",
@@ -52,6 +67,7 @@
5267
output_column_dict["perf"]
5368
+ output_column_dict["attention"]
5469
+ output_column_dict["gemm"]
70+
+ output_column_dict["moe"]
5571
+ output_column_dict["general"]
5672
)
5773

@@ -68,6 +84,12 @@
6884
"bmm_fp8",
6985
"mm_fp4",
7086
],
87+
"moe": [
88+
"trtllm_fp4_block_scale_moe",
89+
"trtllm_fp8_block_scale_moe",
90+
"trtllm_fp8_per_tensor_scale_moe",
91+
"cutlass_fused_moe",
92+
],
7193
}
7294

7395

0 commit comments

Comments
 (0)