Skip to content

Commit ac36c08

Browse files
authored
Suite of ops/shapes scraped from TorchBench (#26)
1 parent 7dd105f commit ac36c08

File tree

4 files changed

+213
-22
lines changed

4 files changed

+213
-22
lines changed

BackendBench/backends.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,21 @@ def __contains__(self, key):
1919
return True
2020

2121

22+
def _flag_gems_softmax(*args, **kwargs):
23+
# half_to_float is not supported in flag_gems
24+
import flag_gems
25+
26+
return flag_gems.ops.softmax(*args[:-1], **kwargs)
27+
28+
29+
def _flag_gems_layernorm(*args, **kwargs):
30+
import flag_gems
31+
32+
x, m, v = flag_gems.ops.layer_norm(*args[:-1], **kwargs)
33+
mv_shape = [*x.shape[:-1], 1]
34+
return x, m.view(*mv_shape), v.view(*mv_shape)
35+
36+
2237
class FlagGemsBackend(Backend):
2338
def __init__(self) -> None:
2439
super().__init__("flaggems")
@@ -121,7 +136,7 @@ def __init__(self) -> None:
121136
torch.ops.aten.isnan.default: flag_gems.ops.isnan,
122137
torch.ops.aten.minimum.default: flag_gems.ops.minimum,
123138
torch.ops.aten.maximum.default: flag_gems.ops.maximum,
124-
torch.ops.aten.native_layer_norm.default: flag_gems.ops.layer_norm,
139+
torch.ops.aten.native_layer_norm.default: _flag_gems_layernorm,
125140
torch.ops.aten.native_layer_norm_backward.default: flag_gems.ops.layer_norm_backward,
126141
torch.ops.aten.le.Tensor: flag_gems.ops.le,
127142
torch.ops.aten.le.Scalar: flag_gems.ops.le_scalar,
@@ -177,7 +192,7 @@ def __init__(self) -> None:
177192
torch.ops.aten.silu_backward.default: flag_gems.ops.silu_backward,
178193
torch.ops.aten.sin.default: flag_gems.ops.sin,
179194
torch.ops.aten.sin_.default: flag_gems.ops.sin_,
180-
torch.ops.aten._softmax.default: flag_gems.ops.softmax,
195+
torch.ops.aten._softmax.default: _flag_gems_softmax,
181196
torch.ops.aten._softmax_backward_data.default: flag_gems.ops.softmax_backward,
182197
torch.ops.aten.sort.default: flag_gems.ops.sort,
183198
torch.ops.aten.sub.Tensor: flag_gems.ops.sub,

BackendBench/eval.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,44 @@
22

33
import torch
44

5-
from triton.testing import do_bench
5+
import triton.testing
66

77

88
logger = logging.getLogger(__name__)
99

10+
EXC_MSG = """
11+
Exception raised for {op}:
12+
args: {args}
13+
kwargs: {kwargs}
14+
exc: {exc}
15+
"""
16+
17+
18+
def format_tensor(t):
19+
return f"{t.dtype}{list(t.shape)}"
20+
21+
22+
def format_args(args):
23+
return [format_tensor(arg) if isinstance(arg, torch.Tensor) else arg for arg in args]
24+
25+
26+
def format_kwargs(kwargs):
27+
return {k: format_tensor(v) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
28+
29+
30+
def format_exception(e, op, args, kwargs):
31+
return EXC_MSG.format(op=op, args=format_args(args), kwargs=format_kwargs(kwargs), exc=e)
32+
1033

1134
def allclose(a, b):
1235
if isinstance(a, torch.Tensor):
13-
torch.testing.assert_close(a, b, equal_nan=True)
36+
torch.testing.assert_close(a, b, equal_nan=True, atol=1e-2, rtol=1e-2)
1437
return True
1538
if isinstance(a, (list, tuple)):
1639
return all(allclose(x, y) for x, y in zip(a, b))
1740
return a == b
1841

1942

20-
EXC_MSG = """
21-
Exception raised for {op}:
22-
args: {args}
23-
kwargs: {kwargs}
24-
exc: {exc}
25-
"""
26-
27-
2843
def eval_correctness_test(op, impl, test):
2944
"""Evaluate impl of op against test."""
3045
args, kwargs = test.args, test.kwargs
@@ -33,13 +48,16 @@ def eval_correctness_test(op, impl, test):
3348
res = impl(*args, **kwargs)
3449
return allclose(ref, res)
3550
except Exception as e:
36-
logger.debug(EXC_MSG.format(op=op, args=args, kwargs=kwargs, exc=e))
51+
logger.warning(format_exception(e, op, args, kwargs))
3752
return False
3853

3954

4055
def eval_correctness(op, impl, tests):
4156
correct, total = 0, 0
4257
for test in tests:
58+
logging.debug(
59+
f"Testing {op.__name__} with args {format_args(test.args)} and kwargs {format_kwargs(test.kwargs)}"
60+
)
4361
if eval_correctness_test(op, impl, test):
4462
correct += 1
4563
total += 1
@@ -60,13 +78,20 @@ def cpu_bench(fn, num_runs=100):
6078

6179

6280
def eval_performance(op, impl, tests):
63-
if torch.cuda.is_available():
64-
base_times = [do_bench(lambda: op(*test.args, **test.kwargs)) for test in tests]
65-
test_times = [do_bench(lambda: impl(*test.args, **test.kwargs)) for test in tests]
66-
else:
67-
base_times = [cpu_bench(lambda: op(*test.args, **test.kwargs)) for test in tests]
68-
test_times = [cpu_bench(lambda: impl(*test.args, **test.kwargs)) for test in tests]
69-
81+
bench_fn = triton.testing.do_bench if torch.cuda.is_available() else cpu_bench
82+
base_times = []
83+
test_times = []
84+
for test in tests:
85+
logging.debug(
86+
f"Benchmarking {op.__name__} with args {format_args(test.args)} and kwargs {format_kwargs(test.kwargs)}"
87+
)
88+
base_times.append(bench_fn(lambda: op(*test.args, **test.kwargs)))
89+
try:
90+
allclose(op(*test.args, **test.kwargs), impl(*test.args, **test.kwargs))
91+
except Exception:
92+
test_times.append(base_times[-1])
93+
continue
94+
test_times.append(bench_fn(lambda: impl(*test.args, **test.kwargs)))
7095
speedups = torch.tensor(test_times) / torch.tensor(base_times)
7196
return speedups.log().mean().exp()
7297

BackendBench/torchbench_suite.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""
2+
Load aten inputs from serialized txt files.
3+
"""
4+
5+
import re
6+
import math
7+
from collections import defaultdict
8+
from pathlib import Path
9+
10+
import torch
11+
from torch.testing import make_tensor
12+
13+
14+
dtype_abbrs = {
15+
torch.bfloat16: "bf16",
16+
torch.float64: "f64",
17+
torch.float32: "f32",
18+
torch.float16: "f16",
19+
torch.complex32: "c32",
20+
torch.complex64: "c64",
21+
torch.complex128: "c128",
22+
torch.int8: "i8",
23+
torch.int16: "i16",
24+
torch.int32: "i32",
25+
torch.int64: "i64",
26+
torch.bool: "b8",
27+
torch.uint8: "u8",
28+
}
29+
30+
dtype_abbrs_parsing = {value: key for key, value in dtype_abbrs.items()}
31+
32+
_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
33+
34+
35+
def _deserialize_tensor(size, dtype, stride=None, device="cuda"):
36+
if stride is not None:
37+
out = torch.empty_strided(size, stride, dtype=dtype, device=device)
38+
else:
39+
out = torch.empty(size, dtype=dtype, device=device)
40+
if dtype in _FLOATING_TYPES:
41+
return out.copy_(make_tensor(size, dtype=dtype, device=device, low=0, high=1))
42+
return out.copy_(make_tensor(size, dtype=dtype, device=device))
43+
44+
45+
def _deserialize_args(inps):
46+
inps = inps.strip().strip("'")
47+
global_vals = {
48+
"T": _deserialize_tensor,
49+
"th": torch,
50+
"inf": math.inf,
51+
"torch": torch,
52+
**dtype_abbrs_parsing,
53+
}
54+
# f strings introduce quotations we dont want
55+
for key in dtype_abbrs_parsing:
56+
inps = inps.replace(f"'{key}'", key)
57+
return eval(inps.strip().strip("'").strip('"'), global_vals)
58+
59+
60+
class TorchBenchTest:
61+
def __init__(self, *args, **kwargs):
62+
self.args = args
63+
self.kwargs = kwargs
64+
65+
66+
class TorchBenchOpTest:
67+
def __init__(self, op, inputs):
68+
self.op = eval(f"torch.ops.{op}")
69+
self.inputs = inputs
70+
71+
@property
72+
def correctness_tests(self):
73+
for inp in self.inputs:
74+
args, kwargs = _deserialize_args(inp)
75+
yield TorchBenchTest(*args, **kwargs)
76+
77+
@property
78+
def performance_tests(self):
79+
for inp in self.inputs:
80+
args, kwargs = _deserialize_args(inp)
81+
yield TorchBenchTest(*args, **kwargs)
82+
83+
84+
def _parse_inputs(filename, filter, op_inputs):
85+
op = None
86+
87+
with open(filename, "r") as f:
88+
for line in f:
89+
if m := re.match("Operator: (.*)", line):
90+
op = m.group(1)
91+
if op == "aten.sum.SymInt":
92+
op = "aten.sum.dim_IntList"
93+
if m := re.match("cnt: \\d+, (.*)", line):
94+
assert op is not None
95+
args = m.group(1)
96+
if filter is None or any(f in op for f in filter):
97+
op_inputs[op].append(args)
98+
return op_inputs
99+
100+
101+
class TorchBenchTestSuite:
102+
def __init__(self, name, filename, filter=None):
103+
self.name = name
104+
self.optests = defaultdict(list)
105+
if Path(filename).is_dir():
106+
for file_path in Path(filename).glob("**/*.txt"):
107+
_parse_inputs(str(file_path), filter, self.optests)
108+
else:
109+
_parse_inputs(filename, filter, self.optests)
110+
# Deduplicate the strings in self.optests
111+
for op in self.optests:
112+
self.optests[op] = list(set(self.optests[op]))
113+
114+
def __iter__(self):
115+
for op, inputs in self.optests.items():
116+
if any(s in op for s in ["embedding", "scatter", "gather", "index", "nll_loss"]):
117+
# TODO: indexing ops need valid indices
118+
continue
119+
yield TorchBenchOpTest(op, inputs)

scripts/main.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,37 @@
88
import click
99
import torch
1010
from BackendBench.opinfo_suite import OpInfoTestSuite
11+
from BackendBench.torchbench_suite import TorchBenchTestSuite
1112
from BackendBench.suite import SmokeTestSuite
1213
from BackendBench.llm_client import ClaudeKernelGenerator
1314

1415
logger = logging.getLogger(__name__)
1516

1617

18+
def setup_logging(log_level):
19+
"""Configure logging with the specified level."""
20+
numeric_level = getattr(logging, log_level.upper(), None)
21+
if not isinstance(numeric_level, int):
22+
raise ValueError(f"Invalid log level: {log_level}")
23+
24+
logging.basicConfig(
25+
level=numeric_level,
26+
format="[%(asctime)s][%(levelname)s][%(filename)s] %(message)s",
27+
datefmt="%Y-%m-%d %H:%M:%S",
28+
)
29+
30+
1731
@click.command()
32+
@click.option(
33+
"--log-level",
34+
default=os.getenv("LOG_LEVEL", "INFO"),
35+
type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], case_sensitive=False),
36+
help="Set the logging level",
37+
)
1838
@click.option(
1939
"--suite",
2040
default="smoke",
21-
type=click.Choice(["smoke", "opinfo"]),
41+
type=click.Choice(["smoke", "opinfo", "torchbench"]),
2242
help="Which suite to run",
2343
)
2444
@click.option(
@@ -39,7 +59,14 @@
3959
type=int,
4060
help="Maximum attempts for LLM kernel generation with feedback",
4161
)
42-
def cli(suite, backend, ops, llm_max_attempts):
62+
@click.option(
63+
"--torchbench-data-path",
64+
default="third_party/tritonbench/tritonbench/data/input_configs",
65+
type=str,
66+
help="Path to TorchBench operator data",
67+
)
68+
def cli(log_level, suite, backend, ops, llm_max_attempts, torchbench_data_path):
69+
setup_logging(log_level)
4370
if ops:
4471
ops = ops.split(",")
4572

@@ -62,6 +89,11 @@ def cli(suite, backend, ops, llm_max_attempts):
6289
torch.bfloat16,
6390
filter=ops,
6491
),
92+
"torchbench": lambda: TorchBenchTestSuite(
93+
"torchbench",
94+
torchbench_data_path,
95+
filter=ops,
96+
),
6597
}[suite]()
6698

6799
overall_correctness = []

0 commit comments

Comments
 (0)