Skip to content

Commit c90a4ef

Browse files
authored
[Benchmark] Allow using 'python benchmarks/run.py' to run all kernels (#280)
1 parent 6d54c8f commit c90a4ef

File tree

1 file changed

+87
-21
lines changed

1 file changed

+87
-21
lines changed

benchmarks/run.py

Lines changed: 87 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,20 @@
55
Currently supported kernels are listed in `KERNEL_MAPPINGS` in `benchmarks/run.py`.
66
77
Usage:
8-
$ python benchmarks/run.py [tritonbench args...] --kernel <kernel_name>
8+
$ python benchmarks/run.py [tritonbench args...] [--kernel <kernel_name(s)>]
99
1010
Example usage:
11-
$ python benchmarks/run.py --metrics speedup,accuracy --kernel vector_add
11+
$ python benchmarks/run.py --metrics speedup,accuracy --kernel vector_add # Runs vector_add kernel
12+
$ python benchmarks/run.py --metrics speedup,accuracy --kernel vector_add,rms_norm # Runs multiple kernels
13+
$ python benchmarks/run.py --metrics speedup,accuracy # Runs all kernels
1214
"""
1315

1416
from __future__ import annotations
1517

1618
import argparse
19+
import gc
1720
import importlib
21+
import os
1822
from pathlib import Path
1923
import subprocess
2024
import sys
@@ -171,26 +175,16 @@ def check_and_setup_tritonbench() -> None:
171175
sys.exit(1)
172176

173177

174-
def main() -> None:
175-
# Parse command line arguments
176-
parser = argparse.ArgumentParser(description="Run Helion kernels with tritonbench")
177-
parser.add_argument(
178-
"--kernel",
179-
type=str,
180-
required=True,
181-
help="Name of the Helion kernel module (e.g., vector_add)",
182-
)
183-
184-
# Parse known args to get the kernel name, pass rest to tritonbench
185-
args, tritonbench_args = parser.parse_known_args()
186-
187-
# Check and setup tritonbench if needed
188-
check_and_setup_tritonbench()
189-
190-
kernel_name = args.kernel
191-
178+
def run_kernel(kernel_name: str, tritonbench_args: list[str]) -> None:
179+
"""Run a single kernel benchmark."""
192180
# Check if kernel is in the mapping table
193-
assert kernel_name in KERNEL_MAPPINGS
181+
if kernel_name not in KERNEL_MAPPINGS:
182+
print(f"Error: Unknown kernel '{kernel_name}'", file=sys.stderr)
183+
print(
184+
f"Available kernels: {', '.join(KERNEL_MAPPINGS.keys())}", file=sys.stderr
185+
)
186+
sys.exit(1)
187+
194188
tritonbench_module, module_path, func_name = KERNEL_MAPPINGS[kernel_name]
195189

196190
# Import from the mapped module
@@ -274,6 +268,15 @@ def helion_method(
274268
attr.reset()
275269

276270
def _inner() -> Callable[..., Any]:
271+
# Force autotuning unless HELION_USE_DEFAULT_CONFIG=1 is set
272+
# This ensures we run autotuning even if the kernel has pre-specified configs
273+
if os.environ.get("HELION_USE_DEFAULT_CONFIG", "0") != "1":
274+
# Find all Kernel objects in the module and force autotuning
275+
for attr_name in dir(module):
276+
attr = getattr(module, attr_name)
277+
if isinstance(attr, Kernel):
278+
attr.settings.force_autotune = True
279+
277280
return kernel_func(*args)
278281

279282
return _inner
@@ -316,6 +319,69 @@ def _inner() -> Callable[..., Any]:
316319
print("\nBenchmark Results:", file=sys.stderr)
317320
print(op.output, file=sys.stderr)
318321

322+
# Clean up memory after running the kernel
323+
# Delete the operator instance which contains all allocated tensors
324+
del op
325+
326+
# Force garbage collection multiple times to ensure memory is freed
327+
for _ in range(3):
328+
gc.collect()
329+
330+
331+
def main() -> None:
332+
# Parse command line arguments
333+
parser = argparse.ArgumentParser(description="Run Helion kernels with tritonbench")
334+
parser.add_argument(
335+
"--kernel",
336+
type=str,
337+
help="Name(s) of the Helion kernel module(s) to run. Can be a single kernel or comma-separated list (e.g., vector_add or vector_add,rms_norm). If not specified, runs all kernels.",
338+
)
339+
340+
# Parse known args to get the kernel name, pass rest to tritonbench
341+
args, tritonbench_args = parser.parse_known_args()
342+
343+
# Check and setup tritonbench if needed
344+
check_and_setup_tritonbench()
345+
346+
if args.kernel:
347+
# Parse comma-separated kernel names
348+
kernel_names = [k.strip() for k in args.kernel.split(",")]
349+
350+
# Validate all kernel names first
351+
invalid_kernels = [k for k in kernel_names if k not in KERNEL_MAPPINGS]
352+
if invalid_kernels:
353+
print(
354+
f"Error: Unknown kernel(s): {', '.join(invalid_kernels)}",
355+
file=sys.stderr,
356+
)
357+
print(
358+
f"Available kernels: {', '.join(KERNEL_MAPPINGS.keys())}",
359+
file=sys.stderr,
360+
)
361+
sys.exit(1)
362+
363+
# Run specified kernels
364+
if len(kernel_names) == 1:
365+
run_kernel(kernel_names[0], tritonbench_args)
366+
else:
367+
print(
368+
f"Running {len(kernel_names)} kernels: {', '.join(kernel_names)}...\n",
369+
file=sys.stderr,
370+
)
371+
for kernel_name in kernel_names:
372+
print(f"\n{'=' * 60}", file=sys.stderr)
373+
print(f"Kernel: {kernel_name}", file=sys.stderr)
374+
print(f"{'=' * 60}\n", file=sys.stderr)
375+
run_kernel(kernel_name, tritonbench_args.copy())
376+
else:
377+
# Run all kernels
378+
print(f"Running all {len(KERNEL_MAPPINGS)} kernels...\n", file=sys.stderr)
379+
for kernel_name in KERNEL_MAPPINGS:
380+
print(f"\n{'=' * 60}", file=sys.stderr)
381+
print(f"Kernel: {kernel_name}", file=sys.stderr)
382+
print(f"{'=' * 60}\n", file=sys.stderr)
383+
run_kernel(kernel_name, tritonbench_args.copy())
384+
319385

320386
if __name__ == "__main__":
321387
main()

0 commit comments

Comments
 (0)