Skip to content

Add moe benchmark routine #1327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft
16 changes: 16 additions & 0 deletions benchmarks/flashinfer_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from routines.attention import parse_attention_args, run_attention_test
from routines.flashinfer_benchmark_utils import full_output_columns, output_column_dict
from routines.gemm import parse_gemm_args, run_gemm_test
from routines.moe import parse_moe_args, run_moe_test


def run_test(args):
Expand All @@ -26,6 +27,12 @@ def run_test(args):
"group_gemm_fp8_nt_groupwise",
]:
res = run_gemm_test(args)
elif args.routine in [
"trtllm_fp4_block_scale_moe",
"trtllm_fp8_block_scale_moe",
"trtllm_fp8_per_tensor_scale_moe",
]:
Comment on lines +30 to +34
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The list of MoE routines is duplicated here and in parse_args (lines 79-81 and 147-150). To improve maintainability and avoid potential inconsistencies, consider defining constants for each routine type (e.g., ATTENTION_ROUTINES, GEMM_ROUTINES, MOE_ROUTINES) at the module level and reusing them.

res = run_moe_test(args)
else:
raise ValueError(f"Unsupported routine: {args.routine}")

Expand Down Expand Up @@ -69,6 +76,9 @@ def parse_args(line=sys.argv[1:]):
"BatchPrefillWithRaggedKVCacheWrapper",
"gemm_fp8_nt_groupwise",
"group_gemm_fp8_nt_groupwise",
"trtllm_fp4_block_scale_moe",
"trtllm_fp8_block_scale_moe",
"trtllm_fp8_per_tensor_scale_moe",
],
)
args, _ = parser.parse_known_args(line[:])
Expand Down Expand Up @@ -133,6 +143,12 @@ def parse_args(line=sys.argv[1:]):
"group_gemm_fp8_nt_groupwise",
]:
args = parse_gemm_args(line, parser)
elif args.routine in [
"trtllm_fp4_block_scale_moe",
"trtllm_fp8_block_scale_moe",
"trtllm_fp8_per_tensor_scale_moe",
]:
args = parse_moe_args(line, parser)
else:
raise ValueError(f"Unsupported routine: {args.routine}")

Expand Down
15 changes: 15 additions & 0 deletions benchmarks/routines/flashinfer_benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@
"out_dtype",
"mma_sm",
],
"moe": [
"num_tokens",
"hidden_size",
"intermediate_size",
"num_experts",
"top_k",
"n_groups",
"top_k_groups",
"routing_method_type",
"routed_scaling_factor",
"tile_tokens_dim",
"use_shuffled_weight",
"weight_layout",
],
"general": [
"refcheck",
"no_cuda_graph",
Expand All @@ -45,5 +59,6 @@
output_column_dict["perf"]
+ output_column_dict["attention"]
+ output_column_dict["gemm"]
+ output_column_dict["moe"]
+ output_column_dict["general"]
)
Loading