|
5 | 5 | Currently supported kernels are listed in `KERNEL_MAPPINGS` in `benchmarks/run.py`.
|
6 | 6 |
|
7 | 7 | Usage:
|
8 |
| -$ python benchmarks/run.py [tritonbench args...] --kernel <kernel_name> |
| 8 | +$ python benchmarks/run.py [tritonbench args...] [--kernel <kernel_name(s)>] |
9 | 9 |
|
10 | 10 | Example usage:
|
11 |
| -$ python benchmarks/run.py --metrics speedup,accuracy --kernel vector_add |
| 11 | +$ python benchmarks/run.py --metrics speedup,accuracy --kernel vector_add # Runs vector_add kernel |
| 12 | +$ python benchmarks/run.py --metrics speedup,accuracy --kernel vector_add,rms_norm # Runs multiple kernels |
| 13 | +$ python benchmarks/run.py --metrics speedup,accuracy # Runs all kernels |
12 | 14 | """
|
13 | 15 |
|
14 | 16 | from __future__ import annotations
|
15 | 17 |
|
16 | 18 | import argparse
|
| 19 | +import gc |
17 | 20 | import importlib
|
| 21 | +import os |
18 | 22 | from pathlib import Path
|
19 | 23 | import subprocess
|
20 | 24 | import sys
|
@@ -171,26 +175,16 @@ def check_and_setup_tritonbench() -> None:
|
171 | 175 | sys.exit(1)
|
172 | 176 |
|
173 | 177 |
|
174 |
| -def main() -> None: |
175 |
| - # Parse command line arguments |
176 |
| - parser = argparse.ArgumentParser(description="Run Helion kernels with tritonbench") |
177 |
| - parser.add_argument( |
178 |
| - "--kernel", |
179 |
| - type=str, |
180 |
| - required=True, |
181 |
| - help="Name of the Helion kernel module (e.g., vector_add)", |
182 |
| - ) |
183 |
| - |
184 |
| - # Parse known args to get the kernel name, pass rest to tritonbench |
185 |
| - args, tritonbench_args = parser.parse_known_args() |
186 |
| - |
187 |
| - # Check and setup tritonbench if needed |
188 |
| - check_and_setup_tritonbench() |
189 |
| - |
190 |
| - kernel_name = args.kernel |
191 |
| - |
| 178 | +def run_kernel(kernel_name: str, tritonbench_args: list[str]) -> None: |
| 179 | + """Run a single kernel benchmark.""" |
192 | 180 | # Check if kernel is in the mapping table
|
193 |
| - assert kernel_name in KERNEL_MAPPINGS |
| 181 | + if kernel_name not in KERNEL_MAPPINGS: |
| 182 | + print(f"Error: Unknown kernel '{kernel_name}'", file=sys.stderr) |
| 183 | + print( |
| 184 | + f"Available kernels: {', '.join(KERNEL_MAPPINGS.keys())}", file=sys.stderr |
| 185 | + ) |
| 186 | + sys.exit(1) |
| 187 | + |
194 | 188 | tritonbench_module, module_path, func_name = KERNEL_MAPPINGS[kernel_name]
|
195 | 189 |
|
196 | 190 | # Import from the mapped module
|
@@ -274,6 +268,15 @@ def helion_method(
|
274 | 268 | attr.reset()
|
275 | 269 |
|
276 | 270 | def _inner() -> Callable[..., Any]:
|
| 271 | + # Force autotuning unless HELION_USE_DEFAULT_CONFIG=1 is set |
| 272 | + # This ensures we run autotuning even if the kernel has pre-specified configs |
| 273 | + if os.environ.get("HELION_USE_DEFAULT_CONFIG", "0") != "1": |
| 274 | + # Find all Kernel objects in the module and force autotuning |
| 275 | + for attr_name in dir(module): |
| 276 | + attr = getattr(module, attr_name) |
| 277 | + if isinstance(attr, Kernel): |
| 278 | + attr.settings.force_autotune = True |
| 279 | + |
277 | 280 | return kernel_func(*args)
|
278 | 281 |
|
279 | 282 | return _inner
|
@@ -316,6 +319,69 @@ def _inner() -> Callable[..., Any]:
|
316 | 319 | print("\nBenchmark Results:", file=sys.stderr)
|
317 | 320 | print(op.output, file=sys.stderr)
|
318 | 321 |
|
| 322 | + # Clean up memory after running the kernel |
| 323 | + # Delete the operator instance which contains all allocated tensors |
| 324 | + del op |
| 325 | + |
| 326 | + # Force garbage collection multiple times to ensure memory is freed |
| 327 | + for _ in range(3): |
| 328 | + gc.collect() |
| 329 | + |
| 330 | + |
| 331 | +def main() -> None: |
| 332 | + # Parse command line arguments |
| 333 | + parser = argparse.ArgumentParser(description="Run Helion kernels with tritonbench") |
| 334 | + parser.add_argument( |
| 335 | + "--kernel", |
| 336 | + type=str, |
| 337 | + help="Name(s) of the Helion kernel module(s) to run. Can be a single kernel or comma-separated list (e.g., vector_add or vector_add,rms_norm). If not specified, runs all kernels.", |
| 338 | + ) |
| 339 | + |
| 340 | + # Parse known args to get the kernel name, pass rest to tritonbench |
| 341 | + args, tritonbench_args = parser.parse_known_args() |
| 342 | + |
| 343 | + # Check and setup tritonbench if needed |
| 344 | + check_and_setup_tritonbench() |
| 345 | + |
| 346 | + if args.kernel: |
| 347 | + # Parse comma-separated kernel names |
| 348 | + kernel_names = [k.strip() for k in args.kernel.split(",")] |
| 349 | + |
| 350 | + # Validate all kernel names first |
| 351 | + invalid_kernels = [k for k in kernel_names if k not in KERNEL_MAPPINGS] |
| 352 | + if invalid_kernels: |
| 353 | + print( |
| 354 | + f"Error: Unknown kernel(s): {', '.join(invalid_kernels)}", |
| 355 | + file=sys.stderr, |
| 356 | + ) |
| 357 | + print( |
| 358 | + f"Available kernels: {', '.join(KERNEL_MAPPINGS.keys())}", |
| 359 | + file=sys.stderr, |
| 360 | + ) |
| 361 | + sys.exit(1) |
| 362 | + |
| 363 | + # Run specified kernels |
| 364 | + if len(kernel_names) == 1: |
| 365 | + run_kernel(kernel_names[0], tritonbench_args) |
| 366 | + else: |
| 367 | + print( |
| 368 | + f"Running {len(kernel_names)} kernels: {', '.join(kernel_names)}...\n", |
| 369 | + file=sys.stderr, |
| 370 | + ) |
| 371 | + for kernel_name in kernel_names: |
| 372 | + print(f"\n{'=' * 60}", file=sys.stderr) |
| 373 | + print(f"Kernel: {kernel_name}", file=sys.stderr) |
| 374 | + print(f"{'=' * 60}\n", file=sys.stderr) |
| 375 | + run_kernel(kernel_name, tritonbench_args.copy()) |
| 376 | + else: |
| 377 | + # Run all kernels |
| 378 | + print(f"Running all {len(KERNEL_MAPPINGS)} kernels...\n", file=sys.stderr) |
| 379 | + for kernel_name in KERNEL_MAPPINGS: |
| 380 | + print(f"\n{'=' * 60}", file=sys.stderr) |
| 381 | + print(f"Kernel: {kernel_name}", file=sys.stderr) |
| 382 | + print(f"{'=' * 60}\n", file=sys.stderr) |
| 383 | + run_kernel(kernel_name, tritonbench_args.copy()) |
| 384 | + |
319 | 385 |
|
320 | 386 | if __name__ == "__main__":
|
321 | 387 | main()
|
0 commit comments