Skip to content

Commit f67c9ec

Browse files
committed
[Backend Tester] Add subtest index field
ghstack-source-id: ed7cbf3 ghstack-comment-id: 3177697272 Pull-Request: #13311
1 parent fda6f78 commit f67c9ec

File tree

4 files changed

+63
-8
lines changed

4 files changed

+63
-8
lines changed

backends/test/suite/context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
# Test run context management. This is used to determine the test context for reporting
22
# purposes.
33
class TestContext:
4+
subtest_index: int
5+
46
def __init__(
57
self, test_name: str, test_base_name: str, flow_name: str, params: dict | None
68
):
79
self.test_name = test_name
810
self.test_base_name = test_base_name
911
self.flow_name = flow_name
1012
self.params = params
13+
self.subtest_index = 0
1114

1215
def __enter__(self):
1316
global _active_test_context

backends/test/suite/operators/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,16 @@ def _test_op(
152152
flow,
153153
context.test_name,
154154
context.test_base_name,
155+
context.subtest_index,
155156
context.params,
156157
generate_random_test_inputs=generate_random_test_inputs,
157158
)
158159

159160
log_test_summary(run_summary)
160161

162+
# This is reset when a new test is started - it creates the context per-test.
163+
context.subtest_index = context.subtest_index + 1
164+
161165
if not run_summary.result.is_success():
162166
if run_summary.result.is_backend_failure():
163167
raise RuntimeError("Test failure.") from run_summary.error

backends/test/suite/reporting.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,41 @@
1111
from torch.export import ExportedProgram
1212

1313

14+
# The maximum number of model output tensors to log statistics for. Most model tests will
15+
# only have one output, but some may return more than one tensor. This upper bound is needed
16+
# upfront since the file is written progressively. Any outputs beyond these will not have stats logged.
17+
MAX_LOGGED_MODEL_OUTPUTS = 2
18+
19+
20+
# Field names for the CSV report.
21+
CSV_FIELD_NAMES = [
22+
"Test ID",
23+
"Test Case",
24+
"Subtest",
25+
"Flow",
26+
"Params",
27+
"Result",
28+
"Result Detail",
29+
"Delegated",
30+
"Quantize Time (s)",
31+
"Lower Time (s)",
32+
"Delegated Nodes",
33+
"Undelegated Nodes",
34+
"Delegated Ops",
35+
"Undelegated Ops",
36+
"PTE Size (Kb)",
37+
]
38+
39+
for i in range(MAX_LOGGED_MODEL_OUTPUTS):
40+
CSV_FIELD_NAMES.extend(
41+
[
42+
f"Output {i} Error Max",
43+
f"Output {i} Error MAE",
44+
f"Output {i} SNR",
45+
]
46+
)
47+
48+
1449
# Operators that are excluded from the counts returned by count_ops. These are used to
1550
# exclude operatations that are not logically relevant or delegatable to backends.
1651
OP_COUNT_IGNORED_OPS = {
@@ -129,6 +164,9 @@ class TestCaseSummary:
129164
name: str
130165
""" The full name of test, including flow and parameter suffixes. """
131166

167+
subtest_index: int
168+
""" The subtest number. If a test case runs multiple tests, this field can be used to disambiguate. """
169+
132170
params: dict | None
133171
""" Test-specific parameters, such as dtype. """
134172

@@ -305,14 +343,22 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
305343
"Lower Time (s)",
306344
]
307345

308-
# Tests can have custom parameters. We'll want to report them here, so we need
309-
# a list of all unique parameter names.
310-
param_names = reduce(
311-
lambda a, b: a.union(b),
312-
(
313-
set(s.params.keys())
314-
for s in summary.test_case_summaries
315-
if s.params is not None
346+
def write_csv_row(record: TestCaseSummary, output: TextIO):
347+
writer = csv.DictWriter(output, CSV_FIELD_NAMES)
348+
349+
row = {
350+
"Test ID": record.name,
351+
"Test Case": record.base_name,
352+
"Subtest": record.subtest_index,
353+
"Flow": record.flow,
354+
"Params": _serialize_params(record.params),
355+
"Result": record.result.to_short_str(),
356+
"Result Detail": record.result.to_detail_str(),
357+
"Delegated": "True" if record.is_delegated() else "False",
358+
"Quantize Time (s)": (
359+
f"{record.quantize_time.total_seconds():.3f}"
360+
if record.quantize_time
361+
else None
316362
),
317363
set(),
318364
)

backends/test/suite/runner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def run_test( # noqa: C901
4646
flow: TestFlow,
4747
test_name: str,
4848
test_base_name: str,
49+
subtest_index: int,
4950
params: dict | None,
5051
dynamic_shapes: Any | None = None,
5152
generate_random_test_inputs: bool = True,
@@ -65,6 +66,7 @@ def build_result(
6566
return TestCaseSummary(
6667
backend=flow.backend,
6768
base_name=test_base_name,
69+
subtest_index=subtest_index,
6870
flow=flow.name,
6971
name=test_name,
7072
params=params,

0 commit comments

Comments
 (0)