@@ -410,10 +410,23 @@ def run_kernel_variants(
410
410
411
411
# Add operator-specific default args if provided
412
412
if operator_args :
413
+ print (
414
+ f"Applying custom args for { operator_name } : { operator_args } " ,
415
+ file = sys .stderr ,
416
+ )
417
+ # First, remove any existing occurrences of these args
413
418
for arg_name , arg_value in operator_args .items ():
414
419
arg_flag = f"--{ arg_name .replace ('_' , '-' )} "
415
- if arg_flag not in tritonbench_args :
416
- tritonbench_args .extend ([arg_flag , str (arg_value )])
420
+ # Remove existing arg if present
421
+ while arg_flag in tritonbench_args :
422
+ idx = tritonbench_args .index (arg_flag )
423
+ tritonbench_args .pop (idx ) # Remove flag
424
+ if idx < len (tritonbench_args ) and not tritonbench_args [idx ].startswith (
425
+ "--"
426
+ ):
427
+ tritonbench_args .pop (idx ) # Remove value
428
+ # Add the custom arg
429
+ tritonbench_args .extend ([arg_flag , str (arg_value )])
417
430
418
431
# Parse known args and collect unknown ones for operator
419
432
tb_args , unknown_args = tb_parser .parse_known_args (tritonbench_args )
@@ -583,8 +596,6 @@ def _inner() -> Callable[..., Any] | object:
583
596
file = sys .stderr ,
584
597
)
585
598
586
- from tritonbench .run import _run
587
-
588
599
# Handle input sharding if requested
589
600
if input_shard_info :
590
601
shard_idx , total_shards = input_shard_info
@@ -621,8 +632,16 @@ def _inner() -> Callable[..., Any] | object:
621
632
# Re-parse args with the new input range
622
633
tb_args , unknown_args = tb_parser .parse_known_args (tritonbench_args )
623
634
624
- # Use tritonbench's _run function which handles arg processing
625
- _run (tb_args , unknown_args )
635
+ # Use the public API to load and run the operator
636
+ from tritonbench .operators import load_opbench_by_name
637
+
638
+ op_class = load_opbench_by_name (operator_name )
639
+ benchmark = op_class (tb_args = tb_args , extra_args = unknown_args )
640
+ benchmark .run ()
641
+
642
+ # Print results if available
643
+ if hasattr (benchmark , "output" ):
644
+ print (benchmark .output )
626
645
627
646
# Force garbage collection multiple times to ensure memory is freed
628
647
for _ in range (3 ):
0 commit comments