Skip to content

Commit 84c47ba

Browse files
GregoryComerConarnar
authored andcommitted
[Backend Tester] Clean up operator test logic (pytorch#12736)
Minor refactoring on the operator test logic - since we now have separate model tests, I've moved operator-test specific test helper logic into the operators directory and updated usages. I also updated discovery slightly to give nicer error messages when things don't import.
1 parent 9112fd6 commit 84c47ba

21 files changed

+260
-175
lines changed

backends/test/suite/__init__.py

Lines changed: 1 addition & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,11 @@
99

1010
import logging
1111
import os
12-
import unittest
13-
14-
from enum import Enum
15-
from typing import Callable
1612

1713
import executorch.backends.test.suite.flow
1814

19-
import torch
20-
from executorch.backends.test.suite.context import get_active_test_context, TestContext
2115
from executorch.backends.test.suite.flow import TestFlow
22-
from executorch.backends.test.suite.reporting import log_test_summary
23-
from executorch.backends.test.suite.runner import run_test, runner_main
16+
from executorch.backends.test.suite.runner import runner_main
2417

2518
logger = logging.getLogger(__name__)
2619
logger.setLevel(logging.INFO)
@@ -62,109 +55,6 @@ def get_test_flows() -> dict[str, TestFlow]:
6255
return _ALL_TEST_FLOWS
6356

6457

65-
DTYPES = [
66-
# torch.int8,
67-
# torch.uint8,
68-
# torch.int16,
69-
# torch.uint16,
70-
# torch.int32,
71-
# torch.uint32,
72-
# torch.int64,
73-
# torch.uint64,
74-
# torch.float16,
75-
torch.float32,
76-
# torch.float64,
77-
]
78-
79-
FLOAT_DTYPES = [
80-
torch.float16,
81-
torch.float32,
82-
torch.float64,
83-
]
84-
85-
86-
# The type of test function. This controls the test generation and expected signature.
87-
# Standard tests are run, as is. Dtype tests get a variant generated for each dtype and
88-
# take an additional dtype parameter.
89-
class TestType(Enum):
90-
STANDARD = 1
91-
DTYPE = 2
92-
93-
94-
# Function annotation for dtype tests. This instructs the test framework to run the test
95-
# for each supported dtype and to pass dtype as a test parameter.
96-
def dtype_test(func):
97-
func.test_type = TestType.DTYPE
98-
return func
99-
100-
101-
# Class annotation for operator tests. This triggers the test framework to register
102-
# the tests.
103-
def operator_test(cls):
104-
_create_tests(cls)
105-
return cls
106-
107-
108-
# Generate test cases for each backend flow.
109-
def _create_tests(cls):
110-
for key in dir(cls):
111-
if key.startswith("test_"):
112-
_expand_test(cls, key)
113-
114-
115-
# Expand a test into variants for each registered flow.
116-
def _expand_test(cls, test_name: str):
117-
test_func = getattr(cls, test_name)
118-
for flow in get_test_flows().values():
119-
_create_test_for_backend(cls, test_func, flow)
120-
delattr(cls, test_name)
121-
122-
123-
def _make_wrapped_test(
124-
test_func: Callable,
125-
test_name: str,
126-
flow: TestFlow,
127-
params: dict | None = None,
128-
):
129-
def wrapped_test(self):
130-
with TestContext(test_name, flow.name, params):
131-
test_kwargs = params or {}
132-
test_kwargs["flow"] = flow
133-
134-
test_func(self, **test_kwargs)
135-
136-
wrapped_test._name = test_name
137-
wrapped_test._flow = flow
138-
139-
return wrapped_test
140-
141-
142-
def _create_test_for_backend(
143-
cls,
144-
test_func: Callable,
145-
flow: TestFlow,
146-
):
147-
test_type = getattr(test_func, "test_type", TestType.STANDARD)
148-
149-
if test_type == TestType.STANDARD:
150-
wrapped_test = _make_wrapped_test(test_func, test_func.__name__, flow)
151-
test_name = f"{test_func.__name__}_{flow.name}"
152-
setattr(cls, test_name, wrapped_test)
153-
elif test_type == TestType.DTYPE:
154-
for dtype in DTYPES:
155-
wrapped_test = _make_wrapped_test(
156-
test_func,
157-
test_func.__name__,
158-
flow,
159-
{"dtype": dtype},
160-
)
161-
dtype_name = str(dtype)[6:] # strip "torch."
162-
test_name = f"{test_func.__name__}_{dtype_name}_{flow.name}"
163-
setattr(cls, test_name, wrapped_test)
164-
else:
165-
raise NotImplementedError(f"Unknown test type {test_type}.")
166-
167-
16858
def load_tests(loader, suite, pattern):
16959
package_dir = os.path.dirname(__file__)
17060
discovered_suite = loader.discover(
@@ -174,32 +64,5 @@ def load_tests(loader, suite, pattern):
17464
return suite
17565

17666

177-
class OperatorTest(unittest.TestCase):
178-
def _test_op(self, model, inputs, flow: TestFlow):
179-
context = get_active_test_context()
180-
181-
# This should be set in the wrapped test. See _make_wrapped_test above.
182-
assert context is not None, "Missing test context."
183-
184-
run_summary = run_test(
185-
model,
186-
inputs,
187-
flow,
188-
context.test_name,
189-
context.params,
190-
)
191-
192-
log_test_summary(run_summary)
193-
194-
if not run_summary.result.is_success():
195-
if run_summary.result.is_backend_failure():
196-
raise RuntimeError("Test failure.") from run_summary.error
197-
else:
198-
# Non-backend failure indicates a bad test. Mark as skipped.
199-
raise unittest.SkipTest(
200-
f"Test failed for reasons other than backend failure. Error: {run_summary.error}"
201-
)
202-
203-
20467
if __name__ == "__main__":
20568
runner_main()

backends/test/suite/discovery.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,17 @@ def _filter_tests(
6969
def _is_test_enabled(test_case: unittest.TestCase, test_filter: TestFilter) -> bool:
7070
test_method = getattr(test_case, test_case._testMethodName)
7171

72+
# Handle import / discovery failures - leave them enabled to report nicely at the
73+
# top level. There might be a better way to do this. Internally, unittest seems to
74+
# replace it with a stub method to report the failure.
75+
if "testFailure" in str(test_method):
76+
print(f"Warning: Test {test_case._testMethodName} failed to import.")
77+
return True
78+
7279
if not hasattr(test_method, "_flow"):
73-
print(f"Test missing flow: {test_method}")
80+
raise RuntimeError(
81+
f"Test missing flow: {test_case._testMethodName} {test_method}"
82+
)
7483

7584
flow: TestFlow = test_method._flow
7685

backends/test/suite/operators/__init__.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,17 @@
77
# pyre-unsafe
88

99
import os
10+
import unittest
11+
12+
from enum import Enum
13+
from typing import Callable
14+
15+
import torch
16+
from executorch.backends.test.suite import get_test_flows
17+
from executorch.backends.test.suite.context import get_active_test_context, TestContext
18+
from executorch.backends.test.suite.flow import TestFlow
19+
from executorch.backends.test.suite.reporting import log_test_summary
20+
from executorch.backends.test.suite.runner import run_test
1021

1122

1223
def load_tests(loader, suite, pattern):
@@ -16,3 +27,133 @@ def load_tests(loader, suite, pattern):
1627
)
1728
suite.addTests(discovered_suite)
1829
return suite
30+
31+
32+
DTYPES = [
33+
# torch.int8,
34+
# torch.uint8,
35+
# torch.int16,
36+
# torch.uint16,
37+
# torch.int32,
38+
# torch.uint32,
39+
# torch.int64,
40+
# torch.uint64,
41+
# torch.float16,
42+
torch.float32,
43+
# torch.float64,
44+
]
45+
46+
FLOAT_DTYPES = [
47+
torch.float16,
48+
torch.float32,
49+
torch.float64,
50+
]
51+
52+
53+
# The type of test function. This controls the test generation and expected signature.
54+
# Standard tests are run, as is. Dtype tests get a variant generated for each dtype and
55+
# take an additional dtype parameter.
56+
class TestType(Enum):
57+
STANDARD = 1
58+
DTYPE = 2
59+
60+
61+
# Function annotation for dtype tests. This instructs the test framework to run the test
62+
# for each supported dtype and to pass dtype as a test parameter.
63+
def dtype_test(func):
64+
func.test_type = TestType.DTYPE
65+
return func
66+
67+
68+
# Class annotation for operator tests. This triggers the test framework to register
69+
# the tests.
70+
def operator_test(cls):
71+
_create_tests(cls)
72+
return cls
73+
74+
75+
# Generate test cases for each backend flow.
76+
def _create_tests(cls):
77+
for key in dir(cls):
78+
if key.startswith("test_"):
79+
_expand_test(cls, key)
80+
81+
82+
# Expand a test into variants for each registered flow.
83+
def _expand_test(cls, test_name: str):
84+
test_func = getattr(cls, test_name)
85+
for flow in get_test_flows().values():
86+
_create_test_for_backend(cls, test_func, flow)
87+
delattr(cls, test_name)
88+
89+
90+
def _make_wrapped_test(
91+
test_func: Callable,
92+
test_name: str,
93+
flow: TestFlow,
94+
params: dict | None = None,
95+
):
96+
def wrapped_test(self):
97+
with TestContext(test_name, flow.name, params):
98+
test_kwargs = params or {}
99+
test_kwargs["flow"] = flow
100+
101+
test_func(self, **test_kwargs)
102+
103+
wrapped_test._name = test_name
104+
wrapped_test._flow = flow
105+
106+
return wrapped_test
107+
108+
109+
def _create_test_for_backend(
110+
cls,
111+
test_func: Callable,
112+
flow: TestFlow,
113+
):
114+
test_type = getattr(test_func, "test_type", TestType.STANDARD)
115+
116+
if test_type == TestType.STANDARD:
117+
wrapped_test = _make_wrapped_test(test_func, test_func.__name__, flow)
118+
test_name = f"{test_func.__name__}_{flow.name}"
119+
setattr(cls, test_name, wrapped_test)
120+
elif test_type == TestType.DTYPE:
121+
for dtype in DTYPES:
122+
wrapped_test = _make_wrapped_test(
123+
test_func,
124+
test_func.__name__,
125+
flow,
126+
{"dtype": dtype},
127+
)
128+
dtype_name = str(dtype)[6:] # strip "torch."
129+
test_name = f"{test_func.__name__}_{dtype_name}_{flow.name}"
130+
setattr(cls, test_name, wrapped_test)
131+
else:
132+
raise NotImplementedError(f"Unknown test type {test_type}.")
133+
134+
135+
class OperatorTest(unittest.TestCase):
136+
def _test_op(self, model, inputs, flow: TestFlow):
137+
context = get_active_test_context()
138+
139+
# This should be set in the wrapped test. See _make_wrapped_test above.
140+
assert context is not None, "Missing test context."
141+
142+
run_summary = run_test(
143+
model,
144+
inputs,
145+
flow,
146+
context.test_name,
147+
context.params,
148+
)
149+
150+
log_test_summary(run_summary)
151+
152+
if not run_summary.result.is_success():
153+
if run_summary.result.is_backend_failure():
154+
raise RuntimeError("Test failure.") from run_summary.error
155+
else:
156+
# Non-backend failure indicates a bad test. Mark as skipped.
157+
raise unittest.SkipTest(
158+
f"Test failed for reasons other than backend failure. Error: {run_summary.error}"
159+
)

backends/test/suite/operators/test_add.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@
88

99

1010
import torch
11-
12-
from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest
1311
from executorch.backends.test.suite.flow import TestFlow
1412

13+
from executorch.backends.test.suite.operators import (
14+
dtype_test,
15+
operator_test,
16+
OperatorTest,
17+
)
18+
1519

1620
class Model(torch.nn.Module):
1721
def forward(self, x, y):

backends/test/suite/operators/test_div.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,14 @@
1010
from typing import Optional
1111

1212
import torch
13-
14-
from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest
1513
from executorch.backends.test.suite.flow import TestFlow
1614

15+
from executorch.backends.test.suite.operators import (
16+
dtype_test,
17+
operator_test,
18+
OperatorTest,
19+
)
20+
1721

1822
class Model(torch.nn.Module):
1923
def forward(self, x, y):

backends/test/suite/operators/test_elu.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@
88

99

1010
import torch
11-
12-
from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest
1311
from executorch.backends.test.suite.flow import TestFlow
1412

13+
from executorch.backends.test.suite.operators import (
14+
dtype_test,
15+
operator_test,
16+
OperatorTest,
17+
)
18+
1519

1620
class Model(torch.nn.Module):
1721
def __init__(self, alpha=1.0, inplace=False):

0 commit comments

Comments
 (0)