Skip to content

Commit fc1c2a7

Browse files
committed
Update
[ghstack-poisoned]
2 parents 2c4488f + 78086b4 commit fc1c2a7

File tree

4 files changed

+109
-114
lines changed

4 files changed

+109
-114
lines changed

backends/test/suite/flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from dataclasses import dataclass, field
3+
from dataclasses import dataclass
44
from typing import Callable
55

66
from executorch.backends.test.harness import Tester

backends/test/suite/reporting.py

Lines changed: 91 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import csv
22

33
from collections import Counter
4-
from dataclasses import dataclass, field
4+
from dataclasses import dataclass
55
from datetime import timedelta
66
from enum import IntEnum
77
from functools import reduce
@@ -11,40 +11,6 @@
1111
from torch.export import ExportedProgram
1212

1313

14-
# The maximum number of model output tensors to log statistics for. Most model tests will
15-
# only have one output, but some may return more than one tensor. This upper bound is needed
16-
# upfront since the file is written progressively. Any outputs beyond these will not have stats logged.
17-
MAX_LOGGED_MODEL_OUTPUTS = 2
18-
19-
20-
# Field names for the CSV report.
21-
CSV_FIELD_NAMES = [
22-
"Test ID",
23-
"Test Case",
24-
"Flow",
25-
"Params",
26-
"Result",
27-
"Result Detail",
28-
"Delegated",
29-
"Quantize Time (s)",
30-
"Lower Time (s)",
31-
"Delegated Nodes",
32-
"Undelegated Nodes",
33-
"Delegated Ops",
34-
"Undelegated Ops",
35-
"PTE Size (Kb)",
36-
]
37-
38-
for i in range(MAX_LOGGED_MODEL_OUTPUTS):
39-
CSV_FIELD_NAMES.extend(
40-
[
41-
f"Output {i} Error Max",
42-
f"Output {i} Error MAE",
43-
f"Output {i} SNR",
44-
]
45-
)
46-
47-
4814
# Operators that are excluded from the counts returned by count_ops. These are used to
4915
# exclude operatations that are not logically relevant or delegatable to backends.
5016
OP_COUNT_IGNORED_OPS = {
@@ -201,15 +167,11 @@ def is_delegated(self):
201167
)
202168

203169

204-
@dataclass
205170
class TestSessionState:
206-
# True if the CSV header has been written to report__path.
207-
has_written_report_header: bool = False
208-
209-
# The file path to write the detail report to, if enabled.
210-
report_path: str | None = None
171+
test_case_summaries: list[TestCaseSummary]
211172

212-
test_case_summaries: list[TestCaseSummary] = field(default_factory=list)
173+
def __init__(self):
174+
self.test_case_summaries = []
213175

214176

215177
@dataclass
@@ -287,11 +249,11 @@ def count_ops(program: dict[str, ExportedProgram] | ExportedProgram) -> Counter:
287249
)
288250

289251

290-
def begin_test_session(report_path: str | None):
252+
def begin_test_session():
291253
global _active_session
292254

293255
assert _active_session is None, "A test session is already active."
294-
_active_session = TestSessionState(report_path=report_path)
256+
_active_session = TestSessionState()
295257

296258

297259
def log_test_summary(summary: TestCaseSummary):
@@ -300,15 +262,6 @@ def log_test_summary(summary: TestCaseSummary):
300262
if _active_session is not None:
301263
_active_session.test_case_summaries.append(summary)
302264

303-
if _active_session.report_path is not None:
304-
file_mode = "a" if _active_session.has_written_report_header else "w"
305-
with open(_active_session.report_path, file_mode) as f:
306-
if not _active_session.has_written_report_header:
307-
write_csv_header(f)
308-
_active_session.has_written_report_header = True
309-
310-
write_csv_row(summary, f)
311-
312265

313266
def complete_test_session() -> RunSummary:
314267
global _active_session
@@ -327,13 +280,6 @@ def _sum_op_counts(counter: Counter | None) -> int | None:
327280
return sum(counter.values()) if counter is not None else None
328281

329282

330-
def _serialize_params(params: dict[str, Any] | None) -> str:
331-
if params is not None:
332-
return str(dict(sorted(params.items())))
333-
else:
334-
return ""
335-
336-
337283
def _serialize_op_counts(counter: Counter | None) -> str:
338284
"""
339285
A utility function to serialize op counts to a string, for the purpose of including
@@ -345,49 +291,91 @@ def _serialize_op_counts(counter: Counter | None) -> str:
345291
return ""
346292

347293

348-
def write_csv_header(output: TextIO):
349-
writer = csv.DictWriter(output, CSV_FIELD_NAMES)
350-
writer.writeheader()
351-
352-
353-
def write_csv_row(record: TestCaseSummary, output: TextIO):
354-
writer = csv.DictWriter(output, CSV_FIELD_NAMES)
355-
356-
row = {
357-
"Test ID": record.name,
358-
"Test Case": record.base_name,
359-
"Flow": record.flow,
360-
"Params": _serialize_params(record.params),
361-
"Result": record.result.to_short_str(),
362-
"Result Detail": record.result.to_detail_str(),
363-
"Delegated": "True" if record.is_delegated() else "False",
364-
"Quantize Time (s)": (
365-
f"{record.quantize_time.total_seconds():.3f}"
366-
if record.quantize_time
367-
else None
368-
),
369-
"Lower Time (s)": (
370-
f"{record.lower_time.total_seconds():.3f}" if record.lower_time else None
294+
def generate_csv_report(summary: RunSummary, output: TextIO):
295+
"""Write a run summary report to a file in CSV format."""
296+
297+
field_names = [
298+
"Test ID",
299+
"Test Case",
300+
"Flow",
301+
"Result",
302+
"Result Detail",
303+
"Delegated",
304+
"Quantize Time (s)",
305+
"Lower Time (s)",
306+
]
307+
308+
# Tests can have custom parameters. We'll want to report them here, so we need
309+
# a list of all unique parameter names.
310+
param_names = reduce(
311+
lambda a, b: a.union(b),
312+
(
313+
set(s.params.keys())
314+
for s in summary.test_case_summaries
315+
if s.params is not None
371316
),
372-
}
373-
374-
for output_idx, error_stats in enumerate(record.tensor_error_statistics):
375-
if output_idx >= MAX_LOGGED_MODEL_OUTPUTS:
376-
print(
377-
f"Model output stats are truncated as model has more than {MAX_LOGGED_MODEL_OUTPUTS} outputs. Consider increasing MAX_LOGGED_MODEL_OUTPUTS."
378-
)
379-
break
380-
381-
row[f"Output {output_idx} Error Max"] = f"{error_stats.error_max:.3f}"
382-
row[f"Output {output_idx} Error MAE"] = f"{error_stats.error_mae:.3f}"
383-
row[f"Output {output_idx} SNR"] = f"{error_stats.sqnr:.3f}"
384-
385-
row["Delegated Nodes"] = _sum_op_counts(record.delegated_op_counts)
386-
row["Undelegated Nodes"] = _sum_op_counts(record.undelegated_op_counts)
387-
row["Delegated Ops"] = _serialize_op_counts(record.delegated_op_counts)
388-
row["Undelegated Ops"] = _serialize_op_counts(record.undelegated_op_counts)
389-
row["PTE Size (Kb)"] = (
390-
f"{record.pte_size_bytes / 1000.0:.3f}" if record.pte_size_bytes else ""
317+
set(),
318+
)
319+
field_names += (s.capitalize() for s in param_names)
320+
321+
# Add tensor error statistic field names for each output index.
322+
max_outputs = max(
323+
len(s.tensor_error_statistics) for s in summary.test_case_summaries
324+
)
325+
for i in range(max_outputs):
326+
field_names.extend(
327+
[
328+
f"Output {i} Error Max",
329+
f"Output {i} Error MAE",
330+
f"Output {i} SNR",
331+
]
332+
)
333+
field_names.extend(
334+
[
335+
"Delegated Nodes",
336+
"Undelegated Nodes",
337+
"Delegated Ops",
338+
"Undelegated Ops",
339+
"PTE Size (Kb)",
340+
]
391341
)
392342

393-
writer.writerow(row)
343+
writer = csv.DictWriter(output, field_names)
344+
writer.writeheader()
345+
346+
for record in summary.test_case_summaries:
347+
row = {
348+
"Test ID": record.name,
349+
"Test Case": record.base_name,
350+
"Flow": record.flow,
351+
"Result": record.result.to_short_str(),
352+
"Result Detail": record.result.to_detail_str(),
353+
"Delegated": "True" if record.is_delegated() else "False",
354+
"Quantize Time (s)": (
355+
f"{record.quantize_time.total_seconds():.3f}"
356+
if record.quantize_time
357+
else None
358+
),
359+
"Lower Time (s)": (
360+
f"{record.lower_time.total_seconds():.3f}"
361+
if record.lower_time
362+
else None
363+
),
364+
}
365+
if record.params is not None:
366+
row.update({k.capitalize(): v for k, v in record.params.items()})
367+
368+
for output_idx, error_stats in enumerate(record.tensor_error_statistics):
369+
row[f"Output {output_idx} Error Max"] = f"{error_stats.error_max:.3f}"
370+
row[f"Output {output_idx} Error MAE"] = f"{error_stats.error_mae:.3f}"
371+
row[f"Output {output_idx} SNR"] = f"{error_stats.sqnr:.3f}"
372+
373+
row["Delegated Nodes"] = _sum_op_counts(record.delegated_op_counts)
374+
row["Undelegated Nodes"] = _sum_op_counts(record.undelegated_op_counts)
375+
row["Delegated Ops"] = _serialize_op_counts(record.delegated_op_counts)
376+
row["Undelegated Ops"] = _serialize_op_counts(record.undelegated_op_counts)
377+
row["PTE Size (Kb)"] = (
378+
f"{record.pte_size_bytes / 1000.0:.3f}" if record.pte_size_bytes else ""
379+
)
380+
381+
writer.writerow(row)

backends/test/suite/runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
begin_test_session,
2626
complete_test_session,
2727
count_ops,
28+
generate_csv_report,
2829
RunSummary,
2930
TestCaseSummary,
3031
TestResult,
@@ -247,7 +248,7 @@ def build_test_filter(args: argparse.Namespace) -> TestFilter:
247248
def runner_main():
248249
args = parse_args()
249250

250-
begin_test_session(args.report)
251+
begin_test_session()
251252

252253
if len(args.suite) > 1:
253254
raise NotImplementedError("TODO Support multiple suites.")
@@ -262,6 +263,11 @@ def runner_main():
262263
summary = complete_test_session()
263264
print_summary(summary)
264265

266+
if args.report is not None:
267+
with open(args.report, "w") as f:
268+
print(f"Writing CSV report to {args.report}.")
269+
generate_csv_report(summary, f)
270+
265271

266272
if __name__ == "__main__":
267273
runner_main()

backends/test/suite/tests/test_reporting.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99

1010
from ..reporting import (
1111
count_ops,
12+
generate_csv_report,
1213
RunSummary,
1314
TestCaseSummary,
1415
TestResult,
1516
TestSessionState,
16-
write_csv_header,
17-
write_csv_row,
1817
)
1918

2019
# Test data for simulated test results.
@@ -70,9 +69,7 @@ def test_csv_report_simple(self):
7069
run_summary = RunSummary.from_session(session_state)
7170

7271
strio = StringIO()
73-
write_csv_header(strio)
74-
for case_summary in run_summary.test_case_summaries:
75-
write_csv_row(case_summary, strio)
72+
generate_csv_report(run_summary, strio)
7673

7774
# Attempt to deserialize and validate the CSV report.
7875
report = DictReader(StringIO(strio.getvalue()))
@@ -84,28 +81,32 @@ def test_csv_report_simple(self):
8481
self.assertEqual(records[0]["Test Case"], "test1")
8582
self.assertEqual(records[0]["Flow"], "flow1")
8683
self.assertEqual(records[0]["Result"], "Pass")
87-
self.assertEqual(records[0]["Params"], "")
84+
self.assertEqual(records[0]["Dtype"], "")
85+
self.assertEqual(records[0]["Use_dynamic_shapes"], "")
8886

8987
# Validate second record: test1, backend2, LOWER_FAIL
9088
self.assertEqual(records[1]["Test ID"], "test1_backend2_flow1")
9189
self.assertEqual(records[1]["Test Case"], "test1")
9290
self.assertEqual(records[1]["Flow"], "flow1")
9391
self.assertEqual(records[1]["Result"], "Fail")
94-
self.assertEqual(records[1]["Params"], "")
92+
self.assertEqual(records[1]["Dtype"], "")
93+
self.assertEqual(records[1]["Use_dynamic_shapes"], "")
9594

9695
# Validate third record: test2, backend1, SUCCESS_UNDELEGATED with dtype param
9796
self.assertEqual(records[2]["Test ID"], "test2_backend1_flow1")
9897
self.assertEqual(records[2]["Test Case"], "test2")
9998
self.assertEqual(records[2]["Flow"], "flow1")
10099
self.assertEqual(records[2]["Result"], "Pass")
101-
self.assertEqual(records[2]["Params"], str({"dtype": torch.float32}))
100+
self.assertEqual(records[2]["Dtype"], str(torch.float32))
101+
self.assertEqual(records[2]["Use_dynamic_shapes"], "")
102102

103103
# Validate fourth record: test2, backend2, EXPORT_FAIL with use_dynamic_shapes param
104104
self.assertEqual(records[3]["Test ID"], "test2_backend2_flow1")
105105
self.assertEqual(records[3]["Test Case"], "test2")
106106
self.assertEqual(records[3]["Flow"], "flow1")
107107
self.assertEqual(records[3]["Result"], "Skip")
108-
self.assertEqual(records[3]["Params"], str({"use_dynamic_shapes": True}))
108+
self.assertEqual(records[3]["Dtype"], "")
109+
self.assertEqual(records[3]["Use_dynamic_shapes"], "True")
109110

110111
def test_count_ops(self):
111112
"""

0 commit comments

Comments
 (0)