Skip to content

Commit 76b08f3

Browse files
authored
[Backend Tester] Add backend filtering, improved test discovery (#12624)
Iteratively update the backend test suite to support cleaner test flow registration, backend filtering, and a bit of misc incremental cleanup. The test suite can now be invoked with `python -m executorch.backends.test.suite.runner operators --backend xnnpack`, for example.
1 parent 6d86fa9 commit 76b08f3

File tree

6 files changed

+199
-39
lines changed

6 files changed

+199
-39
lines changed

backends/test/suite/__init__.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,13 @@
1212
import unittest
1313

1414
from enum import Enum
15-
from typing import Any, Callable, Tuple
15+
from typing import Callable
16+
17+
import executorch.backends.test.suite.flow
1618

1719
import torch
18-
from executorch.backends.test.harness import Tester
1920
from executorch.backends.test.suite.context import get_active_test_context, TestContext
21+
from executorch.backends.test.suite.flow import TestFlow
2022
from executorch.backends.test.suite.reporting import log_test_summary
2123
from executorch.backends.test.suite.runner import run_test, runner_main
2224

@@ -44,22 +46,20 @@ def is_backend_enabled(backend):
4446
return backend in _ENABLED_BACKENDS
4547

4648

47-
ALL_TEST_FLOWS = []
49+
_ALL_TEST_FLOWS: dict[str, TestFlow] = {}
4850

49-
if is_backend_enabled("xnnpack"):
50-
from executorch.backends.xnnpack.test.tester import Tester as XnnpackTester
5151

52-
XNNPACK_TEST_FLOW = ("xnnpack", XnnpackTester)
53-
ALL_TEST_FLOWS.append(XNNPACK_TEST_FLOW)
52+
def get_test_flows() -> dict[str, TestFlow]:
53+
global _ALL_TEST_FLOWS
5454

55-
if is_backend_enabled("coreml"):
56-
try:
57-
from executorch.backends.apple.coreml.test.tester import CoreMLTester
55+
if not _ALL_TEST_FLOWS:
56+
_ALL_TEST_FLOWS = {
57+
name: f
58+
for name, f in executorch.backends.test.suite.flow.all_flows().items()
59+
if is_backend_enabled(f.backend)
60+
}
5861

59-
COREML_TEST_FLOW = ("coreml", CoreMLTester)
60-
ALL_TEST_FLOWS.append(COREML_TEST_FLOW)
61-
except Exception:
62-
print("Core ML AOT is not available.")
62+
return _ALL_TEST_FLOWS
6363

6464

6565
DTYPES = [
@@ -115,53 +115,51 @@ def _create_tests(cls):
115115
# Expand a test into variants for each registered flow.
116116
def _expand_test(cls, test_name: str):
117117
test_func = getattr(cls, test_name)
118-
for flow_name, tester_factory in ALL_TEST_FLOWS:
119-
_create_test_for_backend(cls, test_func, flow_name, tester_factory)
118+
for flow in get_test_flows().values():
119+
_create_test_for_backend(cls, test_func, flow)
120120
delattr(cls, test_name)
121121

122122

123123
def _make_wrapped_test(
124124
test_func: Callable,
125125
test_name: str,
126-
test_flow: str,
127-
tester_factory: Callable,
126+
flow: TestFlow,
128127
params: dict | None = None,
129128
):
130129
def wrapped_test(self):
131-
with TestContext(test_name, test_flow, params):
130+
with TestContext(test_name, flow.name, params):
132131
test_kwargs = params or {}
133-
test_kwargs["tester_factory"] = tester_factory
132+
test_kwargs["tester_factory"] = flow.tester_factory
134133

135134
test_func(self, **test_kwargs)
136135

136+
wrapped_test._name = test_name
137+
wrapped_test._flow = flow
138+
137139
return wrapped_test
138140

139141

140142
def _create_test_for_backend(
141143
cls,
142144
test_func: Callable,
143-
flow_name: str,
144-
tester_factory: Callable[[torch.nn.Module, Tuple[Any]], Tester],
145+
flow: TestFlow,
145146
):
146147
test_type = getattr(test_func, "test_type", TestType.STANDARD)
147148

148149
if test_type == TestType.STANDARD:
149-
wrapped_test = _make_wrapped_test(
150-
test_func, test_func.__name__, flow_name, tester_factory
151-
)
152-
test_name = f"{test_func.__name__}_{flow_name}"
150+
wrapped_test = _make_wrapped_test(test_func, test_func.__name__, flow)
151+
test_name = f"{test_func.__name__}_{flow.name}"
153152
setattr(cls, test_name, wrapped_test)
154153
elif test_type == TestType.DTYPE:
155154
for dtype in DTYPES:
156155
wrapped_test = _make_wrapped_test(
157156
test_func,
158157
test_func.__name__,
159-
flow_name,
160-
tester_factory,
158+
flow,
161159
{"dtype": dtype},
162160
)
163161
dtype_name = str(dtype)[6:] # strip "torch."
164-
test_name = f"{test_func.__name__}_{dtype_name}_{flow_name}"
162+
test_name = f"{test_func.__name__}_{dtype_name}_{flow.name}"
165163
setattr(cls, test_name, wrapped_test)
166164
else:
167165
raise NotImplementedError(f"Unknown test type {test_type}.")

backends/test/suite/discovery.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import os
10+
import unittest
11+
12+
from types import ModuleType
13+
14+
from executorch.backends.test.suite.flow import TestFlow
15+
16+
#
17+
# This file contains logic related to test discovery and filtering.
18+
#
19+
20+
21+
def discover_tests(
22+
root_module: ModuleType, backends: set[str] | None
23+
) -> unittest.TestSuite:
24+
# Collect all tests using the unittest discovery mechanism then filter down.
25+
26+
# Find the file system path corresponding to the root module.
27+
module_file = root_module.__file__
28+
if module_file is None:
29+
raise RuntimeError(f"Module {root_module} has no __file__ attribute")
30+
31+
loader = unittest.TestLoader()
32+
module_dir = os.path.dirname(module_file)
33+
suite = loader.discover(module_dir)
34+
35+
return _filter_tests(suite, backends)
36+
37+
38+
def _filter_tests(
39+
suite: unittest.TestSuite, backends: set[str] | None
40+
) -> unittest.TestSuite:
41+
# Recursively traverse the test suite and add them to the filtered set.
42+
filtered_suite = unittest.TestSuite()
43+
44+
for child in suite:
45+
if isinstance(child, unittest.TestSuite):
46+
filtered_suite.addTest(_filter_tests(child, backends))
47+
elif isinstance(child, unittest.TestCase):
48+
if _is_test_enabled(child, backends):
49+
filtered_suite.addTest(child)
50+
else:
51+
raise RuntimeError(f"Unexpected test type: {type(child)}")
52+
53+
return filtered_suite
54+
55+
56+
def _is_test_enabled(test_case: unittest.TestCase, backends: set[str] | None) -> bool:
57+
test_method = getattr(test_case, test_case._testMethodName)
58+
59+
if backends is not None:
60+
flow: TestFlow = test_method._flow
61+
return flow.backend in backends
62+
else:
63+
return True

backends/test/suite/flow.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import logging
2+
3+
from dataclasses import dataclass
4+
from typing import Callable
5+
6+
from executorch.backends.test.harness import Tester
7+
8+
logger = logging.getLogger(__name__)
9+
logger.setLevel(logging.INFO)
10+
11+
12+
@dataclass
13+
class TestFlow:
14+
"""
15+
A lowering flow to test. This typically corresponds to a combination of a backend and
16+
a lowering recipe.
17+
"""
18+
19+
name: str
20+
""" The name of the lowering flow. """
21+
22+
backend: str
23+
""" The name of the target backend. """
24+
25+
tester_factory: Callable[[], Tester]
26+
""" A factory function that returns a Tester instance for this lowering flow. """
27+
28+
29+
def create_xnnpack_flow() -> TestFlow | None:
30+
try:
31+
from executorch.backends.xnnpack.test.tester import Tester as XnnpackTester
32+
33+
return TestFlow(
34+
name="xnnpack",
35+
backend="xnnpack",
36+
tester_factory=XnnpackTester,
37+
)
38+
except Exception:
39+
logger.info("Skipping XNNPACK flow registration due to import failure.")
40+
return None
41+
42+
43+
def create_coreml_flow() -> TestFlow | None:
44+
try:
45+
from executorch.backends.apple.coreml.test.tester import CoreMLTester
46+
47+
return TestFlow(
48+
name="coreml",
49+
backend="coreml",
50+
tester_factory=CoreMLTester,
51+
)
52+
except Exception:
53+
logger.info("Skipping Core ML flow registration due to import failure.")
54+
return None
55+
56+
57+
def all_flows() -> dict[str, TestFlow]:
58+
flows = [
59+
create_xnnpack_flow(),
60+
create_coreml_flow(),
61+
]
62+
return {f.name: f for f in flows if f is not None}

backends/test/suite/operators/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-unsafe
8+
9+
import os
10+
11+
12+
def load_tests(loader, suite, pattern):
13+
package_dir = os.path.dirname(__file__)
14+
discovered_suite = loader.discover(
15+
start_dir=package_dir, pattern=pattern or "test_*.py"
16+
)
17+
suite.addTests(discovered_suite)
18+
return suite

backends/test/suite/reporting.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections import Counter
22
from dataclasses import dataclass
3-
from enum import IntEnum, nonmember
3+
from enum import IntEnum
44

55

66
class TestResult(IntEnum):
@@ -33,19 +33,15 @@ class TestResult(IntEnum):
3333
UNKNOWN_FAIL = 8
3434
""" The test failed in an unknown or unexpected manner. """
3535

36-
@nonmember
3736
def is_success(self):
3837
return self in {TestResult.SUCCESS, TestResult.SUCCESS_UNDELEGATED}
3938

40-
@nonmember
4139
def is_non_backend_failure(self):
4240
return self in {TestResult.EAGER_FAIL, TestResult.EAGER_FAIL}
4341

44-
@nonmember
4542
def is_backend_failure(self):
4643
return not self.is_success() and not self.is_non_backend_failure()
4744

48-
@nonmember
4945
def display_name(self):
5046
if self == TestResult.SUCCESS:
5147
return "Success (Delegated)"

backends/test/suite/runner.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import argparse
2+
import importlib
23
import unittest
34

45
from typing import Callable
56

67
import torch
78

89
from executorch.backends.test.harness import Tester
10+
from executorch.backends.test.suite.discovery import discover_tests
911
from executorch.backends.test.suite.reporting import (
1012
begin_test_session,
1113
complete_test_session,
@@ -15,6 +17,12 @@
1517
)
1618

1719

20+
# A list of all runnable test suites and the corresponding python package.
21+
NAMED_SUITES = {
22+
"operators": "executorch.backends.test.suite.operators",
23+
}
24+
25+
1826
def run_test( # noqa: C901
1927
model: torch.nn.Module,
2028
inputs: any,
@@ -130,20 +138,42 @@ def parse_args():
130138
prog="ExecuTorch Backend Test Suite",
131139
description="Run ExecuTorch backend tests.",
132140
)
133-
parser.add_argument("test_path", nargs="?", help="Prefix filter for tests to run.")
141+
parser.add_argument(
142+
"suite",
143+
nargs="*",
144+
help="The test suite to run.",
145+
choices=NAMED_SUITES.keys(),
146+
default=["operators"],
147+
)
148+
parser.add_argument(
149+
"-b", "--backend", nargs="*", help="The backend or backends to test."
150+
)
134151
return parser.parse_args()
135152

136153

154+
def test(suite):
155+
if isinstance(suite, unittest.TestSuite):
156+
print(f"Suite: {suite}")
157+
for t in suite:
158+
test(t)
159+
else:
160+
print(f"Leaf: {type(suite)} {suite}")
161+
print(f" {suite.__name__}")
162+
print(f" {callable(suite)}")
163+
164+
137165
def runner_main():
138166
args = parse_args()
139167

140168
begin_test_session()
141169

142-
test_path = args.test_path or "executorch.backends.test.suite.operators"
170+
if len(args.suite) > 1:
171+
raise NotImplementedError("TODO Support multiple suites.")
143172

144-
loader = unittest.TestLoader()
145-
suite = loader.loadTestsFromName(test_path)
146-
unittest.TextTestRunner().run(suite)
173+
test_path = NAMED_SUITES[args.suite[0]]
174+
test_root = importlib.import_module(test_path)
175+
suite = discover_tests(test_root, args.backend)
176+
unittest.TextTestRunner(verbosity=2).run(suite)
147177

148178
summary = complete_test_session()
149179
print_summary(summary)

0 commit comments

Comments
 (0)