Skip to content

Commit 736f530

Browse files
committed
improve one-to-many kernel mapping infra
1 parent 9710f88 commit 736f530

File tree

1 file changed

+18
-21
lines changed

1 file changed

+18
-21
lines changed

benchmarks/run.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
from typing import Callable
2727

2828
# Maps tritonbench op names to Helion kernel examples
29-
# Can map to a single kernel or a list of kernels
30-
KERNEL_MAPPINGS: dict[str, tuple[str, str, str] | list[tuple[str, str, str]]] = {
31-
# <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
29+
# Structure: {tritonbench_op_name: (tritonbench_module, helion_module, helion_func) or [(helion_module, helion_func), ...]}
30+
KERNEL_MAPPINGS: dict[str, tuple[str, str, str] | tuple[str, list[tuple[str, str]]]] = {
31+
# Single kernel mapping: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
3232
# "vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"),
3333
# "embedding": (
3434
# "tritonbench.operators.embedding.operator",
@@ -76,19 +76,14 @@
7676
"examples.fp8_gemm",
7777
"fp8_gemm_tritonbench",
7878
),
79-
"gemm": [
80-
# List of gemm variants
81-
(
82-
"tritonbench.operators.gemm.operator",
83-
"examples.matmul",
84-
"matmul",
85-
),
86-
(
87-
"tritonbench.operators.gemm.operator",
88-
"examples.matmul_split_k",
89-
"matmul_split_k",
90-
),
91-
],
79+
# Multiple kernel mappings: (<tritonbench_module_path>, [(<helion_module>, <helion_func>), ...])
80+
"gemm": (
81+
"tritonbench.operators.gemm.operator",
82+
[
83+
("examples.matmul", "matmul"),
84+
("examples.matmul_split_k", "matmul_split_k"),
85+
],
86+
),
9287
}
9388

9489

@@ -221,10 +216,12 @@ def run_kernel(kernel_name: str, tritonbench_args: list[str]) -> None:
221216

222217
mapping = KERNEL_MAPPINGS[kernel_name]
223218

224-
# Check if it's a list of variants or a single kernel
225-
if isinstance(mapping, list):
226-
# Run each variant
227-
for i, (tritonbench_module, module_path, func_name) in enumerate(mapping):
219+
# Check if it's multiple variants or a single kernel
220+
if len(mapping) == 2 and isinstance(mapping[1], list):
221+
# Multiple variants with shared tritonbench module
222+
tritonbench_module = mapping[0]
223+
variants = mapping[1]
224+
for i, (module_path, func_name) in enumerate(variants):
228225
# Extract variant name from func_name for display
229226
variant_name = func_name
230227
if i > 0:
@@ -233,7 +230,7 @@ def run_kernel(kernel_name: str, tritonbench_args: list[str]) -> None:
233230
print(f"{'=' * 60}\n", file=sys.stderr)
234231
run_single_kernel_variant(kernel_name, tritonbench_module, module_path, func_name, tritonbench_args.copy(), variant_name)
235232
else:
236-
# Single kernel
233+
# Single kernel with full mapping
237234
tritonbench_module, module_path, func_name = mapping
238235
run_single_kernel_variant(kernel_name, tritonbench_module, module_path, func_name, tritonbench_args)
239236

0 commit comments

Comments
 (0)