26
26
from typing import Callable
27
27
28
28
# Maps tritonbench op names to Helion kernel examples
29
- # Can map to a single kernel or a list of kernels
30
- KERNEL_MAPPINGS : dict [str , tuple [str , str , str ] | list [ tuple [str , str , str ]]] = {
31
- # <tritonbench_op_name> : (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
29
+ # Structure: {tritonbench_op_name: (tritonbench_module, helion_module, helion_func) or [(helion_module, helion_func), ...]}
30
+ KERNEL_MAPPINGS : dict [str , tuple [str , str , str ] | tuple [str , list [ tuple [ str , str ] ]]] = {
31
+ # Single kernel mapping : (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
32
32
# "vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"),
33
33
# "embedding": (
34
34
# "tritonbench.operators.embedding.operator",
76
76
"examples.fp8_gemm" ,
77
77
"fp8_gemm_tritonbench" ,
78
78
),
79
- "gemm" : [
80
- # List of gemm variants
81
- (
82
- "tritonbench.operators.gemm.operator" ,
83
- "examples.matmul" ,
84
- "matmul" ,
85
- ),
86
- (
87
- "tritonbench.operators.gemm.operator" ,
88
- "examples.matmul_split_k" ,
89
- "matmul_split_k" ,
90
- ),
91
- ],
79
+ # Multiple kernel mappings: (<tritonbench_module_path>, [(<helion_module>, <helion_func>), ...])
80
+ "gemm" : (
81
+ "tritonbench.operators.gemm.operator" ,
82
+ [
83
+ ("examples.matmul" , "matmul" ),
84
+ ("examples.matmul_split_k" , "matmul_split_k" ),
85
+ ],
86
+ ),
92
87
}
93
88
94
89
@@ -221,10 +216,12 @@ def run_kernel(kernel_name: str, tritonbench_args: list[str]) -> None:
221
216
222
217
mapping = KERNEL_MAPPINGS [kernel_name ]
223
218
224
- # Check if it's a list of variants or a single kernel
225
- if isinstance (mapping , list ):
226
- # Run each variant
227
- for i , (tritonbench_module , module_path , func_name ) in enumerate (mapping ):
219
+ # Check if it's multiple variants or a single kernel
220
+ if len (mapping ) == 2 and isinstance (mapping [1 ], list ):
221
+ # Multiple variants with shared tritonbench module
222
+ tritonbench_module = mapping [0 ]
223
+ variants = mapping [1 ]
224
+ for i , (module_path , func_name ) in enumerate (variants ):
228
225
# Extract variant name from func_name for display
229
226
variant_name = func_name
230
227
if i > 0 :
@@ -233,7 +230,7 @@ def run_kernel(kernel_name: str, tritonbench_args: list[str]) -> None:
233
230
print (f"{ '=' * 60 } \n " , file = sys .stderr )
234
231
run_single_kernel_variant (kernel_name , tritonbench_module , module_path , func_name , tritonbench_args .copy (), variant_name )
235
232
else :
236
- # Single kernel
233
+ # Single kernel with full mapping
237
234
tritonbench_module , module_path , func_name = mapping
238
235
run_single_kernel_variant (kernel_name , tritonbench_module , module_path , func_name , tritonbench_args )
239
236
0 commit comments