Skip to content

Commit ce6c435

Browse files
committed
[Backend Tester] Write report progressively
ghstack-source-id: 47a94ec ghstack-comment-id: 3177579936 Pull-Request: #13308
1 parent cd94169 commit ce6c435

File tree

3 files changed

+121
-108
lines changed

3 files changed

+121
-108
lines changed

backends/test/suite/reporting.py

Lines changed: 111 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import csv
22

33
from collections import Counter
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
from datetime import timedelta
66
from enum import IntEnum
77
from functools import reduce
@@ -11,6 +11,40 @@
1111
from torch.export import ExportedProgram
1212

1313

14+
# The maximum number of model output tensors to log statistics for. Most model tests will
15+
# only have one output, but some may return more than one tensor. This upper bound is needed
16+
# upfront since the file is written progressively. Any outputs beyond these will not have stats logged.
17+
MAX_LOGGED_MODEL_OUTPUTS = 2
18+
19+
20+
# Field names for the CSV report.
21+
CSV_FIELD_NAMES = [
22+
"Test ID",
23+
"Test Case",
24+
"Flow",
25+
"Params",
26+
"Result",
27+
"Result Detail",
28+
"Delegated",
29+
"Quantize Time (s)",
30+
"Lower Time (s)",
31+
"Delegated Nodes",
32+
"Undelegated Nodes",
33+
"Delegated Ops",
34+
"Undelegated Ops",
35+
"PTE Size (Kb)",
36+
]
37+
38+
for i in range(MAX_LOGGED_MODEL_OUTPUTS):
39+
CSV_FIELD_NAMES.extend(
40+
[
41+
f"Output {i} Error Max",
42+
f"Output {i} Error MAE",
43+
f"Output {i} SNR",
44+
]
45+
)
46+
47+
1448
# Operators that are excluded from the counts returned by count_ops. These are used to
1549
# exclude operatations that are not logically relevant or delegatable to backends.
1650
OP_COUNT_IGNORED_OPS = {
@@ -57,15 +91,15 @@ def is_non_backend_failure(self):
5791

5892
def is_backend_failure(self):
5993
return not self.is_success() and not self.is_non_backend_failure()
60-
94+
6195
def to_short_str(self):
62-
if self in { TestResult.SUCCESS, TestResult.SUCCESS_UNDELEGATED }:
96+
if self in {TestResult.SUCCESS, TestResult.SUCCESS_UNDELEGATED}:
6397
return "Pass"
6498
elif self == TestResult.SKIPPED:
6599
return "Skip"
66100
else:
67101
return "Fail"
68-
102+
69103
def to_detail_str(self):
70104
if self == TestResult.SUCCESS:
71105
return ""
@@ -160,14 +194,22 @@ class TestCaseSummary:
160194
""" The size of the PTE file in bytes. """
161195

162196
def is_delegated(self):
163-
return any(v > 0 for v in self.delegated_op_counts.values()) if self.delegated_op_counts else False
197+
return (
198+
any(v > 0 for v in self.delegated_op_counts.values())
199+
if self.delegated_op_counts
200+
else False
201+
)
164202

165203

204+
@dataclass
166205
class TestSessionState:
167-
test_case_summaries: list[TestCaseSummary]
206+
# True if the CSV header has been written to report__path.
207+
has_written_report_header: bool = False
168208

169-
def __init__(self):
170-
self.test_case_summaries = []
209+
# The file path to write the detail report to, if enabled.
210+
report_path: str | None = None
211+
212+
test_case_summaries: list[TestCaseSummary] = field(default_factory=list)
171213

172214

173215
@dataclass
@@ -245,11 +287,11 @@ def count_ops(program: dict[str, ExportedProgram] | ExportedProgram) -> Counter:
245287
)
246288

247289

248-
def begin_test_session():
290+
def begin_test_session(report_path: str | None):
249291
global _active_session
250292

251293
assert _active_session is None, "A test session is already active."
252-
_active_session = TestSessionState()
294+
_active_session = TestSessionState(report_path=report_path)
253295

254296

255297
def log_test_summary(summary: TestCaseSummary):
@@ -258,6 +300,15 @@ def log_test_summary(summary: TestCaseSummary):
258300
if _active_session is not None:
259301
_active_session.test_case_summaries.append(summary)
260302

303+
if _active_session.report_path is not None:
304+
file_mode = "a" if _active_session.has_written_report_header else "w"
305+
with open(_active_session.report_path, file_mode) as f:
306+
if not _active_session.has_written_report_header:
307+
write_csv_header(f)
308+
_active_session.has_written_report_header = True
309+
310+
write_csv_row(summary, f)
311+
261312

262313
def complete_test_session() -> RunSummary:
263314
global _active_session
@@ -276,6 +327,13 @@ def _sum_op_counts(counter: Counter | None) -> int | None:
276327
return sum(counter.values()) if counter is not None else None
277328

278329

330+
def _serialize_params(params: dict[str, Any] | None) -> str:
331+
if params is not None:
332+
return str(dict(sorted(params.items())))
333+
else:
334+
return ""
335+
336+
279337
def _serialize_op_counts(counter: Counter | None) -> str:
280338
"""
281339
A utility function to serialize op counts to a string, for the purpose of including
@@ -287,87 +345,49 @@ def _serialize_op_counts(counter: Counter | None) -> str:
287345
return ""
288346

289347

290-
def generate_csv_report(summary: RunSummary, output: TextIO):
291-
"""Write a run summary report to a file in CSV format."""
292-
293-
field_names = [
294-
"Test ID",
295-
"Test Case",
296-
"Flow",
297-
"Result",
298-
"Result Detail",
299-
"Delegated",
300-
"Quantize Time (s)",
301-
"Lower Time (s)",
302-
]
303-
304-
# Tests can have custom parameters. We'll want to report them here, so we need
305-
# a list of all unique parameter names.
306-
param_names = reduce(
307-
lambda a, b: a.union(b),
308-
(
309-
set(s.params.keys())
310-
for s in summary.test_case_summaries
311-
if s.params is not None
312-
),
313-
set(),
314-
)
315-
field_names += (s.capitalize() for s in param_names)
316-
317-
# Add tensor error statistic field names for each output index.
318-
max_outputs = max(
319-
len(s.tensor_error_statistics) for s in summary.test_case_summaries
320-
)
321-
for i in range(max_outputs):
322-
field_names.extend(
323-
[
324-
f"Output {i} Error Max",
325-
f"Output {i} Error MAE",
326-
f"Output {i} SNR",
327-
]
328-
)
329-
field_names.extend(
330-
[
331-
"Delegated Nodes",
332-
"Undelegated Nodes",
333-
"Delegated Ops",
334-
"Undelegated Ops",
335-
"PTE Size (Kb)",
336-
]
337-
)
338-
339-
writer = csv.DictWriter(output, field_names)
348+
def write_csv_header(output: TextIO):
349+
writer = csv.DictWriter(output, CSV_FIELD_NAMES)
340350
writer.writeheader()
341351

342-
for record in summary.test_case_summaries:
343-
row = {
344-
"Test ID": record.name,
345-
"Test Case": record.base_name,
346-
"Flow": record.flow,
347-
"Result": record.result.to_short_str(),
348-
"Result Detail": record.result.to_detail_str(),
349-
"Delegated": "True" if record.is_delegated() else "False",
350-
"Quantize Time (s)": (
351-
f"{record.quantize_time.total_seconds():.3f}" if record.quantize_time else None
352-
),
353-
"Lower Time (s)": (
354-
f"{record.lower_time.total_seconds():.3f}" if record.lower_time else None
355-
),
356-
}
357-
if record.params is not None:
358-
row.update({k.capitalize(): v for k, v in record.params.items()})
359-
360-
for output_idx, error_stats in enumerate(record.tensor_error_statistics):
361-
row[f"Output {output_idx} Error Max"] = f"{error_stats.error_max:.3f}"
362-
row[f"Output {output_idx} Error MAE"] = f"{error_stats.error_mae:.3f}"
363-
row[f"Output {output_idx} SNR"] = f"{error_stats.sqnr:.3f}"
364-
365-
row["Delegated Nodes"] = _sum_op_counts(record.delegated_op_counts)
366-
row["Undelegated Nodes"] = _sum_op_counts(record.undelegated_op_counts)
367-
row["Delegated Ops"] = _serialize_op_counts(record.delegated_op_counts)
368-
row["Undelegated Ops"] = _serialize_op_counts(record.undelegated_op_counts)
369-
row["PTE Size (Kb)"] = (
370-
f"{record.pte_size_bytes / 1000.0:.3f}" if record.pte_size_bytes else ""
371-
)
372352

373-
writer.writerow(row)
353+
def write_csv_row(record: TestCaseSummary, output: TextIO):
354+
writer = csv.DictWriter(output, CSV_FIELD_NAMES)
355+
356+
row = {
357+
"Test ID": record.name,
358+
"Test Case": record.base_name,
359+
"Flow": record.flow,
360+
"Params": _serialize_params(record.params),
361+
"Result": record.result.to_short_str(),
362+
"Result Detail": record.result.to_detail_str(),
363+
"Delegated": "True" if record.is_delegated() else "False",
364+
"Quantize Time (s)": (
365+
f"{record.quantize_time.total_seconds():.3f}"
366+
if record.quantize_time
367+
else None
368+
),
369+
"Lower Time (s)": (
370+
f"{record.lower_time.total_seconds():.3f}" if record.lower_time else None
371+
),
372+
}
373+
374+
for output_idx, error_stats in enumerate(record.tensor_error_statistics):
375+
if output_idx >= MAX_LOGGED_MODEL_OUTPUTS:
376+
print(
377+
f"Model output stats are truncated as model has more than {MAX_LOGGED_MODEL_OUTPUTS} outputs. Consider increasing MAX_LOGGED_MODEL_OUTPUTS."
378+
)
379+
break
380+
381+
row[f"Output {output_idx} Error Max"] = f"{error_stats.error_max:.3f}"
382+
row[f"Output {output_idx} Error MAE"] = f"{error_stats.error_mae:.3f}"
383+
row[f"Output {output_idx} SNR"] = f"{error_stats.sqnr:.3f}"
384+
385+
row["Delegated Nodes"] = _sum_op_counts(record.delegated_op_counts)
386+
row["Undelegated Nodes"] = _sum_op_counts(record.undelegated_op_counts)
387+
row["Delegated Ops"] = _serialize_op_counts(record.delegated_op_counts)
388+
row["Undelegated Ops"] = _serialize_op_counts(record.undelegated_op_counts)
389+
row["PTE Size (Kb)"] = (
390+
f"{record.pte_size_bytes / 1000.0:.3f}" if record.pte_size_bytes else ""
391+
)
392+
393+
writer.writerow(row)

backends/test/suite/runner.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
begin_test_session,
2626
complete_test_session,
2727
count_ops,
28-
generate_csv_report,
2928
RunSummary,
3029
TestCaseSummary,
3130
TestResult,
@@ -248,7 +247,7 @@ def build_test_filter(args: argparse.Namespace) -> TestFilter:
248247
def runner_main():
249248
args = parse_args()
250249

251-
begin_test_session()
250+
begin_test_session(args.report)
252251

253252
if len(args.suite) > 1:
254253
raise NotImplementedError("TODO Support multiple suites.")
@@ -263,11 +262,6 @@ def runner_main():
263262
summary = complete_test_session()
264263
print_summary(summary)
265264

266-
if args.report is not None:
267-
with open(args.report, "w") as f:
268-
print(f"Writing CSV report to {args.report}.")
269-
generate_csv_report(summary, f)
270-
271265

272266
if __name__ == "__main__":
273267
runner_main()

backends/test/suite/tests/test_reporting.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99

1010
from ..reporting import (
1111
count_ops,
12-
generate_csv_report,
1312
RunSummary,
1413
TestCaseSummary,
1514
TestResult,
1615
TestSessionState,
16+
write_csv_header,
17+
write_csv_row,
1718
)
1819

1920
# Test data for simulated test results.
@@ -69,7 +70,9 @@ def test_csv_report_simple(self):
6970
run_summary = RunSummary.from_session(session_state)
7071

7172
strio = StringIO()
72-
generate_csv_report(run_summary, strio)
73+
write_csv_header(strio)
74+
for case_summary in run_summary.test_case_summaries:
75+
write_csv_row(case_summary, strio)
7376

7477
# Attempt to deserialize and validate the CSV report.
7578
report = DictReader(StringIO(strio.getvalue()))
@@ -81,32 +84,28 @@ def test_csv_report_simple(self):
8184
self.assertEqual(records[0]["Test Case"], "test1")
8285
self.assertEqual(records[0]["Flow"], "flow1")
8386
self.assertEqual(records[0]["Result"], "Pass")
84-
self.assertEqual(records[0]["Dtype"], "")
85-
self.assertEqual(records[0]["Use_dynamic_shapes"], "")
87+
self.assertEqual(records[0]["Params"], "")
8688

8789
# Validate second record: test1, backend2, LOWER_FAIL
8890
self.assertEqual(records[1]["Test ID"], "test1_backend2_flow1")
8991
self.assertEqual(records[1]["Test Case"], "test1")
9092
self.assertEqual(records[1]["Flow"], "flow1")
9193
self.assertEqual(records[1]["Result"], "Fail")
92-
self.assertEqual(records[1]["Dtype"], "")
93-
self.assertEqual(records[1]["Use_dynamic_shapes"], "")
94+
self.assertEqual(records[1]["Params"], "")
9495

9596
# Validate third record: test2, backend1, SUCCESS_UNDELEGATED with dtype param
9697
self.assertEqual(records[2]["Test ID"], "test2_backend1_flow1")
9798
self.assertEqual(records[2]["Test Case"], "test2")
9899
self.assertEqual(records[2]["Flow"], "flow1")
99100
self.assertEqual(records[2]["Result"], "Pass")
100-
self.assertEqual(records[2]["Dtype"], str(torch.float32))
101-
self.assertEqual(records[2]["Use_dynamic_shapes"], "")
101+
self.assertEqual(records[2]["Params"], str({"dtype": torch.float32}))
102102

103103
# Validate fourth record: test2, backend2, EXPORT_FAIL with use_dynamic_shapes param
104104
self.assertEqual(records[3]["Test ID"], "test2_backend2_flow1")
105105
self.assertEqual(records[3]["Test Case"], "test2")
106106
self.assertEqual(records[3]["Flow"], "flow1")
107107
self.assertEqual(records[3]["Result"], "Skip")
108-
self.assertEqual(records[3]["Dtype"], "")
109-
self.assertEqual(records[3]["Use_dynamic_shapes"], "True")
108+
self.assertEqual(records[3]["Params"], str({"use_dynamic_shapes": True}))
110109

111110
def test_count_ops(self):
112111
"""

0 commit comments

Comments
 (0)