30
30
31
31
# Maps tritonbench op names to Helion kernel examples
32
32
# Can map to a single kernel or a list of kernel variants
33
- KERNEL_MAPPINGS : dict [str , tuple [str , str , str ] | tuple [str , list [tuple [str , str ]]]] = {
33
+ # Format options:
34
+ # - Single kernel: (tritonbench_module, helion_module, helion_func)
35
+ # - Single kernel with args: (tritonbench_module, helion_module, helion_func, args_dict)
36
+ # - Multiple kernels: (tritonbench_module, [(helion_module, helion_func), ...])
37
+ # - Multiple kernels with args: (tritonbench_module, [(helion_module, helion_func), ...], args_dict)
38
+ KERNEL_MAPPINGS : dict [str , tuple [str , ...]] = { # pyright: ignore[reportAssignmentType]
34
39
# <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
35
40
"vector_add" : ("tritonbench.operators.vector_add.operator" , "examples.add" , "add" ),
36
41
"embedding" : (
47
52
"tritonbench.operators.rms_norm.operator" ,
48
53
"examples.rms_norm" ,
49
54
"rms_norm_tritonbench" ,
55
+ {
56
+ "num_inputs" : 3
57
+ }, # TODO(yf225): reduction dim size = 8192 currently throws error
50
58
),
51
59
"sum" : ("tritonbench.operators.sum.operator" , "examples.sum" , "sum_tritonbench" ),
52
60
"softmax" : (
58
66
"tritonbench.operators.jagged_mean.operator" ,
59
67
"examples.jagged_mean" ,
60
68
"jagged_mean_tritonbench" ,
69
+ {"B" : 32 , "M" : 8 , "seqlen" : 64 }
70
+ if os .environ .get ("HELION_DEV_LOW_VRAM" , "0" ) == "1"
71
+ else {},
61
72
),
62
73
"fp8_gemm" : (
63
74
"tritonbench.operators.fp8_gemm.fp8_gemm" ,
68
79
"tritonbench.operators.flash_attention.operator" ,
69
80
"examples.attention" ,
70
81
"attention" ,
82
+ {
83
+ "d_head" : 128
84
+ }, # Set default head dimension to 128 for TLX attention compatibility
71
85
),
72
86
"cross_entropy" : (
73
87
"tritonbench.operators.cross_entropy.operator" ,
74
88
"examples.cross_entropy" ,
75
89
"cross_entropy" ,
90
+ {"B" : 4 , "T" : 512 , "v_range" : "10,15" }
91
+ if os .environ .get ("HELION_DEV_LOW_VRAM" , "0" ) == "1"
92
+ else {},
76
93
),
77
94
"fp8_attention" : (
78
95
"tritonbench.operators.fp8_attention.operator" ,
@@ -233,20 +250,40 @@ def run_kernel(
233
250
234
251
mapping = KERNEL_MAPPINGS [kernel_name ]
235
252
253
+ # Extract operator args if present
254
+ operator_args = {}
255
+
236
256
# Normalize to list of variants format
237
- if len ( mapping ) == 2 and isinstance (mapping [1 ], list ):
238
- # Multiple variants with shared tritonbench module
257
+ if isinstance (mapping [1 ], list ):
258
+ # Multiple variants format
239
259
tritonbench_module = mapping [0 ]
240
260
variants = mapping [1 ]
261
+ # Check if last element is args dict
262
+ if len (mapping ) > 2 and isinstance (mapping [2 ], dict ):
263
+ operator_args = mapping [2 ]
241
264
else :
242
- # Single kernel with full mapping - convert to list format
243
- assert len (mapping ) == 3 # Type narrowing for pyright
244
- tritonbench_module , module_path , func_name = mapping
245
- variants = [(module_path , func_name )]
265
+ # Single kernel format
266
+ if len (mapping ) == 4 and isinstance (mapping [3 ], dict ):
267
+ # With args
268
+ tritonbench_module = mapping [0 ]
269
+ module_path = mapping [1 ]
270
+ func_name = mapping [2 ]
271
+ operator_args = mapping [3 ] # pyright: ignore[reportGeneralTypeIssues]
272
+ variants = [(module_path , func_name )]
273
+ else :
274
+ # Without args
275
+ assert len (mapping ) == 3 # Type narrowing for pyright
276
+ tritonbench_module , module_path , func_name = mapping
277
+ variants = [(module_path , func_name )]
246
278
247
279
# Run all variants in the same benchmark
248
280
run_kernel_variants (
249
- kernel_name , tritonbench_module , variants , tritonbench_args , input_shard_info
281
+ kernel_name ,
282
+ tritonbench_module ,
283
+ variants ,
284
+ tritonbench_args ,
285
+ input_shard_info ,
286
+ operator_args ,
250
287
)
251
288
252
289
@@ -256,6 +293,7 @@ def run_kernel_variants(
256
293
variants : list [tuple [str , str ]],
257
294
tritonbench_args : list [str ],
258
295
input_shard_info : tuple [int , int ] | None = None ,
296
+ operator_args : dict [str , Any ] | None = None ,
259
297
) -> None :
260
298
"""Run kernel variants in the same benchmark run."""
261
299
@@ -280,21 +318,12 @@ def run_kernel_variants(
280
318
assert "--op" not in tritonbench_args
281
319
tritonbench_args = ["--op" , operator_name , * tritonbench_args ]
282
320
283
- # Collect all module args from all variants
284
- all_module_args = {}
285
- for module_path , _ in variants :
286
- try :
287
- module = importlib .import_module (module_path )
288
- module_args = getattr (module , "TRITONBENCH_ARGS" , {})
289
- all_module_args .update (module_args )
290
- except ImportError :
291
- pass
292
-
293
- # Add module args to tritonbench_args if not already present
294
- for arg_name , arg_value in all_module_args .items ():
295
- arg_flag = f"--{ arg_name .replace ('_' , '-' )} "
296
- if arg_flag not in tritonbench_args :
297
- tritonbench_args .extend ([arg_flag , str (arg_value )])
321
+ # Add operator-specific default args if provided
322
+ if operator_args :
323
+ for arg_name , arg_value in operator_args .items ():
324
+ arg_flag = f"--{ arg_name .replace ('_' , '-' )} "
325
+ if arg_flag not in tritonbench_args :
326
+ tritonbench_args .extend ([arg_flag , str (arg_value )])
298
327
299
328
# Parse known args and collect unknown ones for operator
300
329
tb_args , unknown_args = tb_parser .parse_known_args (tritonbench_args )
0 commit comments