Skip to content

Commit a91908a

Browse files
committed
Initial commit
0 parents  commit a91908a

File tree

7 files changed

+587
-0
lines changed

7 files changed

+587
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__pycache__/

BackendBench/backends.py

Lines changed: 280 additions & 0 deletions
Large diffs are not rendered by default.

BackendBench/eval.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import logging
2+
3+
import torch
4+
from triton.testing import do_bench
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
def allclose(a, b):
10+
if isinstance(a, torch.Tensor):
11+
torch.testing.assert_close(a, b, equal_nan=True)
12+
return True
13+
if isinstance(a, (list, tuple)):
14+
return all(allclose(x, y) for x, y in zip(a, b))
15+
return a == b
16+
17+
18+
EXC_MSG = """
19+
Exception raised for {op}:
20+
args: {args}
21+
kwargs: {kwargs}
22+
exc: {exc}
23+
"""
24+
25+
26+
def eval_correctness_test(op, impl, test):
27+
"""Evaluate impl of op against test."""
28+
args, kwargs = test.args, test.kwargs
29+
ref = op(*args, **kwargs)
30+
try:
31+
res = impl(*args, **kwargs)
32+
return allclose(ref, res)
33+
except Exception as e:
34+
logger.debug(EXC_MSG.format(op=op, args=args, kwargs=kwargs, exc=e))
35+
return False
36+
37+
38+
def eval_correctness(op, impl, tests):
39+
correct, total = 0, 0
40+
for test in tests:
41+
if eval_correctness_test(op, impl, test):
42+
correct += 1
43+
total += 1
44+
return correct / total
45+
46+
47+
def eval_performance(op, impl, tests):
48+
base_times = [do_bench(lambda: op(*test.args, **test.kwargs)) for test in tests]
49+
test_times = [do_bench(lambda: impl(*test.args, **test.kwargs)) for test in tests]
50+
speedups = torch.tensor(test_times) / torch.tensor(base_times)
51+
# geometric mean of speedups
52+
return speedups.log().mean().exp()
53+
54+
55+
def eval_one_op(op, impl, correctness_tests, performance_tests):
56+
"""Evaluate impl of op against correctness_tests and performance_tests."""
57+
return eval_correctness(op, impl, correctness_tests), eval_performance(
58+
op, impl, performance_tests
59+
)

BackendBench/opinfo_suite.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import logging
2+
from collections import defaultdict
3+
4+
import torch
5+
from torch.testing._internal.common_methods_invocations import op_db
6+
from torch.utils._python_dispatch import TorchDispatchMode
7+
8+
from .eval import allclose
9+
from .suite import OpTest, Test, TestSuite
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class OpInfoTest:
15+
def __init__(self, *args, **kwargs):
16+
self.args = args
17+
self.kwargs = kwargs
18+
19+
20+
class OpInfoOpTest(OpTest):
21+
def __init__(self, op, correctness_tests, indices):
22+
self.op = op
23+
self._correctness_tests = correctness_tests
24+
self.indices = set(indices)
25+
self.performance_tests = []
26+
27+
@property
28+
def correctness_tests(self):
29+
for idx, test in enumerate(self._correctness_tests):
30+
if idx in self.indices:
31+
# print(f"{idx} {test.input=} {test.args=} {test.kwargs=}")
32+
yield OpInfoTest(test.input, *test.args, **test.kwargs)
33+
34+
35+
class OpTracerMode(TorchDispatchMode):
36+
def __init__(self):
37+
self.ops = []
38+
self.args = []
39+
self.kwargs = []
40+
41+
def __torch_dispatch__(self, fn, types, args=(), kwargs={}):
42+
self.ops.append(fn)
43+
self.args.append(args)
44+
self.kwargs.append(kwargs)
45+
return fn(*args, **kwargs)
46+
47+
48+
def build_op_tests(device, dtype, filter=None):
49+
op_info_op_tests = []
50+
for op in op_db:
51+
if filter and op.name not in filter:
52+
continue
53+
if "." in op.name and "nn.functional" not in op.name:
54+
continue
55+
if dtype not in op.supported_dtypes(device):
56+
continue
57+
if op.name in ["nonzero_static"]:
58+
continue
59+
60+
op_indices = defaultdict(list)
61+
for idx, test in enumerate(op.sample_inputs(device, dtype)):
62+
# print(f"{idx=} {test.input=} {test.args=} {test.kwargs=}")
63+
with OpTracerMode() as tracer:
64+
ref = op.op(test.input, *test.args, **test.kwargs)
65+
if len(tracer.ops) == 1:
66+
try:
67+
res = tracer.ops[0](test.input, *test.args, **test.kwargs)
68+
if allclose(ref, res):
69+
op_indices[tracer.ops[0]].append(idx)
70+
except Exception:
71+
logger.debug(
72+
f"opinfo {op.name} couldn't run underlying op {tracer.ops[0]}"
73+
)
74+
else:
75+
logger.debug(f"opinfo {op.name} has {len(tracer.ops)} ops")
76+
77+
for overload, indices in op_indices.items():
78+
if len(indices) > 0:
79+
op_info_op_tests.append(
80+
OpInfoOpTest(overload, op.sample_inputs(device, dtype), indices)
81+
)
82+
83+
return op_info_op_tests
84+
85+
86+
class OpInfoTestSuite(TestSuite):
87+
def __init__(self, name, device, dtype, filter=None):
88+
super().__init__(name, build_op_tests(device, dtype, filter))

BackendBench/suite.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
3+
4+
def randn(*args, **kwargs):
5+
return lambda: torch.randn(*args, **kwargs)
6+
7+
8+
class Test:
9+
def __init__(self, *args, **kwargs):
10+
self._args = args
11+
self._kwargs = kwargs
12+
13+
@property
14+
def args(self):
15+
return [arg() for arg in self._args]
16+
17+
@property
18+
def kwargs(self):
19+
return {k: v() for k, v in self._kwargs.items()}
20+
21+
22+
class OpTest:
23+
def __init__(self, op, correctness_tests, performance_tests):
24+
self.op = op
25+
self.correctness_tests = correctness_tests
26+
self.performance_tests = performance_tests
27+
28+
29+
class TestSuite:
30+
def __init__(self, name, optests):
31+
self.name = name
32+
self.optests = optests
33+
34+
def __iter__(self):
35+
for optest in self.optests:
36+
yield optest
37+
38+
39+
SmokeTestSuite = TestSuite(
40+
"smoke",
41+
[
42+
OpTest(
43+
torch.ops.aten.relu.default,
44+
[
45+
Test(randn(2, device="cuda")),
46+
],
47+
[
48+
Test(randn(2**28, device="cuda")),
49+
],
50+
)
51+
],
52+
)

README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Usage:
2+
3+
Run a simple smoke test (relu) with the default ATen backend:
4+
```bash
5+
python scripts/main.py --suite smoke --backend aten
6+
```
7+
8+
Run the smoke test with FlagGems:
9+
```bash
10+
python scripts/main.py --suite smoke --backend flag_gems
11+
```
12+
13+
Run opinfo tests (correctness only) with ATen
14+
```bash
15+
python scripts/main.py --suite opinfo --backend aten
16+
```
17+
18+
Run a filtered set of opinfo tests with FlagGems
19+
```bash
20+
python scripts/main.py --suite opinfo --backend flag_gems --ops "add,sub"
21+
```
22+
23+
Run all the opinfo tests with FlagGems (takes a few minutes)
24+
```bash
25+
python scripts/main.py --suite opinfo --backend flag_gems
26+
```

scripts/main.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import logging
2+
import sys
3+
4+
import BackendBench.backends as backends
5+
import BackendBench.eval as eval
6+
import click
7+
import torch
8+
from BackendBench.opinfo_suite import OpInfoTestSuite
9+
from BackendBench.suite import SmokeTestSuite
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
@click.command()
15+
@click.option(
16+
"--suite",
17+
default="smoke",
18+
type=click.Choice(["smoke", "opinfo"]),
19+
help="Which suite to run",
20+
)
21+
@click.option(
22+
"--backend",
23+
default="aten",
24+
type=click.Choice(["aten", "flag_gems"]),
25+
help="Which backend to run",
26+
)
27+
@click.option(
28+
"--ops",
29+
default=None,
30+
type=str,
31+
help="Comma-separated list of ops to run",
32+
)
33+
def cli(suite, backend, ops):
34+
if ops:
35+
ops = ops.split(",")
36+
37+
backend = {
38+
"aten": backends.AtenBackend,
39+
"flag_gems": backends.FlagGemsBackend,
40+
}[backend]()
41+
42+
suite = {
43+
"smoke": lambda: SmokeTestSuite,
44+
"opinfo": lambda: OpInfoTestSuite(
45+
"opinfo_cuda_bfloat16",
46+
"cuda",
47+
torch.bfloat16,
48+
filter=ops,
49+
),
50+
}[suite]()
51+
52+
overall_correctness = []
53+
overall_performance = []
54+
55+
for test in suite:
56+
if test.op not in backend:
57+
continue
58+
59+
logger.debug(test.op)
60+
61+
correctness, perf = eval.eval_one_op(
62+
test.op,
63+
backend[test.op],
64+
test.correctness_tests,
65+
test.performance_tests,
66+
)
67+
overall_correctness.append(correctness)
68+
overall_performance.append(perf)
69+
70+
logger.debug(f"max memory allocated: {torch.cuda.max_memory_allocated():,}")
71+
72+
mean_correctness = torch.tensor(overall_correctness).mean().item()
73+
geomean_perf = torch.tensor(overall_performance).log().mean().exp().item()
74+
print(
75+
f"correctness score (mean pass rate over all operators): {mean_correctness:.2f}"
76+
)
77+
print(f"performance score (geomean speedup over all operators): {geomean_perf:.2f}")
78+
79+
80+
if __name__ == "__main__":
81+
cli()

0 commit comments

Comments
 (0)