Skip to content

Commit fbe335b

Browse files
committed
[Backend Tester] Add initial reporting skeleton
1 parent dd4488d commit fbe335b

File tree

4 files changed

+412
-36
lines changed

4 files changed

+412
-36
lines changed

backends/test/suite/__init__.py

Lines changed: 68 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
import torch
1818
from executorch.backends.test.harness import Tester
19+
from executorch.backends.test.suite.context import get_active_test_context, TestContext
20+
from executorch.backends.test.suite.reporting import log_test_summary
21+
from executorch.backends.test.suite.runner import run_test, runner_main
1922

2023
logger = logging.getLogger(__name__)
2124
logger.setLevel(logging.INFO)
@@ -60,17 +63,17 @@ def is_backend_enabled(backend):
6063

6164

6265
DTYPES = [
63-
torch.int8,
64-
torch.uint8,
65-
torch.int16,
66-
torch.uint16,
67-
torch.int32,
68-
torch.uint32,
69-
torch.int64,
70-
torch.uint64,
71-
torch.float16,
66+
# torch.int8,
67+
# torch.uint8,
68+
# torch.int16,
69+
# torch.uint16,
70+
# torch.int32,
71+
# torch.uint32,
72+
# torch.int64,
73+
# torch.uint64,
74+
# torch.float16,
7275
torch.float32,
73-
torch.float64,
76+
# torch.float64,
7477
]
7578

7679
FLOAT_DTYPES = [
@@ -117,16 +120,19 @@ def _expand_test(cls, test_name: str):
117120
delattr(cls, test_name)
118121

119122

120-
def _make_wrapped_test(test_func, *args, **kwargs):
123+
def _make_wrapped_test(
124+
test_func: Callable,
125+
test_name: str,
126+
test_flow: str,
127+
tester_factory: Callable,
128+
params: dict | None = None,
129+
):
121130
def wrapped_test(self):
122-
test_func(self, *args, **kwargs)
131+
with TestContext(test_name, test_flow, params):
132+
test_kwargs = params or {}
133+
test_kwargs["tester_factory"] = tester_factory
123134

124-
return wrapped_test
125-
126-
127-
def _make_wrapped_dtype_test(test_func, dtype, tester_factory):
128-
def wrapped_test(self):
129-
test_func(self, dtype, tester_factory)
135+
test_func(self, **test_kwargs)
130136

131137
return wrapped_test
132138

@@ -140,37 +146,63 @@ def _create_test_for_backend(
140146
test_type = getattr(test_func, "test_type", TestType.STANDARD)
141147

142148
if test_type == TestType.STANDARD:
143-
wrapped_test = _make_wrapped_test(test_func, tester_factory)
149+
wrapped_test = _make_wrapped_test(
150+
test_func, test_func.__name__, flow_name, tester_factory
151+
)
144152
test_name = f"{test_func.__name__}_{flow_name}"
145153
setattr(cls, test_name, wrapped_test)
146154
elif test_type == TestType.DTYPE:
147155
for dtype in DTYPES:
148-
# wrapped_test = _make_wrapped_dtype_test(test_func, dtype, tester_factory)
149-
wrapped_test = _make_wrapped_test(test_func, dtype, tester_factory)
156+
wrapped_test = _make_wrapped_test(
157+
test_func,
158+
test_func.__name__,
159+
flow_name,
160+
tester_factory,
161+
{"dtype": dtype},
162+
)
150163
dtype_name = str(dtype)[6:] # strip "torch."
151164
test_name = f"{test_func.__name__}_{dtype_name}_{flow_name}"
152165
setattr(cls, test_name, wrapped_test)
153166
else:
154167
raise NotImplementedError(f"Unknown test type {test_type}.")
155168

156169

170+
def load_tests(loader, suite, pattern):
171+
package_dir = os.path.dirname(__file__)
172+
discovered_suite = loader.discover(
173+
start_dir=package_dir, pattern=pattern or "test_*.py"
174+
)
175+
suite.addTests(discovered_suite)
176+
return suite
177+
178+
157179
class OperatorTest(unittest.TestCase):
158180
def _test_op(self, model, inputs, tester_factory):
159-
tester = (
160-
tester_factory(
161-
model,
162-
inputs,
163-
)
164-
.export()
165-
.to_edge_transform_and_lower()
181+
context = get_active_test_context()
182+
183+
# This should be set in the wrapped test. See _make_wrapped_test above.
184+
assert context is not None, "Missing test context."
185+
186+
run_summary = run_test(
187+
model,
188+
inputs,
189+
tester_factory,
190+
context.test_name,
191+
context.flow_name,
192+
context.params,
166193
)
167194

168-
is_delegated = any(
169-
n.target == torch._higher_order_ops.executorch_call_delegate
170-
for n in tester.stages[tester.cur].graph_module.graph.nodes
171-
if n.op == "call_function"
172-
)
195+
log_test_summary(run_summary)
196+
197+
if not run_summary.result.is_success():
198+
if run_summary.result.is_backend_failure():
199+
raise RuntimeError("Test failure.") from run_summary.error
200+
else:
201+
# Non-backend failure indicates a bad test. Mark as skipped.
202+
raise unittest.SkipTest(
203+
f"Test failed for reasons other than backend failure. Error: {run_summary.error}"
204+
)
205+
173206

174-
# Only run the runtime test if the op was delegated.
175-
if is_delegated:
176-
(tester.to_executorch().serialize().run_method_and_compare_outputs())
207+
if __name__ == "__main__":
208+
runner_main()

backends/test/suite/context.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Test run context management. This is used to determine the test context for reporting
2+
# purposes.
3+
class TestContext:
4+
def __init__(self, test_name: str, flow_name: str, params: dict | None):
5+
self.test_name = test_name
6+
self.flow_name = flow_name
7+
self.params = params
8+
9+
def __enter__(self):
10+
global _active_test_context
11+
import sys
12+
13+
if _active_test_context is not None:
14+
print(f"Active context: {_active_test_context.test_name}", file=sys.stderr)
15+
assert _active_test_context is None
16+
_active_test_context = self
17+
18+
def __exit__(self, exc_type, exc_value, traceback):
19+
global _active_test_context
20+
_active_test_context = None
21+
22+
23+
_active_test_context: TestContext | None = None
24+
25+
26+
def get_active_test_context() -> TestContext | None:
27+
global _active_test_context
28+
return _active_test_context

backends/test/suite/reporting.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
from collections import Counter
2+
from dataclasses import dataclass
3+
from enum import IntEnum, nonmember
4+
5+
6+
class TestResult(IntEnum):
7+
"""Represents the result of a test case run, indicating success or a specific failure reason."""
8+
9+
SUCCESS = 0
10+
""" The test succeeded with the backend delegate part or all of the graph. """
11+
12+
SUCCESS_UNDELEGATED = 1
13+
""" The test succeeded without the backend delegating anything. """
14+
15+
EAGER_FAIL = 2
16+
""" The test failed due to the model failing to run in eager mode. """
17+
18+
EXPORT_FAIL = 3
19+
""" The test failed due to the model failing to export. """
20+
21+
LOWER_FAIL = 4
22+
""" The test failed due to a failure in partitioning or lowering. """
23+
24+
PTE_LOAD_FAIL = 5
25+
""" The test failed due to the resulting PTE failing to load. """
26+
27+
PTE_RUN_FAIL = 6
28+
""" The test failed due to the resulting PTE failing to run. """
29+
30+
OUTPUT_MISMATCH_FAIL = 7
31+
""" The test failed due to a mismatch between runtime and reference outputs. """
32+
33+
UNKNOWN_FAIL = 8
34+
""" The test failed in an unknown or unexpected manner. """
35+
36+
@nonmember
37+
def is_success(self):
38+
return self in {TestResult.SUCCESS, TestResult.SUCCESS_UNDELEGATED}
39+
40+
@nonmember
41+
def is_non_backend_failure(self):
42+
return self in {TestResult.EAGER_FAIL, TestResult.EAGER_FAIL}
43+
44+
@nonmember
45+
def is_backend_failure(self):
46+
return not self.is_success() and not self.is_non_backend_failure()
47+
48+
@nonmember
49+
def display_name(self):
50+
if self == TestResult.SUCCESS:
51+
return "Success (Delegated)"
52+
elif self == TestResult.SUCCESS_UNDELEGATED:
53+
return "Success (Undelegated)"
54+
elif self == TestResult.EAGER_FAIL:
55+
return "Fail (Eager)"
56+
elif self == TestResult.EXPORT_FAIL:
57+
return "Fail (Export)"
58+
elif self == TestResult.LOWER_FAIL:
59+
return "Fail (Lowering)"
60+
elif self == TestResult.PTE_LOAD_FAIL:
61+
return "Fail (PTE Load)"
62+
elif self == TestResult.PTE_RUN_FAIL:
63+
return "Fail (PTE Run)"
64+
elif self == TestResult.OUTPUT_MISMATCH_FAIL:
65+
return "Fail (Output Mismatch)"
66+
elif self == TestResult.UNKNOWN_FAIL:
67+
return "Fail (Other)"
68+
else:
69+
raise ValueError(f"Invalid TestResult value: {self}.")
70+
71+
72+
@dataclass
73+
class TestCaseSummary:
74+
"""
75+
Contains summary results for the execution of a single test case.
76+
"""
77+
78+
name: str
79+
""" The qualified name of the test, not including the flow suffix. """
80+
81+
flow: str
82+
""" The backend-specific flow name. Corresponds to flows registered in backends/test/suite/__init__.py. """
83+
84+
params: dict | None
85+
""" Test-specific parameters, such as dtype. """
86+
87+
result: TestResult
88+
""" The top-level result, such as SUCCESS or LOWER_FAIL. """
89+
90+
error: Exception | None
91+
""" The Python exception object, if any. """
92+
93+
94+
class TestSessionState:
95+
test_case_summaries: list[TestCaseSummary]
96+
97+
def __init__(self):
98+
self.test_case_summaries = []
99+
100+
101+
@dataclass
102+
class RunSummary:
103+
aggregated_results: dict[TestResult, int]
104+
num_test_cases: int
105+
test_case_summaries: list[TestCaseSummary]
106+
total_failed: int
107+
total_passed: int
108+
total_skipped: int
109+
110+
@staticmethod
111+
def from_session(session: TestSessionState) -> "RunSummary":
112+
# Total each outcome type.
113+
aggregated_results = dict(
114+
sorted(Counter(s.result for s in session.test_case_summaries).items())
115+
)
116+
117+
total_failed = 0
118+
total_passed = 0
119+
total_skipped = 0
120+
121+
for k, v in aggregated_results.items():
122+
if k.is_success():
123+
total_passed += v
124+
elif k.is_backend_failure():
125+
total_failed += v
126+
else:
127+
total_skipped += v
128+
129+
return RunSummary(
130+
aggregated_results=aggregated_results,
131+
num_test_cases=len(session.test_case_summaries),
132+
test_case_summaries=session.test_case_summaries,
133+
total_failed=total_failed,
134+
total_passed=total_passed,
135+
total_skipped=total_skipped,
136+
)
137+
138+
139+
_active_session: TestSessionState | None = None
140+
141+
142+
def begin_test_session():
143+
global _active_session
144+
145+
assert _active_session is None, "A test session is already active."
146+
_active_session = TestSessionState()
147+
148+
149+
def log_test_summary(summary: TestCaseSummary):
150+
global _active_session
151+
152+
if _active_session is not None:
153+
_active_session.test_case_summaries.append(summary)
154+
155+
156+
def complete_test_session() -> RunSummary:
157+
global _active_session
158+
159+
assert _active_session is not None, "No test session is active."
160+
summary = RunSummary.from_session(_active_session)
161+
_active_session = None
162+
163+
return summary

0 commit comments

Comments
 (0)