Skip to content

Commit b32eca7

Browse files
committed
[Backend Tester] Write report progressively
ghstack-source-id: 3cf5663 ghstack-comment-id: 3177579936 Pull-Request: #13308
1 parent 6279cbd commit b32eca7

File tree

3 files changed

+113
-108
lines changed

3 files changed

+113
-108
lines changed

backends/test/suite/reporting.py

Lines changed: 103 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import csv
22

33
from collections import Counter
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
from datetime import timedelta
66
from enum import IntEnum
77
from functools import reduce
@@ -11,6 +11,40 @@
1111
from torch.export import ExportedProgram
1212

1313

14+
# The maximum number of model output tensors to log statistics for. Most model tests will
15+
# only have one output, but some may return more than one tensor. This upper bound is needed
16+
# upfront since the file is written progressively. Any outputs beyond these will not have stats logged.
17+
MAX_LOGGED_MODEL_OUTPUTS = 2
18+
19+
20+
# Field names for the CSV report.
21+
CSV_FIELD_NAMES = [
22+
"Test ID",
23+
"Test Case",
24+
"Flow",
25+
"Params",
26+
"Result",
27+
"Result Detail",
28+
"Delegated",
29+
"Quantize Time (s)",
30+
"Lower Time (s)",
31+
"Delegated Nodes",
32+
"Undelegated Nodes",
33+
"Delegated Ops",
34+
"Undelegated Ops",
35+
"PTE Size (Kb)",
36+
]
37+
38+
for i in range(MAX_LOGGED_MODEL_OUTPUTS):
39+
CSV_FIELD_NAMES.extend(
40+
[
41+
f"Output {i} Error Max",
42+
f"Output {i} Error MAE",
43+
f"Output {i} SNR",
44+
]
45+
)
46+
47+
1448
# Operators that are excluded from the counts returned by count_ops. These are used to
1549
# exclude operatations that are not logically relevant or delegatable to backends.
1650
OP_COUNT_IGNORED_OPS = {
@@ -167,11 +201,15 @@ def is_delegated(self):
167201
)
168202

169203

204+
@dataclass
170205
class TestSessionState:
171-
test_case_summaries: list[TestCaseSummary]
206+
# True if the CSV header has been written to report__path.
207+
has_written_report_header: bool = False
172208

173-
def __init__(self):
174-
self.test_case_summaries = []
209+
# The file path to write the detail report to, if enabled.
210+
report_path: str | None = None
211+
212+
test_case_summaries: list[TestCaseSummary] = field(default_factory=list)
175213

176214

177215
@dataclass
@@ -249,11 +287,11 @@ def count_ops(program: dict[str, ExportedProgram] | ExportedProgram) -> Counter:
249287
)
250288

251289

252-
def begin_test_session():
290+
def begin_test_session(report_path: str | None):
253291
global _active_session
254292

255293
assert _active_session is None, "A test session is already active."
256-
_active_session = TestSessionState()
294+
_active_session = TestSessionState(report_path=report_path)
257295

258296

259297
def log_test_summary(summary: TestCaseSummary):
@@ -262,6 +300,15 @@ def log_test_summary(summary: TestCaseSummary):
262300
if _active_session is not None:
263301
_active_session.test_case_summaries.append(summary)
264302

303+
if _active_session.report_path is not None:
304+
file_mode = "a" if _active_session.has_written_report_header else "w"
305+
with open(_active_session.report_path, file_mode) as f:
306+
if not _active_session.has_written_report_header:
307+
write_csv_header(f)
308+
_active_session.has_written_report_header = True
309+
310+
write_csv_row(summary, f)
311+
265312

266313
def complete_test_session() -> RunSummary:
267314
global _active_session
@@ -280,6 +327,13 @@ def _sum_op_counts(counter: Counter | None) -> int | None:
280327
return sum(counter.values()) if counter is not None else None
281328

282329

330+
def _serialize_params(params: dict[str, Any] | None) -> str:
331+
if params is not None:
332+
return str(dict(sorted(params.items())))
333+
else:
334+
return ""
335+
336+
283337
def _serialize_op_counts(counter: Counter | None) -> str:
284338
"""
285339
A utility function to serialize op counts to a string, for the purpose of including
@@ -291,91 +345,49 @@ def _serialize_op_counts(counter: Counter | None) -> str:
291345
return ""
292346

293347

294-
def generate_csv_report(summary: RunSummary, output: TextIO):
295-
"""Write a run summary report to a file in CSV format."""
296-
297-
field_names = [
298-
"Test ID",
299-
"Test Case",
300-
"Flow",
301-
"Result",
302-
"Result Detail",
303-
"Delegated",
304-
"Quantize Time (s)",
305-
"Lower Time (s)",
306-
]
307-
308-
# Tests can have custom parameters. We'll want to report them here, so we need
309-
# a list of all unique parameter names.
310-
param_names = reduce(
311-
lambda a, b: a.union(b),
312-
(
313-
set(s.params.keys())
314-
for s in summary.test_case_summaries
315-
if s.params is not None
316-
),
317-
set(),
318-
)
319-
field_names += (s.capitalize() for s in param_names)
320-
321-
# Add tensor error statistic field names for each output index.
322-
max_outputs = max(
323-
len(s.tensor_error_statistics) for s in summary.test_case_summaries
324-
)
325-
for i in range(max_outputs):
326-
field_names.extend(
327-
[
328-
f"Output {i} Error Max",
329-
f"Output {i} Error MAE",
330-
f"Output {i} SNR",
331-
]
332-
)
333-
field_names.extend(
334-
[
335-
"Delegated Nodes",
336-
"Undelegated Nodes",
337-
"Delegated Ops",
338-
"Undelegated Ops",
339-
"PTE Size (Kb)",
340-
]
341-
)
342-
343-
writer = csv.DictWriter(output, field_names)
348+
def write_csv_header(output: TextIO):
349+
writer = csv.DictWriter(output, CSV_FIELD_NAMES)
344350
writer.writeheader()
345351

346-
for record in summary.test_case_summaries:
347-
row = {
348-
"Test ID": record.name,
349-
"Test Case": record.base_name,
350-
"Flow": record.flow,
351-
"Result": record.result.to_short_str(),
352-
"Result Detail": record.result.to_detail_str(),
353-
"Delegated": "True" if record.is_delegated() else "False",
354-
"Quantize Time (s)": (
355-
f"{record.quantize_time.total_seconds():.3f}"
356-
if record.quantize_time
357-
else None
358-
),
359-
"Lower Time (s)": (
360-
f"{record.lower_time.total_seconds():.3f}"
361-
if record.lower_time
362-
else None
363-
),
364-
}
365-
if record.params is not None:
366-
row.update({k.capitalize(): v for k, v in record.params.items()})
367-
368-
for output_idx, error_stats in enumerate(record.tensor_error_statistics):
369-
row[f"Output {output_idx} Error Max"] = f"{error_stats.error_max:.3f}"
370-
row[f"Output {output_idx} Error MAE"] = f"{error_stats.error_mae:.3f}"
371-
row[f"Output {output_idx} SNR"] = f"{error_stats.sqnr:.3f}"
372-
373-
row["Delegated Nodes"] = _sum_op_counts(record.delegated_op_counts)
374-
row["Undelegated Nodes"] = _sum_op_counts(record.undelegated_op_counts)
375-
row["Delegated Ops"] = _serialize_op_counts(record.delegated_op_counts)
376-
row["Undelegated Ops"] = _serialize_op_counts(record.undelegated_op_counts)
377-
row["PTE Size (Kb)"] = (
378-
f"{record.pte_size_bytes / 1000.0:.3f}" if record.pte_size_bytes else ""
379-
)
380352

381-
writer.writerow(row)
353+
def write_csv_row(record: TestCaseSummary, output: TextIO):
354+
writer = csv.DictWriter(output, CSV_FIELD_NAMES)
355+
356+
row = {
357+
"Test ID": record.name,
358+
"Test Case": record.base_name,
359+
"Flow": record.flow,
360+
"Params": _serialize_params(record.params),
361+
"Result": record.result.to_short_str(),
362+
"Result Detail": record.result.to_detail_str(),
363+
"Delegated": "True" if record.is_delegated() else "False",
364+
"Quantize Time (s)": (
365+
f"{record.quantize_time.total_seconds():.3f}"
366+
if record.quantize_time
367+
else None
368+
),
369+
"Lower Time (s)": (
370+
f"{record.lower_time.total_seconds():.3f}" if record.lower_time else None
371+
),
372+
}
373+
374+
for output_idx, error_stats in enumerate(record.tensor_error_statistics):
375+
if output_idx >= MAX_LOGGED_MODEL_OUTPUTS:
376+
print(
377+
f"Model output stats are truncated as model has more than {MAX_LOGGED_MODEL_OUTPUTS} outputs. Consider increasing MAX_LOGGED_MODEL_OUTPUTS."
378+
)
379+
break
380+
381+
row[f"Output {output_idx} Error Max"] = f"{error_stats.error_max:.3f}"
382+
row[f"Output {output_idx} Error MAE"] = f"{error_stats.error_mae:.3f}"
383+
row[f"Output {output_idx} SNR"] = f"{error_stats.sqnr:.3f}"
384+
385+
row["Delegated Nodes"] = _sum_op_counts(record.delegated_op_counts)
386+
row["Undelegated Nodes"] = _sum_op_counts(record.undelegated_op_counts)
387+
row["Delegated Ops"] = _serialize_op_counts(record.delegated_op_counts)
388+
row["Undelegated Ops"] = _serialize_op_counts(record.undelegated_op_counts)
389+
row["PTE Size (Kb)"] = (
390+
f"{record.pte_size_bytes / 1000.0:.3f}" if record.pte_size_bytes else ""
391+
)
392+
393+
writer.writerow(row)

backends/test/suite/runner.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
begin_test_session,
2626
complete_test_session,
2727
count_ops,
28-
generate_csv_report,
2928
RunSummary,
3029
TestCaseSummary,
3130
TestResult,
@@ -248,7 +247,7 @@ def build_test_filter(args: argparse.Namespace) -> TestFilter:
248247
def runner_main():
249248
args = parse_args()
250249

251-
begin_test_session()
250+
begin_test_session(args.report)
252251

253252
if len(args.suite) > 1:
254253
raise NotImplementedError("TODO Support multiple suites.")
@@ -263,11 +262,6 @@ def runner_main():
263262
summary = complete_test_session()
264263
print_summary(summary)
265264

266-
if args.report is not None:
267-
with open(args.report, "w") as f:
268-
print(f"Writing CSV report to {args.report}.")
269-
generate_csv_report(summary, f)
270-
271265

272266
if __name__ == "__main__":
273267
runner_main()

backends/test/suite/tests/test_reporting.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99

1010
from ..reporting import (
1111
count_ops,
12-
generate_csv_report,
1312
RunSummary,
1413
TestCaseSummary,
1514
TestResult,
1615
TestSessionState,
16+
write_csv_header,
17+
write_csv_row,
1718
)
1819

1920
# Test data for simulated test results.
@@ -69,7 +70,9 @@ def test_csv_report_simple(self):
6970
run_summary = RunSummary.from_session(session_state)
7071

7172
strio = StringIO()
72-
generate_csv_report(run_summary, strio)
73+
write_csv_header(strio)
74+
for case_summary in run_summary.test_case_summaries:
75+
write_csv_row(case_summary, strio)
7376

7477
# Attempt to deserialize and validate the CSV report.
7578
report = DictReader(StringIO(strio.getvalue()))
@@ -81,32 +84,28 @@ def test_csv_report_simple(self):
8184
self.assertEqual(records[0]["Test Case"], "test1")
8285
self.assertEqual(records[0]["Flow"], "flow1")
8386
self.assertEqual(records[0]["Result"], "Pass")
84-
self.assertEqual(records[0]["Dtype"], "")
85-
self.assertEqual(records[0]["Use_dynamic_shapes"], "")
87+
self.assertEqual(records[0]["Params"], "")
8688

8789
# Validate second record: test1, backend2, LOWER_FAIL
8890
self.assertEqual(records[1]["Test ID"], "test1_backend2_flow1")
8991
self.assertEqual(records[1]["Test Case"], "test1")
9092
self.assertEqual(records[1]["Flow"], "flow1")
9193
self.assertEqual(records[1]["Result"], "Fail")
92-
self.assertEqual(records[1]["Dtype"], "")
93-
self.assertEqual(records[1]["Use_dynamic_shapes"], "")
94+
self.assertEqual(records[1]["Params"], "")
9495

9596
# Validate third record: test2, backend1, SUCCESS_UNDELEGATED with dtype param
9697
self.assertEqual(records[2]["Test ID"], "test2_backend1_flow1")
9798
self.assertEqual(records[2]["Test Case"], "test2")
9899
self.assertEqual(records[2]["Flow"], "flow1")
99100
self.assertEqual(records[2]["Result"], "Pass")
100-
self.assertEqual(records[2]["Dtype"], str(torch.float32))
101-
self.assertEqual(records[2]["Use_dynamic_shapes"], "")
101+
self.assertEqual(records[2]["Params"], str({"dtype": torch.float32}))
102102

103103
# Validate fourth record: test2, backend2, EXPORT_FAIL with use_dynamic_shapes param
104104
self.assertEqual(records[3]["Test ID"], "test2_backend2_flow1")
105105
self.assertEqual(records[3]["Test Case"], "test2")
106106
self.assertEqual(records[3]["Flow"], "flow1")
107107
self.assertEqual(records[3]["Result"], "Skip")
108-
self.assertEqual(records[3]["Dtype"], "")
109-
self.assertEqual(records[3]["Use_dynamic_shapes"], "True")
108+
self.assertEqual(records[3]["Params"], str({"use_dynamic_shapes": True}))
110109

111110
def test_count_ops(self):
112111
"""

0 commit comments

Comments
 (0)