Skip to content

Commit 044b993

Browse files
authored
Add --topn option to select N largest shapes for each operator (#29)
1 parent 9d53e89 commit 044b993

File tree

3 files changed

+98
-15
lines changed

3 files changed

+98
-15
lines changed

BackendBench/torchbench_suite.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
Load aten inputs from serialized txt files.
33
"""
44

5-
import re
65
import math
6+
import re
77
from collections import defaultdict
88
from pathlib import Path
99

@@ -33,13 +33,14 @@
3333

3434

3535
def _deserialize_tensor(size, dtype, stride=None, device="cuda"):
36-
if stride is not None:
37-
out = torch.empty_strided(size, stride, dtype=dtype, device=device)
38-
else:
39-
out = torch.empty(size, dtype=dtype, device=device)
36+
kwargs = {}
4037
if dtype in _FLOATING_TYPES:
41-
return out.copy_(make_tensor(size, dtype=dtype, device=device, low=0, high=1))
42-
return out.copy_(make_tensor(size, dtype=dtype, device=device))
38+
kwargs.update({"low": 0, "high": 1})
39+
if stride is not None:
40+
extent = 1 + sum((size - 1) * stride for size, stride in zip(size, stride))
41+
data = make_tensor(extent, dtype=dtype, device=device, **kwargs)
42+
return data.as_strided(size, stride)
43+
return make_tensor(size, dtype=dtype, device=device, **kwargs)
4344

4445

4546
def _deserialize_args(inps):
@@ -63,20 +64,40 @@ def __init__(self, *args, **kwargs):
6364
self.kwargs = kwargs
6465

6566

67+
def _args_size(args):
68+
size = 0
69+
for arg in args:
70+
if isinstance(arg, torch.Tensor):
71+
size += arg.numel() * arg.element_size()
72+
elif isinstance(arg, (tuple, list)):
73+
size += _args_size(arg)
74+
return size
75+
76+
6677
class TorchBenchOpTest:
67-
def __init__(self, op, inputs):
78+
def __init__(self, op, inputs, topn):
6879
self.op = eval(f"torch.ops.{op}")
6980
self.inputs = inputs
81+
self.topn = topn
82+
83+
def tests(self):
84+
inputs_and_sizes = []
85+
for inp in self.inputs:
86+
args, kwargs = _deserialize_args(inp)
87+
size = _args_size(args) + _args_size(list(kwargs.values()))
88+
inputs_and_sizes.append((size, inp))
89+
ret = [x[1] for x in sorted(inputs_and_sizes, reverse=True)]
90+
return ret if self.topn is None else ret[: self.topn]
7091

7192
@property
7293
def correctness_tests(self):
73-
for inp in self.inputs:
94+
for inp in self.tests():
7495
args, kwargs = _deserialize_args(inp)
7596
yield TorchBenchTest(*args, **kwargs)
7697

7798
@property
7899
def performance_tests(self):
79-
for inp in self.inputs:
100+
for inp in self.tests():
80101
args, kwargs = _deserialize_args(inp)
81102
yield TorchBenchTest(*args, **kwargs)
82103

@@ -99,8 +120,9 @@ def _parse_inputs(filename, filter, op_inputs):
99120

100121

101122
class TorchBenchTestSuite:
102-
def __init__(self, name, filename, filter=None):
123+
def __init__(self, name, filename, filter=None, topn=None):
103124
self.name = name
125+
self.topn = topn
104126
self.optests = defaultdict(list)
105127
if Path(filename).is_dir():
106128
for file_path in Path(filename).glob("**/*.txt"):
@@ -113,7 +135,21 @@ def __init__(self, name, filename, filter=None):
113135

114136
def __iter__(self):
115137
for op, inputs in self.optests.items():
116-
if any(s in op for s in ["embedding", "scatter", "gather", "index", "nll_loss"]):
138+
if any(
139+
s in op
140+
for s in [
141+
"embedding",
142+
"scatter",
143+
"gather",
144+
"index",
145+
"nll_loss",
146+
"im2col_backward",
147+
"col2im_backward",
148+
"native_layer_norm_backward",
149+
"upsample_nearest2d_backward.vec",
150+
"upsample_bilinear2d_backward.vec",
151+
]
152+
):
117153
# TODO: indexing ops need valid indices
118154
continue
119-
yield TorchBenchOpTest(op, inputs)
155+
yield TorchBenchOpTest(op, inputs, self.topn)

scripts/main.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
import BackendBench.eval as eval
88
import click
99
import torch
10+
from BackendBench.llm_client import ClaudeKernelGenerator
1011
from BackendBench.opinfo_suite import OpInfoTestSuite
11-
from BackendBench.torchbench_suite import TorchBenchTestSuite
1212
from BackendBench.suite import SmokeTestSuite
13-
from BackendBench.llm_client import ClaudeKernelGenerator
13+
from BackendBench.torchbench_suite import TorchBenchTestSuite
1414

1515
logger = logging.getLogger(__name__)
1616

@@ -53,6 +53,13 @@ def setup_logging(log_level):
5353
type=str,
5454
help="Comma-separated list of ops to run",
5555
)
56+
@click.option(
57+
"--topn-inputs",
58+
"--topn",
59+
default=None,
60+
type=int,
61+
help="Select the top N largest inputs for each op (default: all inputs)",
62+
)
5663
@click.option(
5764
"--llm-max-attempts",
5865
default=5,
@@ -82,6 +89,7 @@ def cli(
8289
suite,
8390
backend,
8491
ops,
92+
topn_inputs,
8593
llm_max_attempts,
8694
kernel_agent_workers,
8795
kernel_agent_max_rounds,
@@ -122,6 +130,7 @@ def cli(
122130
"torchbench",
123131
torchbench_data_path,
124132
filter=ops,
133+
topn=topn_inputs,
125134
),
126135
}[suite]()
127136

test/test_torchbench_suite.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
from BackendBench.torchbench_suite import TorchBenchOpTest
3+
4+
5+
class TestOpTest:
6+
def test_op_test(self):
7+
op_test = TorchBenchOpTest(
8+
"aten.relu.default", ["((T([32, 128, 512], f16, None, 'cpu'),), {})"], None
9+
)
10+
for test in op_test.correctness_tests:
11+
args, kwargs = test.args, test.kwargs
12+
arg, *extras = args
13+
assert arg.shape == torch.Size([32, 128, 512])
14+
assert arg.dtype == torch.float16
15+
assert kwargs == {}
16+
assert extras == []
17+
18+
torch.testing.assert_close(torch.relu(arg), op_test.op(arg))
19+
20+
def test_topn(self):
21+
op_test = TorchBenchOpTest(
22+
"aten.relu.default",
23+
[
24+
"((T([32, 128, 512], f16, None, 'cpu'),), {})",
25+
"((T([32, 256, 512], f16, None, 'cpu'),), {})",
26+
],
27+
1,
28+
)
29+
assert len(op_test.tests()) == 1
30+
for test in op_test.correctness_tests:
31+
args, kwargs = test.args, test.kwargs
32+
arg, *extras = args
33+
assert arg.shape == torch.Size([32, 256, 512])
34+
assert arg.dtype == torch.float16
35+
assert kwargs == {}
36+
assert extras == []
37+
38+
torch.testing.assert_close(torch.relu(arg), op_test.op(arg))

0 commit comments

Comments
 (0)