Skip to content

Commit 6d54c8f

Browse files
authored
[Benchmark] Fix tritonbench integration due to upstream changes (#278)
1 parent c68181a commit 6d54c8f

File tree

1 file changed

+41
-41
lines changed

1 file changed

+41
-41
lines changed

benchmarks/run.py

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# pyright: reportMissingImports=false
2+
13
"""Performance comparison between Helion, torch.compile, Triton, and PyTorch eager by leveraging TritonBench.
24
35
Currently supported kernels are listed in `KERNEL_MAPPINGS` in `benchmarks/run.py`.
@@ -242,62 +244,60 @@ def main() -> None:
242244
# Parse known args and collect unknown ones for operator
243245
tb_args, unknown_args = tb_parser.parse_known_args(tritonbench_args)
244246

245-
# Register the Helion kernel with tritonbench BEFORE importing the operator
246-
from tritonbench.utils.triton_op import ( # type: ignore[reportMissingImports]
247-
register_benchmark,
248-
)
247+
# Import and run the operator
248+
try:
249+
operator_module = importlib.import_module(tritonbench_module)
250+
Operator = operator_module.Operator
251+
except ImportError as e:
252+
print(
253+
f"Error: Could not import operator '{operator_name}' from tritonbench",
254+
file=sys.stderr,
255+
)
256+
print(f"Tried: {tritonbench_module}", file=sys.stderr)
257+
print(f"Import error: {e}", file=sys.stderr)
258+
sys.exit(1)
249259

250260
# Create the benchmark method
251-
def create_helion_method(
252-
kernel_func: Callable[..., Any],
261+
def helion_method(
262+
self: Any,
263+
*args: Any,
253264
) -> Callable[..., Any]:
254-
def helion_method(
255-
self: Any,
256-
*args: Any,
257-
) -> Callable[..., Any]:
258-
"""Helion implementation."""
265+
"""Helion implementation."""
259266

260-
# Reset all Helion kernels before creating the benchmark function
261-
# so that each input size can go through its own autotuning.
262-
from helion.runtime.kernel import Kernel
267+
# Reset all Helion kernels before creating the benchmark function
268+
# so that each input size can go through its own autotuning.
269+
from helion.runtime.kernel import Kernel
263270

264-
for attr_name in dir(module):
265-
attr = getattr(module, attr_name)
266-
if isinstance(attr, Kernel):
267-
attr.reset()
271+
for attr_name in dir(module):
272+
attr = getattr(module, attr_name)
273+
if isinstance(attr, Kernel):
274+
attr.reset()
268275

269-
def _inner() -> Callable[..., Any]:
270-
return kernel_func(*args)
276+
def _inner() -> Callable[..., Any]:
277+
return kernel_func(*args)
271278

272-
return _inner
279+
return _inner
273280

274-
return helion_method
275-
276-
# Register it as a benchmark first
281+
# Method name for the benchmark
277282
helion_method_name = f"helion_{kernel_name}"
278-
register_benchmark(
283+
284+
# Import register_benchmark API
285+
from tritonbench.utils.triton_op import ( # pyright: ignore[reportMissingImports]
286+
register_benchmark,
287+
)
288+
289+
# Use register_benchmark decorator
290+
decorated_method = register_benchmark(
279291
operator_name=operator_name,
280292
func_name=helion_method_name,
281293
baseline=False,
282294
enabled=True,
295+
fwd_only=False,
283296
label=helion_method_name,
284-
)
285-
286-
# Import and run the operator
287-
try:
288-
operator_module = importlib.import_module(tritonbench_module)
289-
Operator = operator_module.Operator
290-
except ImportError as e:
291-
print(
292-
f"Error: Could not import operator '{operator_name}' from tritonbench",
293-
file=sys.stderr,
294-
)
295-
print(f"Tried: {tritonbench_module}", file=sys.stderr)
296-
print(f"Import error: {e}", file=sys.stderr)
297-
sys.exit(1)
297+
)(helion_method)
298298

299-
# Monkey-patch the Operator class after import
300-
setattr(Operator, helion_method_name, create_helion_method(kernel_func))
299+
# Set the decorated method on the Operator class
300+
setattr(Operator, helion_method_name, decorated_method)
301301

302302
print(
303303
f"Running {operator_name} benchmark with Helion implementation...\n",

0 commit comments

Comments
 (0)