Skip to content

Commit 24d453d

Browse files
authored
[Benchmark] Move per-operator settings from example file to benchmarks/run.py (#403)
1 parent 8727491 commit 24d453d

File tree

4 files changed

+52
-41
lines changed

4 files changed

+52
-41
lines changed

benchmarks/run.py

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,12 @@
3030

3131
# Maps tritonbench op names to Helion kernel examples
3232
# Can map to a single kernel or a list of kernel variants
33-
KERNEL_MAPPINGS: dict[str, tuple[str, str, str] | tuple[str, list[tuple[str, str]]]] = {
33+
# Format options:
34+
# - Single kernel: (tritonbench_module, helion_module, helion_func)
35+
# - Single kernel with args: (tritonbench_module, helion_module, helion_func, args_dict)
36+
# - Multiple kernels: (tritonbench_module, [(helion_module, helion_func), ...])
37+
# - Multiple kernels with args: (tritonbench_module, [(helion_module, helion_func), ...], args_dict)
38+
KERNEL_MAPPINGS: dict[str, tuple[str, ...]] = { # pyright: ignore[reportAssignmentType]
3439
# <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
3540
"vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"),
3641
"embedding": (
@@ -47,6 +52,9 @@
4752
"tritonbench.operators.rms_norm.operator",
4853
"examples.rms_norm",
4954
"rms_norm_tritonbench",
55+
{
56+
"num_inputs": 3
57+
}, # TODO(yf225): reduction dim size = 8192 currently throws error
5058
),
5159
"sum": ("tritonbench.operators.sum.operator", "examples.sum", "sum_tritonbench"),
5260
"softmax": (
@@ -58,6 +66,9 @@
5866
"tritonbench.operators.jagged_mean.operator",
5967
"examples.jagged_mean",
6068
"jagged_mean_tritonbench",
69+
{"B": 32, "M": 8, "seqlen": 64}
70+
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1"
71+
else {},
6172
),
6273
"fp8_gemm": (
6374
"tritonbench.operators.fp8_gemm.fp8_gemm",
@@ -68,11 +79,17 @@
6879
"tritonbench.operators.flash_attention.operator",
6980
"examples.attention",
7081
"attention",
82+
{
83+
"d_head": 128
84+
}, # Set default head dimension to 128 for TLX attention compatibility
7185
),
7286
"cross_entropy": (
7387
"tritonbench.operators.cross_entropy.operator",
7488
"examples.cross_entropy",
7589
"cross_entropy",
90+
{"B": 4, "T": 512, "v_range": "10,15"}
91+
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1"
92+
else {},
7693
),
7794
"fp8_attention": (
7895
"tritonbench.operators.fp8_attention.operator",
@@ -233,20 +250,40 @@ def run_kernel(
233250

234251
mapping = KERNEL_MAPPINGS[kernel_name]
235252

253+
# Extract operator args if present
254+
operator_args = {}
255+
236256
# Normalize to list of variants format
237-
if len(mapping) == 2 and isinstance(mapping[1], list):
238-
# Multiple variants with shared tritonbench module
257+
if isinstance(mapping[1], list):
258+
# Multiple variants format
239259
tritonbench_module = mapping[0]
240260
variants = mapping[1]
261+
# Check if last element is args dict
262+
if len(mapping) > 2 and isinstance(mapping[2], dict):
263+
operator_args = mapping[2]
241264
else:
242-
# Single kernel with full mapping - convert to list format
243-
assert len(mapping) == 3 # Type narrowing for pyright
244-
tritonbench_module, module_path, func_name = mapping
245-
variants = [(module_path, func_name)]
265+
# Single kernel format
266+
if len(mapping) == 4 and isinstance(mapping[3], dict):
267+
# With args
268+
tritonbench_module = mapping[0]
269+
module_path = mapping[1]
270+
func_name = mapping[2]
271+
operator_args = mapping[3] # pyright: ignore[reportGeneralTypeIssues]
272+
variants = [(module_path, func_name)]
273+
else:
274+
# Without args
275+
assert len(mapping) == 3 # Type narrowing for pyright
276+
tritonbench_module, module_path, func_name = mapping
277+
variants = [(module_path, func_name)]
246278

247279
# Run all variants in the same benchmark
248280
run_kernel_variants(
249-
kernel_name, tritonbench_module, variants, tritonbench_args, input_shard_info
281+
kernel_name,
282+
tritonbench_module,
283+
variants,
284+
tritonbench_args,
285+
input_shard_info,
286+
operator_args,
250287
)
251288

252289

@@ -256,6 +293,7 @@ def run_kernel_variants(
256293
variants: list[tuple[str, str]],
257294
tritonbench_args: list[str],
258295
input_shard_info: tuple[int, int] | None = None,
296+
operator_args: dict[str, Any] | None = None,
259297
) -> None:
260298
"""Run kernel variants in the same benchmark run."""
261299

@@ -280,21 +318,12 @@ def run_kernel_variants(
280318
assert "--op" not in tritonbench_args
281319
tritonbench_args = ["--op", operator_name, *tritonbench_args]
282320

283-
# Collect all module args from all variants
284-
all_module_args = {}
285-
for module_path, _ in variants:
286-
try:
287-
module = importlib.import_module(module_path)
288-
module_args = getattr(module, "TRITONBENCH_ARGS", {})
289-
all_module_args.update(module_args)
290-
except ImportError:
291-
pass
292-
293-
# Add module args to tritonbench_args if not already present
294-
for arg_name, arg_value in all_module_args.items():
295-
arg_flag = f"--{arg_name.replace('_', '-')}"
296-
if arg_flag not in tritonbench_args:
297-
tritonbench_args.extend([arg_flag, str(arg_value)])
321+
# Add operator-specific default args if provided
322+
if operator_args:
323+
for arg_name, arg_value in operator_args.items():
324+
arg_flag = f"--{arg_name.replace('_', '-')}"
325+
if arg_flag not in tritonbench_args:
326+
tritonbench_args.extend([arg_flag, str(arg_value)])
298327

299328
# Parse known args and collect unknown ones for operator
300329
tb_args, unknown_args = tb_parser.parse_known_args(tritonbench_args)

examples/cross_entropy.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,11 @@
11
from __future__ import annotations
22

3-
import os
4-
53
import torch
64

75
import helion
86
from helion._testing import run_example
97
import helion.language as hl
108

11-
# TritonBench configuration - adjust based on HELION_DEV_LOW_VRAM environment variable
12-
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1":
13-
# Low memory configuration
14-
TRITONBENCH_ARGS = {"B": 4, "T": 512, "v_range": "10,15"}
15-
169

1710
@helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper])
1811
def cross_entropy(

examples/jagged_mean.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,11 @@
11
from __future__ import annotations
22

3-
import os
4-
53
import torch
64

75
import helion
86
from helion._testing import run_example
97
import helion.language as hl
108

11-
# TritonBench configuration - adjust based on HELION_DEV_LOW_VRAM environment variable
12-
if os.environ.get("HELION_DEV_LOW_VRAM", "0") == "1":
13-
# Low memory configuration
14-
TRITONBENCH_ARGS = {"B": 32, "M": 8, "seqlen": 64}
15-
169

1710
@helion.kernel()
1811
def jagged_mean_kernel(

examples/rms_norm.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,6 @@
66
from helion._testing import run_example
77
import helion.language as hl
88

9-
# TritonBench configuration
10-
# TODO(yf225): reduction dim size = 8192 currently throws error. After it's fixed we can remove "num_inputs" extra arg.
11-
TRITONBENCH_ARGS = {"num_inputs": 3}
12-
139

1410
@helion.kernel(static_shapes=True)
1511
def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:

0 commit comments

Comments
 (0)