@@ -320,10 +320,23 @@ def run_kernel_variants(
320
320
321
321
# Add operator-specific default args if provided
322
322
if operator_args :
323
+ print (
324
+ f"Applying custom args for { operator_name } : { operator_args } " ,
325
+ file = sys .stderr ,
326
+ )
327
+ # First, remove any existing occurrences of these args
323
328
for arg_name , arg_value in operator_args .items ():
324
329
arg_flag = f"--{ arg_name .replace ('_' , '-' )} "
325
- if arg_flag not in tritonbench_args :
326
- tritonbench_args .extend ([arg_flag , str (arg_value )])
330
+ # Remove existing arg if present
331
+ while arg_flag in tritonbench_args :
332
+ idx = tritonbench_args .index (arg_flag )
333
+ tritonbench_args .pop (idx ) # Remove flag
334
+ if idx < len (tritonbench_args ) and not tritonbench_args [idx ].startswith (
335
+ "--"
336
+ ):
337
+ tritonbench_args .pop (idx ) # Remove value
338
+ # Add the custom arg
339
+ tritonbench_args .extend ([arg_flag , str (arg_value )])
327
340
328
341
# Parse known args and collect unknown ones for operator
329
342
tb_args , unknown_args = tb_parser .parse_known_args (tritonbench_args )
@@ -429,8 +442,6 @@ def _inner() -> Callable[..., Any] | object:
429
442
file = sys .stderr ,
430
443
)
431
444
432
- from tritonbench .run import _run
433
-
434
445
# Handle input sharding if requested
435
446
if input_shard_info :
436
447
shard_idx , total_shards = input_shard_info
@@ -467,8 +478,16 @@ def _inner() -> Callable[..., Any] | object:
467
478
# Re-parse args with the new input range
468
479
tb_args , unknown_args = tb_parser .parse_known_args (tritonbench_args )
469
480
470
- # Use tritonbench's _run function which handles arg processing
471
- _run (tb_args , unknown_args )
481
+ # Use the public API to load and run the operator
482
+ from tritonbench .operators import load_opbench_by_name
483
+
484
+ op_class = load_opbench_by_name (operator_name )
485
+ benchmark = op_class (tb_args = tb_args , extra_args = unknown_args )
486
+ benchmark .run ()
487
+
488
+ # Print results if available
489
+ if hasattr (benchmark , "output" ):
490
+ print (benchmark .output )
472
491
473
492
# Force garbage collection multiple times to ensure memory is freed
474
493
for _ in range (3 ):
0 commit comments