Skip to content

Commit 13af4b3

Browse files
authored
Add unit test for waterfalling errors (#73)
No op
1 parent 95e1a03 commit 13af4b3

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

test/test_adverse_cases.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import pytest
2+
from BackendBench.torchbench_suite import TorchBenchOpTest
3+
from BackendBench.eval import eval_one_op
4+
import BackendBench.backends as backends
5+
import torch
6+
7+
8+
class TestAdaptiveAvgPool2dBackward:
9+
# todo: @jiannanWang unskip this test
10+
@pytest.mark.skip(reason="Not ready for testing yet as it'd brick the gpu")
11+
def test_adaptive_avg_pool2d_backward_gpu(self):
12+
"""Test on GPU with eval_one_op."""
13+
op_test_should_error = TorchBenchOpTest(
14+
"aten._adaptive_avg_pool2d_backward.default",
15+
["((T([512, 4096, 56, 56], f16), T([512, 4096, 56, 56], f16)), {})"],
16+
None,
17+
)
18+
19+
op_test_should_succeed = TorchBenchOpTest(
20+
"aten.addmm.default",
21+
["((T([14, 14], f32), T([14, 14], f32), T([14, 14], f32)), {})"],
22+
None,
23+
)
24+
25+
# run test that should brick the gpu due to an illegal memory access
26+
backend = backends.AtenBackend()
27+
with pytest.raises(RuntimeError):
28+
_, _ = eval_one_op(
29+
op_test_should_error.op,
30+
backend[op_test_should_error.op],
31+
list(op_test_should_error.correctness_tests),
32+
list(op_test_should_error.performance_tests),
33+
)
34+
35+
# add these in case code changes in eval_one_op. There shouldn't be any errors here
36+
torch.cuda.synchronize()
37+
torch.cuda.empty_cache()
38+
39+
# tests that a simple op works afterwards to make sure we recover after an illegal memory access
40+
correctness, _ = eval_one_op(
41+
op_test_should_succeed.op,
42+
backend[op_test_should_succeed.op],
43+
list(op_test_should_succeed.correctness_tests),
44+
list(op_test_should_succeed.performance_tests),
45+
)
46+
47+
assert correctness == 1.0
48+
49+
50+
if __name__ == "__main__":
51+
pytest.main([__file__, "-v", "-s"])

0 commit comments

Comments
 (0)