Skip to content

Commit 5c7808d

Browse files
committed
[Benchmark] Avoid using _run in TritonBench integration
1 parent 7958869 commit 5c7808d

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

benchmarks/run.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -410,10 +410,23 @@ def run_kernel_variants(
410410

411411
# Add operator-specific default args if provided
412412
if operator_args:
413+
print(
414+
f"Applying custom args for {operator_name}: {operator_args}",
415+
file=sys.stderr,
416+
)
417+
# First, remove any existing occurrences of these args
413418
for arg_name, arg_value in operator_args.items():
414419
arg_flag = f"--{arg_name.replace('_', '-')}"
415-
if arg_flag not in tritonbench_args:
416-
tritonbench_args.extend([arg_flag, str(arg_value)])
420+
# Remove existing arg if present
421+
while arg_flag in tritonbench_args:
422+
idx = tritonbench_args.index(arg_flag)
423+
tritonbench_args.pop(idx) # Remove flag
424+
if idx < len(tritonbench_args) and not tritonbench_args[idx].startswith(
425+
"--"
426+
):
427+
tritonbench_args.pop(idx) # Remove value
428+
# Add the custom arg
429+
tritonbench_args.extend([arg_flag, str(arg_value)])
417430

418431
# Parse known args and collect unknown ones for operator
419432
tb_args, unknown_args = tb_parser.parse_known_args(tritonbench_args)
@@ -583,8 +596,6 @@ def _inner() -> Callable[..., Any] | object:
583596
file=sys.stderr,
584597
)
585598

586-
from tritonbench.run import _run
587-
588599
# Handle input sharding if requested
589600
if input_shard_info:
590601
shard_idx, total_shards = input_shard_info
@@ -621,8 +632,16 @@ def _inner() -> Callable[..., Any] | object:
621632
# Re-parse args with the new input range
622633
tb_args, unknown_args = tb_parser.parse_known_args(tritonbench_args)
623634

624-
# Use tritonbench's _run function which handles arg processing
625-
_run(tb_args, unknown_args)
635+
# Use the public API to load and run the operator
636+
from tritonbench.operators import load_opbench_by_name
637+
638+
op_class = load_opbench_by_name(operator_name)
639+
benchmark = op_class(tb_args=tb_args, extra_args=unknown_args)
640+
benchmark.run()
641+
642+
# Print results if available
643+
if hasattr(benchmark, "output"):
644+
print(benchmark.output)
626645

627646
# Force garbage collection multiple times to ensure memory is freed
628647
for _ in range(3):

0 commit comments

Comments
 (0)