Skip to content

Commit 1b9315e

Browse files
committed
Update
[ghstack-poisoned]
2 parents 90ce443 + 94d89c4 commit 1b9315e

File tree

3 files changed

+62
-98
lines changed

3 files changed

+62
-98
lines changed

backends/test/suite/reporting.py

Lines changed: 52 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import csv
22

33
from collections import Counter
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
from datetime import timedelta
66
from enum import IntEnum
77
from functools import reduce
@@ -205,11 +205,15 @@ def is_delegated(self):
205205
)
206206

207207

208+
@dataclass
208209
class TestSessionState:
209-
test_case_summaries: list[TestCaseSummary]
210+
# True if the CSV header has been written to report__path.
211+
has_written_report_header: bool = False
212+
213+
# The file path to write the detail report to, if enabled.
214+
report_path: str | None = None
210215

211-
def __init__(self):
212-
self.test_case_summaries = []
216+
test_case_summaries: list[TestCaseSummary] = field(default_factory=list)
213217

214218

215219
@dataclass
@@ -287,11 +291,11 @@ def count_ops(program: dict[str, ExportedProgram] | ExportedProgram) -> Counter:
287291
)
288292

289293

290-
def begin_test_session():
294+
def begin_test_session(report_path: str | None):
291295
global _active_session
292296

293297
assert _active_session is None, "A test session is already active."
294-
_active_session = TestSessionState()
298+
_active_session = TestSessionState(report_path=report_path)
295299

296300

297301
def log_test_summary(summary: TestCaseSummary):
@@ -300,6 +304,15 @@ def log_test_summary(summary: TestCaseSummary):
300304
if _active_session is not None:
301305
_active_session.test_case_summaries.append(summary)
302306

307+
if _active_session.report_path is not None:
308+
file_mode = "a" if _active_session.has_written_report_header else "w"
309+
with open(_active_session.report_path, file_mode) as f:
310+
if not _active_session.has_written_report_header:
311+
write_csv_header(f)
312+
_active_session.has_written_report_header = True
313+
314+
write_csv_row(summary, f)
315+
303316

304317
def complete_test_session() -> RunSummary:
305318
global _active_session
@@ -318,6 +331,13 @@ def _sum_op_counts(counter: Counter | None) -> int | None:
318331
return sum(counter.values()) if counter is not None else None
319332

320333

334+
def _serialize_params(params: dict[str, Any] | None) -> str:
335+
if params is not None:
336+
return str(dict(sorted(params.items())))
337+
else:
338+
return ""
339+
340+
321341
def _serialize_op_counts(counter: Counter | None) -> str:
322342
"""
323343
A utility function to serialize op counts to a string, for the purpose of including
@@ -329,19 +349,10 @@ def _serialize_op_counts(counter: Counter | None) -> str:
329349
return ""
330350

331351

332-
def generate_csv_report(summary: RunSummary, output: TextIO):
333-
"""Write a run summary report to a file in CSV format."""
352+
def write_csv_header(output: TextIO):
353+
writer = csv.DictWriter(output, CSV_FIELD_NAMES)
354+
writer.writeheader()
334355

335-
field_names = [
336-
"Test ID",
337-
"Test Case",
338-
"Flow",
339-
"Result",
340-
"Result Detail",
341-
"Delegated",
342-
"Quantize Time (s)",
343-
"Lower Time (s)",
344-
]
345356

346357
def write_csv_row(record: TestCaseSummary, output: TextIO):
347358
writer = csv.DictWriter(output, CSV_FIELD_NAMES)
@@ -360,68 +371,28 @@ def write_csv_row(record: TestCaseSummary, output: TextIO):
360371
if record.quantize_time
361372
else None
362373
),
363-
set(),
364-
)
365-
field_names += (s.capitalize() for s in param_names)
366-
367-
# Add tensor error statistic field names for each output index.
368-
max_outputs = max(
369-
len(s.tensor_error_statistics) for s in summary.test_case_summaries
370-
)
371-
for i in range(max_outputs):
372-
field_names.extend(
373-
[
374-
f"Output {i} Error Max",
375-
f"Output {i} Error MAE",
376-
f"Output {i} SNR",
377-
]
378-
)
379-
field_names.extend(
380-
[
381-
"Delegated Nodes",
382-
"Undelegated Nodes",
383-
"Delegated Ops",
384-
"Undelegated Ops",
385-
"PTE Size (Kb)",
386-
]
374+
"Lower Time (s)": (
375+
f"{record.lower_time.total_seconds():.3f}" if record.lower_time else None
376+
),
377+
}
378+
379+
for output_idx, error_stats in enumerate(record.tensor_error_statistics):
380+
if output_idx >= MAX_LOGGED_MODEL_OUTPUTS:
381+
print(
382+
f"Model output stats are truncated as model has more than {MAX_LOGGED_MODEL_OUTPUTS} outputs. Consider increasing MAX_LOGGED_MODEL_OUTPUTS."
383+
)
384+
break
385+
386+
row[f"Output {output_idx} Error Max"] = f"{error_stats.error_max:.3f}"
387+
row[f"Output {output_idx} Error MAE"] = f"{error_stats.error_mae:.3f}"
388+
row[f"Output {output_idx} SNR"] = f"{error_stats.sqnr:.3f}"
389+
390+
row["Delegated Nodes"] = _sum_op_counts(record.delegated_op_counts)
391+
row["Undelegated Nodes"] = _sum_op_counts(record.undelegated_op_counts)
392+
row["Delegated Ops"] = _serialize_op_counts(record.delegated_op_counts)
393+
row["Undelegated Ops"] = _serialize_op_counts(record.undelegated_op_counts)
394+
row["PTE Size (Kb)"] = (
395+
f"{record.pte_size_bytes / 1000.0:.3f}" if record.pte_size_bytes else ""
387396
)
388397

389-
writer = csv.DictWriter(output, field_names)
390-
writer.writeheader()
391-
392-
for record in summary.test_case_summaries:
393-
row = {
394-
"Test ID": record.name,
395-
"Test Case": record.base_name,
396-
"Flow": record.flow,
397-
"Result": record.result.to_short_str(),
398-
"Result Detail": record.result.to_detail_str(),
399-
"Delegated": "True" if record.is_delegated() else "False",
400-
"Quantize Time (s)": (
401-
f"{record.quantize_time.total_seconds():.3f}"
402-
if record.quantize_time
403-
else None
404-
),
405-
"Lower Time (s)": (
406-
f"{record.lower_time.total_seconds():.3f}"
407-
if record.lower_time
408-
else None
409-
),
410-
}
411-
if record.params is not None:
412-
row.update({k.capitalize(): v for k, v in record.params.items()})
413-
414-
for output_idx, error_stats in enumerate(record.tensor_error_statistics):
415-
row[f"Output {output_idx} Error Max"] = f"{error_stats.error_max:.3f}"
416-
row[f"Output {output_idx} Error MAE"] = f"{error_stats.error_mae:.3f}"
417-
row[f"Output {output_idx} SNR"] = f"{error_stats.sqnr:.3f}"
418-
419-
row["Delegated Nodes"] = _sum_op_counts(record.delegated_op_counts)
420-
row["Undelegated Nodes"] = _sum_op_counts(record.undelegated_op_counts)
421-
row["Delegated Ops"] = _serialize_op_counts(record.delegated_op_counts)
422-
row["Undelegated Ops"] = _serialize_op_counts(record.undelegated_op_counts)
423-
row["PTE Size (Kb)"] = (
424-
f"{record.pte_size_bytes / 1000.0:.3f}" if record.pte_size_bytes else ""
425-
)
426-
427-
writer.writerow(row)
398+
writer.writerow(row)

backends/test/suite/runner.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
begin_test_session,
2626
complete_test_session,
2727
count_ops,
28-
generate_csv_report,
2928
RunSummary,
3029
TestCaseSummary,
3130
TestResult,
@@ -250,7 +249,7 @@ def build_test_filter(args: argparse.Namespace) -> TestFilter:
250249
def runner_main():
251250
args = parse_args()
252251

253-
begin_test_session()
252+
begin_test_session(args.report)
254253

255254
if len(args.suite) > 1:
256255
raise NotImplementedError("TODO Support multiple suites.")
@@ -265,11 +264,6 @@ def runner_main():
265264
summary = complete_test_session()
266265
print_summary(summary)
267266

268-
if args.report is not None:
269-
with open(args.report, "w") as f:
270-
print(f"Writing CSV report to {args.report}.")
271-
generate_csv_report(summary, f)
272-
273267

274268
if __name__ == "__main__":
275269
runner_main()

backends/test/suite/tests/test_reporting.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@
99

1010
from ..reporting import (
1111
count_ops,
12-
generate_csv_report,
1312
RunSummary,
1413
TestCaseSummary,
1514
TestResult,
1615
TestSessionState,
16+
write_csv_header,
17+
write_csv_row,
1718
)
1819

1920
# Test data for simulated test results.
@@ -69,7 +70,9 @@ def test_csv_report_simple(self):
6970
run_summary = RunSummary.from_session(session_state)
7071

7172
strio = StringIO()
72-
generate_csv_report(run_summary, strio)
73+
write_csv_header(strio)
74+
for case_summary in run_summary.test_case_summaries:
75+
write_csv_row(case_summary, strio)
7376

7477
# Attempt to deserialize and validate the CSV report.
7578
report = DictReader(StringIO(strio.getvalue()))
@@ -81,32 +84,28 @@ def test_csv_report_simple(self):
8184
self.assertEqual(records[0]["Test Case"], "test1")
8285
self.assertEqual(records[0]["Flow"], "flow1")
8386
self.assertEqual(records[0]["Result"], "Pass")
84-
self.assertEqual(records[0]["Dtype"], "")
85-
self.assertEqual(records[0]["Use_dynamic_shapes"], "")
87+
self.assertEqual(records[0]["Params"], "")
8688

8789
# Validate second record: test1, backend2, LOWER_FAIL
8890
self.assertEqual(records[1]["Test ID"], "test1_backend2_flow1")
8991
self.assertEqual(records[1]["Test Case"], "test1")
9092
self.assertEqual(records[1]["Flow"], "flow1")
9193
self.assertEqual(records[1]["Result"], "Fail")
92-
self.assertEqual(records[1]["Dtype"], "")
93-
self.assertEqual(records[1]["Use_dynamic_shapes"], "")
94+
self.assertEqual(records[1]["Params"], "")
9495

9596
# Validate third record: test2, backend1, SUCCESS_UNDELEGATED with dtype param
9697
self.assertEqual(records[2]["Test ID"], "test2_backend1_flow1")
9798
self.assertEqual(records[2]["Test Case"], "test2")
9899
self.assertEqual(records[2]["Flow"], "flow1")
99100
self.assertEqual(records[2]["Result"], "Pass")
100-
self.assertEqual(records[2]["Dtype"], str(torch.float32))
101-
self.assertEqual(records[2]["Use_dynamic_shapes"], "")
101+
self.assertEqual(records[2]["Params"], str({"dtype": torch.float32}))
102102

103103
# Validate fourth record: test2, backend2, EXPORT_FAIL with use_dynamic_shapes param
104104
self.assertEqual(records[3]["Test ID"], "test2_backend2_flow1")
105105
self.assertEqual(records[3]["Test Case"], "test2")
106106
self.assertEqual(records[3]["Flow"], "flow1")
107107
self.assertEqual(records[3]["Result"], "Skip")
108-
self.assertEqual(records[3]["Dtype"], "")
109-
self.assertEqual(records[3]["Use_dynamic_shapes"], "True")
108+
self.assertEqual(records[3]["Params"], str({"use_dynamic_shapes": True}))
110109

111110
def test_count_ops(self):
112111
"""

0 commit comments

Comments
 (0)