|
| 1 | +# pyright: reportMissingImports=false |
| 2 | + |
1 | 3 | """Performance comparison between Helion, torch.compile, Triton, and PyTorch eager by leveraging TritonBench.
|
2 | 4 |
|
3 | 5 | Currently supported kernels are listed in `KERNEL_MAPPINGS` in `benchmarks/run.py`.
|
@@ -242,62 +244,60 @@ def main() -> None:
|
242 | 244 | # Parse known args and collect unknown ones for operator
|
243 | 245 | tb_args, unknown_args = tb_parser.parse_known_args(tritonbench_args)
|
244 | 246 |
|
245 |
| - # Register the Helion kernel with tritonbench BEFORE importing the operator |
246 |
| - from tritonbench.utils.triton_op import ( # type: ignore[reportMissingImports] |
247 |
| - register_benchmark, |
248 |
| - ) |
| 247 | + # Import and run the operator |
| 248 | + try: |
| 249 | + operator_module = importlib.import_module(tritonbench_module) |
| 250 | + Operator = operator_module.Operator |
| 251 | + except ImportError as e: |
| 252 | + print( |
| 253 | + f"Error: Could not import operator '{operator_name}' from tritonbench", |
| 254 | + file=sys.stderr, |
| 255 | + ) |
| 256 | + print(f"Tried: {tritonbench_module}", file=sys.stderr) |
| 257 | + print(f"Import error: {e}", file=sys.stderr) |
| 258 | + sys.exit(1) |
249 | 259 |
|
250 | 260 | # Create the benchmark method
|
251 |
| - def create_helion_method( |
252 |
| - kernel_func: Callable[..., Any], |
| 261 | + def helion_method( |
| 262 | + self: Any, |
| 263 | + *args: Any, |
253 | 264 | ) -> Callable[..., Any]:
|
254 |
| - def helion_method( |
255 |
| - self: Any, |
256 |
| - *args: Any, |
257 |
| - ) -> Callable[..., Any]: |
258 |
| - """Helion implementation.""" |
| 265 | + """Helion implementation.""" |
259 | 266 |
|
260 |
| - # Reset all Helion kernels before creating the benchmark function |
261 |
| - # so that each input size can go through its own autotuning. |
262 |
| - from helion.runtime.kernel import Kernel |
| 267 | + # Reset all Helion kernels before creating the benchmark function |
| 268 | + # so that each input size can go through its own autotuning. |
| 269 | + from helion.runtime.kernel import Kernel |
263 | 270 |
|
264 |
| - for attr_name in dir(module): |
265 |
| - attr = getattr(module, attr_name) |
266 |
| - if isinstance(attr, Kernel): |
267 |
| - attr.reset() |
| 271 | + for attr_name in dir(module): |
| 272 | + attr = getattr(module, attr_name) |
| 273 | + if isinstance(attr, Kernel): |
| 274 | + attr.reset() |
268 | 275 |
|
269 |
| - def _inner() -> Callable[..., Any]: |
270 |
| - return kernel_func(*args) |
| 276 | + def _inner() -> Callable[..., Any]: |
| 277 | + return kernel_func(*args) |
271 | 278 |
|
272 |
| - return _inner |
| 279 | + return _inner |
273 | 280 |
|
274 |
| - return helion_method |
275 |
| - |
276 |
| - # Register it as a benchmark first |
| 281 | + # Method name for the benchmark |
277 | 282 | helion_method_name = f"helion_{kernel_name}"
|
278 |
| - register_benchmark( |
| 283 | + |
| 284 | + # Import register_benchmark API |
| 285 | + from tritonbench.utils.triton_op import ( # pyright: ignore[reportMissingImports] |
| 286 | + register_benchmark, |
| 287 | + ) |
| 288 | + |
| 289 | + # Use register_benchmark decorator |
| 290 | + decorated_method = register_benchmark( |
279 | 291 | operator_name=operator_name,
|
280 | 292 | func_name=helion_method_name,
|
281 | 293 | baseline=False,
|
282 | 294 | enabled=True,
|
| 295 | + fwd_only=False, |
283 | 296 | label=helion_method_name,
|
284 |
| - ) |
285 |
| - |
286 |
| - # Import and run the operator |
287 |
| - try: |
288 |
| - operator_module = importlib.import_module(tritonbench_module) |
289 |
| - Operator = operator_module.Operator |
290 |
| - except ImportError as e: |
291 |
| - print( |
292 |
| - f"Error: Could not import operator '{operator_name}' from tritonbench", |
293 |
| - file=sys.stderr, |
294 |
| - ) |
295 |
| - print(f"Tried: {tritonbench_module}", file=sys.stderr) |
296 |
| - print(f"Import error: {e}", file=sys.stderr) |
297 |
| - sys.exit(1) |
| 297 | + )(helion_method) |
298 | 298 |
|
299 |
| - # Monkey-patch the Operator class after import |
300 |
| - setattr(Operator, helion_method_name, create_helion_method(kernel_func)) |
| 299 | + # Set the decorated method on the Operator class |
| 300 | + setattr(Operator, helion_method_name, decorated_method) |
301 | 301 |
|
302 | 302 | print(
|
303 | 303 | f"Running {operator_name} benchmark with Helion implementation...\n",
|
|
0 commit comments