11import csv
22
33from collections import Counter
4- from dataclasses import dataclass
4+ from dataclasses import dataclass , field
55from datetime import timedelta
66from enum import IntEnum
77from functools import reduce
1111from torch .export import ExportedProgram
1212
1313
14+ # The maximum number of model output tensors to log statistics for. Most model tests will
15+ # only have one output, but some may return more than one tensor. This upper bound is needed
16+ # upfront since the file is written progressively. Any outputs beyond these will not have stats logged.
17+ MAX_LOGGED_MODEL_OUTPUTS = 2
18+
19+
20+ # Field names for the CSV report.
21+ CSV_FIELD_NAMES = [
22+ "Test ID" ,
23+ "Test Case" ,
24+ "Flow" ,
25+ "Params" ,
26+ "Result" ,
27+ "Result Detail" ,
28+ "Delegated" ,
29+ "Quantize Time (s)" ,
30+ "Lower Time (s)" ,
31+ "Delegated Nodes" ,
32+ "Undelegated Nodes" ,
33+ "Delegated Ops" ,
34+ "Undelegated Ops" ,
35+ "PTE Size (Kb)" ,
36+ ]
37+
38+ for i in range (MAX_LOGGED_MODEL_OUTPUTS ):
39+ CSV_FIELD_NAMES .extend (
40+ [
41+ f"Output { i } Error Max" ,
42+ f"Output { i } Error MAE" ,
43+ f"Output { i } SNR" ,
44+ ]
45+ )
46+
47+
1448# Operators that are excluded from the counts returned by count_ops. These are used to
1549# exclude operatations that are not logically relevant or delegatable to backends.
1650OP_COUNT_IGNORED_OPS = {
@@ -57,15 +91,15 @@ def is_non_backend_failure(self):
5791
5892 def is_backend_failure (self ):
5993 return not self .is_success () and not self .is_non_backend_failure ()
60-
94+
6195 def to_short_str (self ):
62- if self in { TestResult .SUCCESS , TestResult .SUCCESS_UNDELEGATED }:
96+ if self in {TestResult .SUCCESS , TestResult .SUCCESS_UNDELEGATED }:
6397 return "Pass"
6498 elif self == TestResult .SKIPPED :
6599 return "Skip"
66100 else :
67101 return "Fail"
68-
102+
69103 def to_detail_str (self ):
70104 if self == TestResult .SUCCESS :
71105 return ""
@@ -160,14 +194,22 @@ class TestCaseSummary:
160194 """ The size of the PTE file in bytes. """
161195
162196 def is_delegated (self ):
163- return any (v > 0 for v in self .delegated_op_counts .values ()) if self .delegated_op_counts else False
197+ return (
198+ any (v > 0 for v in self .delegated_op_counts .values ())
199+ if self .delegated_op_counts
200+ else False
201+ )
164202
165203
204+ @dataclass
166205class TestSessionState :
167- test_case_summaries : list [TestCaseSummary ]
206+ # True if the CSV header has been written to report__path.
207+ has_written_report_header : bool = False
168208
169- def __init__ (self ):
170- self .test_case_summaries = []
209+ # The file path to write the detail report to, if enabled.
210+ report_path : str | None = None
211+
212+ test_case_summaries : list [TestCaseSummary ] = field (default_factory = list )
171213
172214
173215@dataclass
@@ -245,11 +287,11 @@ def count_ops(program: dict[str, ExportedProgram] | ExportedProgram) -> Counter:
245287 )
246288
247289
248- def begin_test_session ():
290+ def begin_test_session (report_path : str | None ):
249291 global _active_session
250292
251293 assert _active_session is None , "A test session is already active."
252- _active_session = TestSessionState ()
294+ _active_session = TestSessionState (report_path = report_path )
253295
254296
255297def log_test_summary (summary : TestCaseSummary ):
@@ -258,6 +300,15 @@ def log_test_summary(summary: TestCaseSummary):
258300 if _active_session is not None :
259301 _active_session .test_case_summaries .append (summary )
260302
303+ if _active_session .report_path is not None :
304+ file_mode = "a" if _active_session .has_written_report_header else "w"
305+ with open (_active_session .report_path , file_mode ) as f :
306+ if not _active_session .has_written_report_header :
307+ write_csv_header (f )
308+ _active_session .has_written_report_header = True
309+
310+ write_csv_row (summary , f )
311+
261312
262313def complete_test_session () -> RunSummary :
263314 global _active_session
@@ -276,6 +327,13 @@ def _sum_op_counts(counter: Counter | None) -> int | None:
276327 return sum (counter .values ()) if counter is not None else None
277328
278329
330+ def _serialize_params (params : dict [str , Any ] | None ) -> str :
331+ if params is not None :
332+ return str (dict (sorted (params .items ())))
333+ else :
334+ return ""
335+
336+
279337def _serialize_op_counts (counter : Counter | None ) -> str :
280338 """
281339 A utility function to serialize op counts to a string, for the purpose of including
@@ -287,87 +345,49 @@ def _serialize_op_counts(counter: Counter | None) -> str:
287345 return ""
288346
289347
290- def generate_csv_report (summary : RunSummary , output : TextIO ):
291- """Write a run summary report to a file in CSV format."""
292-
293- field_names = [
294- "Test ID" ,
295- "Test Case" ,
296- "Flow" ,
297- "Result" ,
298- "Result Detail" ,
299- "Delegated" ,
300- "Quantize Time (s)" ,
301- "Lower Time (s)" ,
302- ]
303-
304- # Tests can have custom parameters. We'll want to report them here, so we need
305- # a list of all unique parameter names.
306- param_names = reduce (
307- lambda a , b : a .union (b ),
308- (
309- set (s .params .keys ())
310- for s in summary .test_case_summaries
311- if s .params is not None
312- ),
313- set (),
314- )
315- field_names += (s .capitalize () for s in param_names )
316-
317- # Add tensor error statistic field names for each output index.
318- max_outputs = max (
319- len (s .tensor_error_statistics ) for s in summary .test_case_summaries
320- )
321- for i in range (max_outputs ):
322- field_names .extend (
323- [
324- f"Output { i } Error Max" ,
325- f"Output { i } Error MAE" ,
326- f"Output { i } SNR" ,
327- ]
328- )
329- field_names .extend (
330- [
331- "Delegated Nodes" ,
332- "Undelegated Nodes" ,
333- "Delegated Ops" ,
334- "Undelegated Ops" ,
335- "PTE Size (Kb)" ,
336- ]
337- )
338-
339- writer = csv .DictWriter (output , field_names )
348+ def write_csv_header (output : TextIO ):
349+ writer = csv .DictWriter (output , CSV_FIELD_NAMES )
340350 writer .writeheader ()
341351
342- for record in summary .test_case_summaries :
343- row = {
344- "Test ID" : record .name ,
345- "Test Case" : record .base_name ,
346- "Flow" : record .flow ,
347- "Result" : record .result .to_short_str (),
348- "Result Detail" : record .result .to_detail_str (),
349- "Delegated" : "True" if record .is_delegated () else "False" ,
350- "Quantize Time (s)" : (
351- f"{ record .quantize_time .total_seconds ():.3f} " if record .quantize_time else None
352- ),
353- "Lower Time (s)" : (
354- f"{ record .lower_time .total_seconds ():.3f} " if record .lower_time else None
355- ),
356- }
357- if record .params is not None :
358- row .update ({k .capitalize (): v for k , v in record .params .items ()})
359-
360- for output_idx , error_stats in enumerate (record .tensor_error_statistics ):
361- row [f"Output { output_idx } Error Max" ] = f"{ error_stats .error_max :.3f} "
362- row [f"Output { output_idx } Error MAE" ] = f"{ error_stats .error_mae :.3f} "
363- row [f"Output { output_idx } SNR" ] = f"{ error_stats .sqnr :.3f} "
364-
365- row ["Delegated Nodes" ] = _sum_op_counts (record .delegated_op_counts )
366- row ["Undelegated Nodes" ] = _sum_op_counts (record .undelegated_op_counts )
367- row ["Delegated Ops" ] = _serialize_op_counts (record .delegated_op_counts )
368- row ["Undelegated Ops" ] = _serialize_op_counts (record .undelegated_op_counts )
369- row ["PTE Size (Kb)" ] = (
370- f"{ record .pte_size_bytes / 1000.0 :.3f} " if record .pte_size_bytes else ""
371- )
372352
373- writer .writerow (row )
353+ def write_csv_row (record : TestCaseSummary , output : TextIO ):
354+ writer = csv .DictWriter (output , CSV_FIELD_NAMES )
355+
356+ row = {
357+ "Test ID" : record .name ,
358+ "Test Case" : record .base_name ,
359+ "Flow" : record .flow ,
360+ "Params" : _serialize_params (record .params ),
361+ "Result" : record .result .to_short_str (),
362+ "Result Detail" : record .result .to_detail_str (),
363+ "Delegated" : "True" if record .is_delegated () else "False" ,
364+ "Quantize Time (s)" : (
365+ f"{ record .quantize_time .total_seconds ():.3f} "
366+ if record .quantize_time
367+ else None
368+ ),
369+ "Lower Time (s)" : (
370+ f"{ record .lower_time .total_seconds ():.3f} " if record .lower_time else None
371+ ),
372+ }
373+
374+ for output_idx , error_stats in enumerate (record .tensor_error_statistics ):
375+ if output_idx >= MAX_LOGGED_MODEL_OUTPUTS :
376+ print (
377+ f"Model output stats are truncated as model has more than { MAX_LOGGED_MODEL_OUTPUTS } outputs. Consider increasing MAX_LOGGED_MODEL_OUTPUTS."
378+ )
379+ break
380+
381+ row [f"Output { output_idx } Error Max" ] = f"{ error_stats .error_max :.3f} "
382+ row [f"Output { output_idx } Error MAE" ] = f"{ error_stats .error_mae :.3f} "
383+ row [f"Output { output_idx } SNR" ] = f"{ error_stats .sqnr :.3f} "
384+
385+ row ["Delegated Nodes" ] = _sum_op_counts (record .delegated_op_counts )
386+ row ["Undelegated Nodes" ] = _sum_op_counts (record .undelegated_op_counts )
387+ row ["Delegated Ops" ] = _serialize_op_counts (record .delegated_op_counts )
388+ row ["Undelegated Ops" ] = _serialize_op_counts (record .undelegated_op_counts )
389+ row ["PTE Size (Kb)" ] = (
390+ f"{ record .pte_size_bytes / 1000.0 :.3f} " if record .pte_size_bytes else ""
391+ )
392+
393+ writer .writerow (row )
0 commit comments