Skip to content
Closed
16 changes: 16 additions & 0 deletions benchmarks/flashinfer_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from routines.attention import parse_attention_args, run_attention_test
from routines.flashinfer_benchmark_utils import full_output_columns, output_column_dict
from routines.gemm import parse_gemm_args, run_gemm_test
from routines.moe import parse_moe_args, run_moe_test


def run_test(args):
Expand All @@ -26,6 +27,12 @@ def run_test(args):
"group_gemm_fp8_nt_groupwise",
]:
res = run_gemm_test(args)
elif args.routine in [
"trtllm_fp4_block_scale_moe",
"trtllm_fp8_block_scale_moe",
"trtllm_fp8_per_tensor_scale_moe",
]:
Comment on lines +30 to +34
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The list of MoE routines is duplicated here and in parse_args (lines 79-81 and 147-150). To improve maintainability and avoid potential inconsistencies, consider defining constants for each routine type (e.g., ATTENTION_ROUTINES, GEMM_ROUTINES, MOE_ROUTINES) at the module level and reusing them.

res = run_moe_test(args)
else:
raise ValueError(f"Unsupported routine: {args.routine}")

Expand Down Expand Up @@ -69,6 +76,9 @@ def parse_args(line=sys.argv[1:]):
"BatchPrefillWithRaggedKVCacheWrapper",
"gemm_fp8_nt_groupwise",
"group_gemm_fp8_nt_groupwise",
"trtllm_fp4_block_scale_moe",
"trtllm_fp8_block_scale_moe",
"trtllm_fp8_per_tensor_scale_moe",
],
)
args, _ = parser.parse_known_args(line[:])
Expand Down Expand Up @@ -133,6 +143,12 @@ def parse_args(line=sys.argv[1:]):
"group_gemm_fp8_nt_groupwise",
]:
args = parse_gemm_args(line, parser)
elif args.routine in [
"trtllm_fp4_block_scale_moe",
"trtllm_fp8_block_scale_moe",
"trtllm_fp8_per_tensor_scale_moe",
]:
args = parse_moe_args(line, parser)
else:
raise ValueError(f"Unsupported routine: {args.routine}")

Expand Down
15 changes: 15 additions & 0 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@
"out_dtype",
"mma_sm",
],
"moe": [
"num_tokens",
"hidden_size",
"intermediate_size",
"num_experts",
"top_k",
"n_groups",
"top_k_groups",
"routing_method_type",
"routed_scaling_factor",
"tile_tokens_dim",
"use_shuffled_weight",
"weight_layout",
],
"general": [
"refcheck",
"no_cuda_graph",
Expand All @@ -45,5 +59,6 @@
output_column_dict["perf"]
+ output_column_dict["attention"]
+ output_column_dict["gemm"]
+ output_column_dict["moe"]
+ output_column_dict["general"]
)
Loading