Skip to content

Commit 657876d

Browse files
authored
[Backend Tester] Write report progressively (#13308)
Append to the report file line by line after each test, rather than all at the end. This ensures that report data is available if the test run is aborted or hit with an unrecoverable native crash.
1 parent 8b04295 commit 657876d

File tree

3 files changed

+113
-108
lines changed

3 files changed

+113
-108
lines changed

backends/test/suite/reporting.py

Lines changed: 103 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import csv
22

33
from collections import Counter
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
from datetime import timedelta
66
from enum import IntEnum
77
from functools import reduce
@@ -11,6 +11,40 @@
1111
from torch.export import ExportedProgram
1212

1313

14+
# The maximum number of model output tensors to log statistics for. Most model tests will
15+
# only have one output, but some may return more than one tensor. This upper bound is needed
16+
# upfront since the file is written progressively. Any outputs beyond these will not have stats logged.
17+
MAX_LOGGED_MODEL_OUTPUTS = 2
18+
19+
20+
# Field names for the CSV report.
21+
CSV_FIELD_NAMES = [
22+
"Test ID",
23+
"Test Case",
24+
"Flow",
25+
"Params",
26+
"Result",
27+
"Result Detail",
28+
"Delegated",
29+
"Quantize Time (s)",
30+
"Lower Time (s)",
31+
"Delegated Nodes",
32+
"Undelegated Nodes",
33+
"Delegated Ops",
34+
"Undelegated Ops",
35+
"PTE Size (Kb)",
36+
]
37+
38+
for i in range(MAX_LOGGED_MODEL_OUTPUTS):
39+
CSV_FIELD_NAMES.extend(
40+
[
41+
f"Output {i} Error Max",
42+
f"Output {i} Error MAE",
43+
f"Output {i} SNR",
44+
]
45+
)
46+
47+
1448
# Operators that are excluded from the counts returned by count_ops. These are used to
1549
# exclude operatations that are not logically relevant or delegatable to backends.
1650
OP_COUNT_IGNORED_OPS = {
@@ -167,11 +201,15 @@ def is_delegated(self):
167201
)
168202

169203

204+
@dataclass
170205
class TestSessionState:
171-
test_case_summaries: list[TestCaseSummary]
206+
# True if the CSV header has been written to report__path.
207+
has_written_report_header: bool = False
172208

173-
def __init__(self):
174-
self.test_case_summaries = []
209+
# The file path to write the detail report to, if enabled.
210+
report_path: str | None = None
211+
212+
test_case_summaries: list[TestCaseSummary] = field(default_factory=list)
175213

176214

177215
@dataclass
@@ -249,11 +287,11 @@ def count_ops(program: dict[str, ExportedProgram] | ExportedProgram) -> Counter:
249287
)
250288

251289

252-
def begin_test_session():
290+
def begin_test_session(report_path: str | None):
253291
global _active_session
254292

255293
assert _active_session is None, "A test session is already active."
256-
_active_session = TestSessionState()
294+
_active_session = TestSessionState(report_path=report_path)
257295

258296

259297
def log_test_summary(summary: TestCaseSummary):
@@ -262,6 +300,15 @@ def log_test_summary(summary: TestCaseSummary):
262300
if _active_session is not None:
263301
_active_session.test_case_summaries.append(summary)
264302

303+
if _active_session.report_path is not None:
304+
file_mode = "a" if _active_session.has_written_report_header else "w"
305+
with open(_active_session.report_path, file_mode) as f:
306+
if not _active_session.has_written_report_header:
307+
write_csv_header(f)
308+
_active_session.has_written_report_header = True
309+
310+
write_csv_row(summary, f)
311+
265312

266313
def complete_test_session() -> RunSummary:
267314
global _active_session
@@ -280,6 +327,13 @@ def _sum_op_counts(counter: Counter | None) -> int | None:
280327
return sum(counter.values()) if counter is not None else None
281328

282329

330+
def _serialize_params(params: dict[str, Any] | None) -> str:
331+
if params is not None:
332+
return str(dict(sorted(params.items())))
333+
else:
334+
return ""
335+
336+
283337
def _serialize_op_counts(counter: Counter | None) -> str:
284338
"""
285339
A utility function to serialize op counts to a string, for the purpose of including
@@ -291,91 +345,49 @@ def _serialize_op_counts(counter: Counter | None) -> str:
291345
return ""
292346

293347

294-
def generate_csv_report(summary: RunSummary, output: TextIO):
295-
"""Write a run summary report to a file in CSV format."""
296-
297-
field_names = [
298-
"Test ID",
299-
"Test Case",
300-
"Flow",
301-
"Result",
302-
"Result Detail",
303-
"Delegated",
304-
"Quantize Time (s)",
305-
"Lower Time (s)",
306-
]
307-
308-
# Tests can have custom parameters. We'll want to report them here, so we need
309-
# a list of all unique parameter names.
310-
param_names = reduce(
311-
lambda a, b: a.union(b),
312-
(
313-
set(s.params.keys())
314-
for s in summary.test_case_summaries
315-
if s.params is not None
316-
),
317-
set(),
318-
)
319-
field_names += (s.capitalize() for s in param_names)
320-
321-
# Add tensor error statistic field names for each output index.
322-
max_outputs = max(
323-
len(s.tensor_error_statistics) for s in summary.test_case_summaries
324-
)
325-
for i in range(max_outputs):
326-
field_names.extend(
327-
[
328-
f"Output {i} Error Max",
329-
f"Output {i} Error MAE",
330-
f"Output {i} SNR",
331-
]
332-
)
333-
field_names.extend(
334-
[
335-
"Delegated Nodes",
336-
"Undelegated Nodes",
337-
"Delegated Ops",
338-
"Undelegated Ops",
339-
"PTE Size (Kb)",
340-
]
341-
)
342-
343-
writer = csv.DictWriter(output, field_names)
348+
def write_csv_header(output: TextIO):
349+
writer = csv.DictWriter(output, CSV_FIELD_NAMES)
344350
writer.writeheader()
345351

346-
for record in summary.test_case_summaries:
347-
row = {
348-
"Test ID": record.name,
349-
"Test Case": record.base_name,
350-
"Flow": record.flow,
351-
"Result": record.result.to_short_str(),
352-
"Result Detail": record.result.to_detail_str(),
353-
"Delegated": "True" if record.is_delegated() else "False",
354-
"Quantize Time (s)": (
355-
f"{record.quantize_time.total_seconds():.3f}"
356-
if record.quantize_time
357-
else None
358-
),
359-
"Lower Time (s)": (
360-
f"{record.lower_time.total_seconds():.3f}"
361-
if record.lower_time
362-
else None
363-
),
364-
}
365-
if record.params is not None:
366-
row.update({k.capitalize(): v for k, v in record.params.items()})
367-
368-
for output_idx, error_stats in enumerate(record.tensor_error_statistics):
369-
row[f"Output {output_idx} Error Max"] = f"{error_stats.error_max:.3f}"
370-
row[f"Output {output_idx} Error MAE"] = f"{error_stats.error_mae:.3f}"
371-
row[f"Output {output_idx} SNR"] = f"{error_stats.sqnr:.3f}"
372-
373-
row["Delegated Nodes"] = _sum_op_counts(record.delegated_op_counts)
374-
row["Undelegated Nodes"] = _sum_op_counts(record.undelegated_op_counts)
375-
row["Delegated Ops"] = _serialize_op_counts(record.delegated_op_counts)
376-
row["Undelegated Ops"] = _serialize_op_counts(record.undelegated_op_counts)
377-
row["PTE Size (Kb)"] = (
378-
f"{record.pte_size_bytes / 1000.0:.3f}" if record.pte_size_bytes else ""
379-
)
380352

381-
writer.writerow(row)
353+
def write_csv_row(record: TestCaseSummary, output: TextIO):
354+
writer = csv.DictWriter(output, CSV_FIELD_NAMES)
355+
356+
row = {
357+
"Test ID": record.name,
358+
"Test Case": record.base_name,
359+
"Flow": record.flow,
360+
"Params": _serialize_params(record.params),
361+
"Result": record.result.to_short_str(),
362+
"Result Detail": record.result.to_detail_str(),
363+
"Delegated": "True" if record.is_delegated() else "False",
364+
"Quantize Time (s)": (
365+
f"{record.quantize_time.total_seconds():.3f}"
366+
if record.quantize_time
367+
else None
368+
),
369+
"Lower Time (s)": (
370+
f"{record.lower_time.total_seconds():.3f}" if record.lower_time else None
371+
),
372+
}
373+
374+
for output_idx, error_stats in enumerate(record.tensor_error_statistics):
375+
if output_idx >= MAX_LOGGED_MODEL_OUTPUTS:
376+
print(
377+
f"Model output stats are truncated as model has more than {MAX_LOGGED_MODEL_OUTPUTS} outputs. Consider increasing MAX_LOGGED_MODEL_OUTPUTS."
378+
)
379+
break
380+
381+
row[f"Output {output_idx} Error Max"] = f"{error_stats.error_max:.3f}"
382+
row[f"Output {output_idx} Error MAE"] = f"{error_stats.error_mae:.3f}"
383+
row[f"Output {output_idx} SNR"] = f"{error_stats.sqnr:.3f}"
384+
385+
row["Delegated Nodes"] = _sum_op_counts(record.delegated_op_counts)
386+
row["Undelegated Nodes"] = _sum_op_counts(record.undelegated_op_counts)
387+
row["Delegated Ops"] = _serialize_op_counts(record.delegated_op_counts)
388+
row["Undelegated Ops"] = _serialize_op_counts(record.undelegated_op_counts)
389+
row["PTE Size (Kb)"] = (
390+
f"{record.pte_size_bytes / 1000.0:.3f}" if record.pte_size_bytes else ""
391+
)
392+
393+
writer.writerow(row)

backends/test/suite/runner.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
begin_test_session,
2626
complete_test_session,
2727
count_ops,
28-
generate_csv_report,
2928
RunSummary,
3029
TestCaseSummary,
3130
TestResult,
@@ -248,7 +247,7 @@ def build_test_filter(args: argparse.Namespace) -> TestFilter:
248247
def runner_main():
249248
args = parse_args()
250249

251-
begin_test_session()
250+
begin_test_session(args.report)
252251

253252
if len(args.suite) > 1:
254253
raise NotImplementedError("TODO Support multiple suites.")
@@ -263,11 +262,6 @@ def runner_main():
263262
summary = complete_test_session()
264263
print_summary(summary)
265264

266-
if args.report is not None:
267-
with open(args.report, "w") as f:
268-
print(f"Writing CSV report to {args.report}.")
269-
generate_csv_report(summary, f)
270-
271265

272266
if __name__ == "__main__":
273267
runner_main()

backends/test/suite/tests/test_reporting.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99

1010
from ..reporting import (
1111
count_ops,
12-
generate_csv_report,
1312
RunSummary,
1413
TestCaseSummary,
1514
TestResult,
1615
TestSessionState,
16+
write_csv_header,
17+
write_csv_row,
1718
)
1819

1920
# Test data for simulated test results.
@@ -69,7 +70,9 @@ def test_csv_report_simple(self):
6970
run_summary = RunSummary.from_session(session_state)
7071

7172
strio = StringIO()
72-
generate_csv_report(run_summary, strio)
73+
write_csv_header(strio)
74+
for case_summary in run_summary.test_case_summaries:
75+
write_csv_row(case_summary, strio)
7376

7477
# Attempt to deserialize and validate the CSV report.
7578
report = DictReader(StringIO(strio.getvalue()))
@@ -81,32 +84,28 @@ def test_csv_report_simple(self):
8184
self.assertEqual(records[0]["Test Case"], "test1")
8285
self.assertEqual(records[0]["Flow"], "flow1")
8386
self.assertEqual(records[0]["Result"], "Pass")
84-
self.assertEqual(records[0]["Dtype"], "")
85-
self.assertEqual(records[0]["Use_dynamic_shapes"], "")
87+
self.assertEqual(records[0]["Params"], "")
8688

8789
# Validate second record: test1, backend2, LOWER_FAIL
8890
self.assertEqual(records[1]["Test ID"], "test1_backend2_flow1")
8991
self.assertEqual(records[1]["Test Case"], "test1")
9092
self.assertEqual(records[1]["Flow"], "flow1")
9193
self.assertEqual(records[1]["Result"], "Fail")
92-
self.assertEqual(records[1]["Dtype"], "")
93-
self.assertEqual(records[1]["Use_dynamic_shapes"], "")
94+
self.assertEqual(records[1]["Params"], "")
9495

9596
# Validate third record: test2, backend1, SUCCESS_UNDELEGATED with dtype param
9697
self.assertEqual(records[2]["Test ID"], "test2_backend1_flow1")
9798
self.assertEqual(records[2]["Test Case"], "test2")
9899
self.assertEqual(records[2]["Flow"], "flow1")
99100
self.assertEqual(records[2]["Result"], "Pass")
100-
self.assertEqual(records[2]["Dtype"], str(torch.float32))
101-
self.assertEqual(records[2]["Use_dynamic_shapes"], "")
101+
self.assertEqual(records[2]["Params"], str({"dtype": torch.float32}))
102102

103103
# Validate fourth record: test2, backend2, EXPORT_FAIL with use_dynamic_shapes param
104104
self.assertEqual(records[3]["Test ID"], "test2_backend2_flow1")
105105
self.assertEqual(records[3]["Test Case"], "test2")
106106
self.assertEqual(records[3]["Flow"], "flow1")
107107
self.assertEqual(records[3]["Result"], "Skip")
108-
self.assertEqual(records[3]["Dtype"], "")
109-
self.assertEqual(records[3]["Use_dynamic_shapes"], "True")
108+
self.assertEqual(records[3]["Params"], str({"use_dynamic_shapes": True}))
110109

111110
def test_count_ops(self):
112111
"""

0 commit comments

Comments
 (0)