11import csv
22
33from collections import Counter
4- from dataclasses import dataclass
4+ from dataclasses import dataclass , field
55from datetime import timedelta
66from enum import IntEnum
77from functools import reduce
1111from torch .export import ExportedProgram
1212
1313
14+ # The maximum number of model output tensors to log statistics for. Most model tests will
15+ # only have one output, but some may return more than one tensor. This upper bound is needed
16+ # upfront since the file is written progressively. Any outputs beyond these will not have stats logged.
17+ MAX_LOGGED_MODEL_OUTPUTS = 2
18+
19+
20+ # Field names for the CSV report.
21+ CSV_FIELD_NAMES = [
22+ "Test ID" ,
23+ "Test Case" ,
24+ "Flow" ,
25+ "Params" ,
26+ "Result" ,
27+ "Result Detail" ,
28+ "Delegated" ,
29+ "Quantize Time (s)" ,
30+ "Lower Time (s)" ,
31+ "Delegated Nodes" ,
32+ "Undelegated Nodes" ,
33+ "Delegated Ops" ,
34+ "Undelegated Ops" ,
35+ "PTE Size (Kb)" ,
36+ ]
37+
38+ for i in range (MAX_LOGGED_MODEL_OUTPUTS ):
39+ CSV_FIELD_NAMES .extend (
40+ [
41+ f"Output { i } Error Max" ,
42+ f"Output { i } Error MAE" ,
43+ f"Output { i } SNR" ,
44+ ]
45+ )
46+
47+
1448# Operators that are excluded from the counts returned by count_ops. These are used to
1549# exclude operatations that are not logically relevant or delegatable to backends.
1650OP_COUNT_IGNORED_OPS = {
@@ -167,11 +201,15 @@ def is_delegated(self):
167201 )
168202
169203
204+ @dataclass
170205class TestSessionState :
171- test_case_summaries : list [TestCaseSummary ]
206+ # True if the CSV header has been written to report__path.
207+ has_written_report_header : bool = False
172208
173- def __init__ (self ):
174- self .test_case_summaries = []
209+ # The file path to write the detail report to, if enabled.
210+ report_path : str | None = None
211+
212+ test_case_summaries : list [TestCaseSummary ] = field (default_factory = list )
175213
176214
177215@dataclass
@@ -249,11 +287,11 @@ def count_ops(program: dict[str, ExportedProgram] | ExportedProgram) -> Counter:
249287 )
250288
251289
252- def begin_test_session ():
290+ def begin_test_session (report_path : str | None ):
253291 global _active_session
254292
255293 assert _active_session is None , "A test session is already active."
256- _active_session = TestSessionState ()
294+ _active_session = TestSessionState (report_path = report_path )
257295
258296
259297def log_test_summary (summary : TestCaseSummary ):
@@ -262,6 +300,15 @@ def log_test_summary(summary: TestCaseSummary):
262300 if _active_session is not None :
263301 _active_session .test_case_summaries .append (summary )
264302
303+ if _active_session .report_path is not None :
304+ file_mode = "a" if _active_session .has_written_report_header else "w"
305+ with open (_active_session .report_path , file_mode ) as f :
306+ if not _active_session .has_written_report_header :
307+ write_csv_header (f )
308+ _active_session .has_written_report_header = True
309+
310+ write_csv_row (summary , f )
311+
265312
266313def complete_test_session () -> RunSummary :
267314 global _active_session
@@ -280,6 +327,13 @@ def _sum_op_counts(counter: Counter | None) -> int | None:
280327 return sum (counter .values ()) if counter is not None else None
281328
282329
330+ def _serialize_params (params : dict [str , Any ] | None ) -> str :
331+ if params is not None :
332+ return str (dict (sorted (params .items ())))
333+ else :
334+ return ""
335+
336+
283337def _serialize_op_counts (counter : Counter | None ) -> str :
284338 """
285339 A utility function to serialize op counts to a string, for the purpose of including
@@ -291,91 +345,49 @@ def _serialize_op_counts(counter: Counter | None) -> str:
291345 return ""
292346
293347
294- def generate_csv_report (summary : RunSummary , output : TextIO ):
295- """Write a run summary report to a file in CSV format."""
296-
297- field_names = [
298- "Test ID" ,
299- "Test Case" ,
300- "Flow" ,
301- "Result" ,
302- "Result Detail" ,
303- "Delegated" ,
304- "Quantize Time (s)" ,
305- "Lower Time (s)" ,
306- ]
307-
308- # Tests can have custom parameters. We'll want to report them here, so we need
309- # a list of all unique parameter names.
310- param_names = reduce (
311- lambda a , b : a .union (b ),
312- (
313- set (s .params .keys ())
314- for s in summary .test_case_summaries
315- if s .params is not None
316- ),
317- set (),
318- )
319- field_names += (s .capitalize () for s in param_names )
320-
321- # Add tensor error statistic field names for each output index.
322- max_outputs = max (
323- len (s .tensor_error_statistics ) for s in summary .test_case_summaries
324- )
325- for i in range (max_outputs ):
326- field_names .extend (
327- [
328- f"Output { i } Error Max" ,
329- f"Output { i } Error MAE" ,
330- f"Output { i } SNR" ,
331- ]
332- )
333- field_names .extend (
334- [
335- "Delegated Nodes" ,
336- "Undelegated Nodes" ,
337- "Delegated Ops" ,
338- "Undelegated Ops" ,
339- "PTE Size (Kb)" ,
340- ]
341- )
342-
343- writer = csv .DictWriter (output , field_names )
348+ def write_csv_header (output : TextIO ):
349+ writer = csv .DictWriter (output , CSV_FIELD_NAMES )
344350 writer .writeheader ()
345351
346- for record in summary .test_case_summaries :
347- row = {
348- "Test ID" : record .name ,
349- "Test Case" : record .base_name ,
350- "Flow" : record .flow ,
351- "Result" : record .result .to_short_str (),
352- "Result Detail" : record .result .to_detail_str (),
353- "Delegated" : "True" if record .is_delegated () else "False" ,
354- "Quantize Time (s)" : (
355- f"{ record .quantize_time .total_seconds ():.3f} "
356- if record .quantize_time
357- else None
358- ),
359- "Lower Time (s)" : (
360- f"{ record .lower_time .total_seconds ():.3f} "
361- if record .lower_time
362- else None
363- ),
364- }
365- if record .params is not None :
366- row .update ({k .capitalize (): v for k , v in record .params .items ()})
367-
368- for output_idx , error_stats in enumerate (record .tensor_error_statistics ):
369- row [f"Output { output_idx } Error Max" ] = f"{ error_stats .error_max :.3f} "
370- row [f"Output { output_idx } Error MAE" ] = f"{ error_stats .error_mae :.3f} "
371- row [f"Output { output_idx } SNR" ] = f"{ error_stats .sqnr :.3f} "
372-
373- row ["Delegated Nodes" ] = _sum_op_counts (record .delegated_op_counts )
374- row ["Undelegated Nodes" ] = _sum_op_counts (record .undelegated_op_counts )
375- row ["Delegated Ops" ] = _serialize_op_counts (record .delegated_op_counts )
376- row ["Undelegated Ops" ] = _serialize_op_counts (record .undelegated_op_counts )
377- row ["PTE Size (Kb)" ] = (
378- f"{ record .pte_size_bytes / 1000.0 :.3f} " if record .pte_size_bytes else ""
379- )
380352
381- writer .writerow (row )
353+ def write_csv_row (record : TestCaseSummary , output : TextIO ):
354+ writer = csv .DictWriter (output , CSV_FIELD_NAMES )
355+
356+ row = {
357+ "Test ID" : record .name ,
358+ "Test Case" : record .base_name ,
359+ "Flow" : record .flow ,
360+ "Params" : _serialize_params (record .params ),
361+ "Result" : record .result .to_short_str (),
362+ "Result Detail" : record .result .to_detail_str (),
363+ "Delegated" : "True" if record .is_delegated () else "False" ,
364+ "Quantize Time (s)" : (
365+ f"{ record .quantize_time .total_seconds ():.3f} "
366+ if record .quantize_time
367+ else None
368+ ),
369+ "Lower Time (s)" : (
370+ f"{ record .lower_time .total_seconds ():.3f} " if record .lower_time else None
371+ ),
372+ }
373+
374+ for output_idx , error_stats in enumerate (record .tensor_error_statistics ):
375+ if output_idx >= MAX_LOGGED_MODEL_OUTPUTS :
376+ print (
377+ f"Model output stats are truncated as model has more than { MAX_LOGGED_MODEL_OUTPUTS } outputs. Consider increasing MAX_LOGGED_MODEL_OUTPUTS."
378+ )
379+ break
380+
381+ row [f"Output { output_idx } Error Max" ] = f"{ error_stats .error_max :.3f} "
382+ row [f"Output { output_idx } Error MAE" ] = f"{ error_stats .error_mae :.3f} "
383+ row [f"Output { output_idx } SNR" ] = f"{ error_stats .sqnr :.3f} "
384+
385+ row ["Delegated Nodes" ] = _sum_op_counts (record .delegated_op_counts )
386+ row ["Undelegated Nodes" ] = _sum_op_counts (record .undelegated_op_counts )
387+ row ["Delegated Ops" ] = _serialize_op_counts (record .delegated_op_counts )
388+ row ["Undelegated Ops" ] = _serialize_op_counts (record .undelegated_op_counts )
389+ row ["PTE Size (Kb)" ] = (
390+ f"{ record .pte_size_bytes / 1000.0 :.3f} " if record .pte_size_bytes else ""
391+ )
392+
393+ writer .writerow (row )
0 commit comments