Skip to content

Commit a3f8d0f

Browse files
authored
Add basic smoke test to CI (#3)
1 parent a6d3149 commit a3f8d0f

File tree

7 files changed

+63
-10
lines changed

7 files changed

+63
-10
lines changed

.github/workflows/smoke-test.yml

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
name: Smoke Test
2+
3+
on:
4+
push:
5+
6+
jobs:
7+
smoke-test:
8+
runs-on: ubuntu-latest
9+
steps:
10+
- uses: actions/checkout@v4
11+
12+
- name: Set up Python
13+
uses: actions/setup-python@v4
14+
with:
15+
python-version: '3.x'
16+
17+
- name: Cache pip dependencies
18+
uses: actions/cache@v3
19+
with:
20+
path: ~/.cache/pip
21+
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
22+
restore-keys: |
23+
${{ runner.os }}-pip-
24+
25+
- name: Install dependencies
26+
run: |
27+
pip install -r requirements.txt
28+
29+
- name: Run smoke test
30+
run: |
31+
PYTHONPATH=. python scripts/main.py --suite smoke --backend aten

BackendBench/eval.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,27 @@ def eval_correctness(op, impl, tests):
4444
return correct / total
4545

4646

47+
def cpu_bench(fn, num_runs=100):
48+
"""Simple CPU benchmarking using time.perf_counter."""
49+
import time
50+
51+
for _ in range(10):
52+
fn()
53+
54+
start = time.perf_counter()
55+
for _ in range(num_runs):
56+
fn()
57+
return (time.perf_counter() - start) / num_runs
58+
59+
4760
def eval_performance(op, impl, tests):
48-
base_times = [do_bench(lambda: op(*test.args, **test.kwargs)) for test in tests]
49-
test_times = [do_bench(lambda: impl(*test.args, **test.kwargs)) for test in tests]
61+
if torch.cuda.is_available():
62+
base_times = [do_bench(lambda: op(*test.args, **test.kwargs)) for test in tests]
63+
test_times = [do_bench(lambda: impl(*test.args, **test.kwargs)) for test in tests]
64+
else:
65+
base_times = [cpu_bench(lambda: op(*test.args, **test.kwargs)) for test in tests]
66+
test_times = [cpu_bench(lambda: impl(*test.args, **test.kwargs)) for test in tests]
67+
5068
speedups = torch.tensor(test_times) / torch.tensor(base_times)
5169
# geometric mean of speedups
5270
return speedups.log().mean().exp()

BackendBench/opinfo_suite.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,7 @@ def build_op_tests(device, dtype, filter=None):
6767
if allclose(ref, res):
6868
op_indices[tracer.ops[0]].append(idx)
6969
except Exception:
70-
logger.debug(
71-
f"opinfo {op.name} couldn't run underlying op {tracer.ops[0]}"
72-
)
70+
logger.debug(f"opinfo {op.name} couldn't run underlying op {tracer.ops[0]}")
7371
else:
7472
logger.debug(f"opinfo {op.name} has {len(tracer.ops)} ops")
7573

BackendBench/suite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ def __iter__(self):
4242
OpTest(
4343
torch.ops.aten.relu.default,
4444
[
45-
Test(randn(2, device="cuda")),
45+
Test(randn(2, device="cpu")),
4646
],
4747
[
48-
Test(randn(2**28, device="cuda")),
48+
Test(randn(2**28, device="cpu")),
4949
],
5050
)
5151
],

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
[tool.ruff]
2+
line-length = 100
3+
4+
[tool.ruff.format]

requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
torch
2+
click
3+
numpy
4+
expecttest

scripts/main.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,7 @@ def cli(suite, backend, ops):
7070

7171
mean_correctness = torch.tensor(overall_correctness).mean().item()
7272
geomean_perf = torch.tensor(overall_performance).log().mean().exp().item()
73-
print(
74-
f"correctness score (mean pass rate over all operators): {mean_correctness:.2f}"
75-
)
73+
print(f"correctness score (mean pass rate over all operators): {mean_correctness:.2f}")
7674
print(f"performance score (geomean speedup over all operators): {geomean_perf:.2f}")
7775

7876

0 commit comments

Comments
 (0)