Skip to content

Commit cfbc408

Browse files
committed
Update
[ghstack-poisoned]
2 parents 375b076 + 285e058 commit cfbc408

File tree

3 files changed

+97
-61
lines changed

3 files changed

+97
-61
lines changed

backends/test/suite/reporting.py

Lines changed: 81 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import csv
22

33
from collections import Counter
4-
from dataclasses import dataclass, field
4+
from dataclasses import dataclass
55
from datetime import timedelta
66
from enum import IntEnum
77
from functools import reduce
@@ -205,15 +205,11 @@ def is_delegated(self):
205205
)
206206

207207

208-
@dataclass
209208
class TestSessionState:
210-
# True if the CSV header has been written to report__path.
211-
has_written_report_header: bool = False
212-
213-
# The file path to write the detail report to, if enabled.
214-
report_path: str | None = None
209+
test_case_summaries: list[TestCaseSummary]
215210

216-
test_case_summaries: list[TestCaseSummary] = field(default_factory=list)
211+
def __init__(self):
212+
self.test_case_summaries = []
217213

218214

219215
@dataclass
@@ -291,11 +287,11 @@ def count_ops(program: dict[str, ExportedProgram] | ExportedProgram) -> Counter:
291287
)
292288

293289

294-
def begin_test_session(report_path: str | None):
290+
def begin_test_session():
295291
global _active_session
296292

297293
assert _active_session is None, "A test session is already active."
298-
_active_session = TestSessionState(report_path=report_path)
294+
_active_session = TestSessionState()
299295

300296

301297
def log_test_summary(summary: TestCaseSummary):
@@ -304,15 +300,6 @@ def log_test_summary(summary: TestCaseSummary):
304300
if _active_session is not None:
305301
_active_session.test_case_summaries.append(summary)
306302

307-
if _active_session.report_path is not None:
308-
file_mode = "a" if _active_session.has_written_report_header else "w"
309-
with open(_active_session.report_path, file_mode) as f:
310-
if not _active_session.has_written_report_header:
311-
write_csv_header(f)
312-
_active_session.has_written_report_header = True
313-
314-
write_csv_row(summary, f)
315-
316303

317304
def complete_test_session() -> RunSummary:
318305
global _active_session
@@ -331,13 +318,6 @@ def _sum_op_counts(counter: Counter | None) -> int | None:
331318
return sum(counter.values()) if counter is not None else None
332319

333320

334-
def _serialize_params(params: dict[str, Any] | None) -> str:
335-
if params is not None:
336-
return str(dict(sorted(params.items())))
337-
else:
338-
return ""
339-
340-
341321
def _serialize_op_counts(counter: Counter | None) -> str:
342322
"""
343323
A utility function to serialize op counts to a string, for the purpose of including
@@ -349,10 +329,19 @@ def _serialize_op_counts(counter: Counter | None) -> str:
349329
return ""
350330

351331

352-
def write_csv_header(output: TextIO):
353-
writer = csv.DictWriter(output, CSV_FIELD_NAMES)
354-
writer.writeheader()
332+
def generate_csv_report(summary: RunSummary, output: TextIO):
333+
"""Write a run summary report to a file in CSV format."""
355334

335+
field_names = [
336+
"Test ID",
337+
"Test Case",
338+
"Flow",
339+
"Result",
340+
"Result Detail",
341+
"Delegated",
342+
"Quantize Time (s)",
343+
"Lower Time (s)",
344+
]
356345

357346
def write_csv_row(record: TestCaseSummary, output: TextIO):
358347
writer = csv.DictWriter(output, CSV_FIELD_NAMES)
@@ -371,28 +360,68 @@ def write_csv_row(record: TestCaseSummary, output: TextIO):
371360
if record.quantize_time
372361
else None
373362
),
374-
"Lower Time (s)": (
375-
f"{record.lower_time.total_seconds():.3f}" if record.lower_time else None
376-
),
377-
}
378-
379-
for output_idx, error_stats in enumerate(record.tensor_error_statistics):
380-
if output_idx >= MAX_LOGGED_MODEL_OUTPUTS:
381-
print(
382-
f"Model output stats are truncated as model has more than {MAX_LOGGED_MODEL_OUTPUTS} outputs. Consider increasing MAX_LOGGED_MODEL_OUTPUTS."
383-
)
384-
break
385-
386-
row[f"Output {output_idx} Error Max"] = f"{error_stats.error_max:.3f}"
387-
row[f"Output {output_idx} Error MAE"] = f"{error_stats.error_mae:.3f}"
388-
row[f"Output {output_idx} SNR"] = f"{error_stats.sqnr:.3f}"
389-
390-
row["Delegated Nodes"] = _sum_op_counts(record.delegated_op_counts)
391-
row["Undelegated Nodes"] = _sum_op_counts(record.undelegated_op_counts)
392-
row["Delegated Ops"] = _serialize_op_counts(record.delegated_op_counts)
393-
row["Undelegated Ops"] = _serialize_op_counts(record.undelegated_op_counts)
394-
row["PTE Size (Kb)"] = (
395-
f"{record.pte_size_bytes / 1000.0:.3f}" if record.pte_size_bytes else ""
363+
set(),
364+
)
365+
field_names += (s.capitalize() for s in param_names)
366+
367+
# Add tensor error statistic field names for each output index.
368+
max_outputs = max(
369+
len(s.tensor_error_statistics) for s in summary.test_case_summaries
396370
)
371+
for i in range(max_outputs):
372+
field_names.extend(
373+
[
374+
f"Output {i} Error Max",
375+
f"Output {i} Error MAE",
376+
f"Output {i} SNR",
377+
]
378+
)
379+
field_names.extend(
380+
[
381+
"Delegated Nodes",
382+
"Undelegated Nodes",
383+
"Delegated Ops",
384+
"Undelegated Ops",
385+
"PTE Size (Kb)",
386+
]
387+
)
388+
389+
writer = csv.DictWriter(output, field_names)
390+
writer.writeheader()
391+
392+
for record in summary.test_case_summaries:
393+
row = {
394+
"Test ID": record.name,
395+
"Test Case": record.base_name,
396+
"Flow": record.flow,
397+
"Result": record.result.to_short_str(),
398+
"Result Detail": record.result.to_detail_str(),
399+
"Delegated": "True" if record.is_delegated() else "False",
400+
"Quantize Time (s)": (
401+
f"{record.quantize_time.total_seconds():.3f}"
402+
if record.quantize_time
403+
else None
404+
),
405+
"Lower Time (s)": (
406+
f"{record.lower_time.total_seconds():.3f}"
407+
if record.lower_time
408+
else None
409+
),
410+
}
411+
if record.params is not None:
412+
row.update({k.capitalize(): v for k, v in record.params.items()})
413+
414+
for output_idx, error_stats in enumerate(record.tensor_error_statistics):
415+
row[f"Output {output_idx} Error Max"] = f"{error_stats.error_max:.3f}"
416+
row[f"Output {output_idx} Error MAE"] = f"{error_stats.error_mae:.3f}"
417+
row[f"Output {output_idx} SNR"] = f"{error_stats.sqnr:.3f}"
418+
419+
row["Delegated Nodes"] = _sum_op_counts(record.delegated_op_counts)
420+
row["Undelegated Nodes"] = _sum_op_counts(record.undelegated_op_counts)
421+
row["Delegated Ops"] = _serialize_op_counts(record.delegated_op_counts)
422+
row["Undelegated Ops"] = _serialize_op_counts(record.undelegated_op_counts)
423+
row["PTE Size (Kb)"] = (
424+
f"{record.pte_size_bytes / 1000.0:.3f}" if record.pte_size_bytes else ""
425+
)
397426

398-
writer.writerow(row)
427+
writer.writerow(row)

backends/test/suite/runner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
begin_test_session,
2727
complete_test_session,
2828
count_ops,
29+
generate_csv_report,
2930
RunSummary,
3031
TestCaseSummary,
3132
TestResult,
@@ -269,6 +270,11 @@ def runner_main():
269270
summary = complete_test_session()
270271
print_summary(summary)
271272

273+
if args.report is not None:
274+
with open(args.report, "w") as f:
275+
print(f"Writing CSV report to {args.report}.")
276+
generate_csv_report(summary, f)
277+
272278

273279
if __name__ == "__main__":
274280
runner_main()

backends/test/suite/tests/test_reporting.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99

1010
from ..reporting import (
1111
count_ops,
12+
generate_csv_report,
1213
RunSummary,
1314
TestCaseSummary,
1415
TestResult,
1516
TestSessionState,
16-
write_csv_header,
17-
write_csv_row,
1817
)
1918

2019
# Test data for simulated test results.
@@ -70,9 +69,7 @@ def test_csv_report_simple(self):
7069
run_summary = RunSummary.from_session(session_state)
7170

7271
strio = StringIO()
73-
write_csv_header(strio)
74-
for case_summary in run_summary.test_case_summaries:
75-
write_csv_row(case_summary, strio)
72+
generate_csv_report(run_summary, strio)
7673

7774
# Attempt to deserialize and validate the CSV report.
7875
report = DictReader(StringIO(strio.getvalue()))
@@ -84,28 +81,32 @@ def test_csv_report_simple(self):
8481
self.assertEqual(records[0]["Test Case"], "test1")
8582
self.assertEqual(records[0]["Flow"], "flow1")
8683
self.assertEqual(records[0]["Result"], "Pass")
87-
self.assertEqual(records[0]["Params"], "")
84+
self.assertEqual(records[0]["Dtype"], "")
85+
self.assertEqual(records[0]["Use_dynamic_shapes"], "")
8886

8987
# Validate second record: test1, backend2, LOWER_FAIL
9088
self.assertEqual(records[1]["Test ID"], "test1_backend2_flow1")
9189
self.assertEqual(records[1]["Test Case"], "test1")
9290
self.assertEqual(records[1]["Flow"], "flow1")
9391
self.assertEqual(records[1]["Result"], "Fail")
94-
self.assertEqual(records[1]["Params"], "")
92+
self.assertEqual(records[1]["Dtype"], "")
93+
self.assertEqual(records[1]["Use_dynamic_shapes"], "")
9594

9695
# Validate third record: test2, backend1, SUCCESS_UNDELEGATED with dtype param
9796
self.assertEqual(records[2]["Test ID"], "test2_backend1_flow1")
9897
self.assertEqual(records[2]["Test Case"], "test2")
9998
self.assertEqual(records[2]["Flow"], "flow1")
10099
self.assertEqual(records[2]["Result"], "Pass")
101-
self.assertEqual(records[2]["Params"], str({"dtype": torch.float32}))
100+
self.assertEqual(records[2]["Dtype"], str(torch.float32))
101+
self.assertEqual(records[2]["Use_dynamic_shapes"], "")
102102

103103
# Validate fourth record: test2, backend2, EXPORT_FAIL with use_dynamic_shapes param
104104
self.assertEqual(records[3]["Test ID"], "test2_backend2_flow1")
105105
self.assertEqual(records[3]["Test Case"], "test2")
106106
self.assertEqual(records[3]["Flow"], "flow1")
107107
self.assertEqual(records[3]["Result"], "Skip")
108-
self.assertEqual(records[3]["Params"], str({"use_dynamic_shapes": True}))
108+
self.assertEqual(records[3]["Dtype"], "")
109+
self.assertEqual(records[3]["Use_dynamic_shapes"], "True")
109110

110111
def test_count_ops(self):
111112
"""

0 commit comments

Comments
 (0)