Skip to content

Commit ea8c7fd

Browse files
authored
[Benchmark] Avoid using _run in TritonBench integration (#444)
1 parent f105b05 commit ea8c7fd

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

benchmarks/run.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -320,10 +320,23 @@ def run_kernel_variants(
320320

321321
# Add operator-specific default args if provided
322322
if operator_args:
323+
print(
324+
f"Applying custom args for {operator_name}: {operator_args}",
325+
file=sys.stderr,
326+
)
327+
# First, remove any existing occurrences of these args
323328
for arg_name, arg_value in operator_args.items():
324329
arg_flag = f"--{arg_name.replace('_', '-')}"
325-
if arg_flag not in tritonbench_args:
326-
tritonbench_args.extend([arg_flag, str(arg_value)])
330+
# Remove existing arg if present
331+
while arg_flag in tritonbench_args:
332+
idx = tritonbench_args.index(arg_flag)
333+
tritonbench_args.pop(idx) # Remove flag
334+
if idx < len(tritonbench_args) and not tritonbench_args[idx].startswith(
335+
"--"
336+
):
337+
tritonbench_args.pop(idx) # Remove value
338+
# Add the custom arg
339+
tritonbench_args.extend([arg_flag, str(arg_value)])
327340

328341
# Parse known args and collect unknown ones for operator
329342
tb_args, unknown_args = tb_parser.parse_known_args(tritonbench_args)
@@ -429,8 +442,6 @@ def _inner() -> Callable[..., Any] | object:
429442
file=sys.stderr,
430443
)
431444

432-
from tritonbench.run import _run
433-
434445
# Handle input sharding if requested
435446
if input_shard_info:
436447
shard_idx, total_shards = input_shard_info
@@ -467,8 +478,16 @@ def _inner() -> Callable[..., Any] | object:
467478
# Re-parse args with the new input range
468479
tb_args, unknown_args = tb_parser.parse_known_args(tritonbench_args)
469480

470-
# Use tritonbench's _run function which handles arg processing
471-
_run(tb_args, unknown_args)
481+
# Use the public API to load and run the operator
482+
from tritonbench.operators import load_opbench_by_name
483+
484+
op_class = load_opbench_by_name(operator_name)
485+
benchmark = op_class(tb_args=tb_args, extra_args=unknown_args)
486+
benchmark.run()
487+
488+
# Print results if available
489+
if hasattr(benchmark, "output"):
490+
print(benchmark.output)
472491

473492
# Force garbage collection multiple times to ensure memory is freed
474493
for _ in range(3):

0 commit comments

Comments
 (0)