11import csv
22
33from collections import Counter
4- from dataclasses import dataclass , field
4+ from dataclasses import dataclass
55from datetime import timedelta
66from enum import IntEnum
77from functools import reduce
1111from torch .export import ExportedProgram
1212
1313
14- # The maximum number of model output tensors to log statistics for. Most model tests will
15- # only have one output, but some may return more than one tensor. This upper bound is needed
16- # upfront since the file is written progressively. Any outputs beyond these will not have stats logged.
17- MAX_LOGGED_MODEL_OUTPUTS = 2
18-
19-
20- # Field names for the CSV report.
21- CSV_FIELD_NAMES = [
22- "Test ID" ,
23- "Test Case" ,
24- "Flow" ,
25- "Params" ,
26- "Result" ,
27- "Result Detail" ,
28- "Delegated" ,
29- "Quantize Time (s)" ,
30- "Lower Time (s)" ,
31- "Delegated Nodes" ,
32- "Undelegated Nodes" ,
33- "Delegated Ops" ,
34- "Undelegated Ops" ,
35- "PTE Size (Kb)" ,
36- ]
37-
38- for i in range (MAX_LOGGED_MODEL_OUTPUTS ):
39- CSV_FIELD_NAMES .extend (
40- [
41- f"Output { i } Error Max" ,
42- f"Output { i } Error MAE" ,
43- f"Output { i } SNR" ,
44- ]
45- )
46-
47-
4814# Operators that are excluded from the counts returned by count_ops. These are used to
4915# exclude operatations that are not logically relevant or delegatable to backends.
5016OP_COUNT_IGNORED_OPS = {
@@ -201,15 +167,11 @@ def is_delegated(self):
201167 )
202168
203169
204- @dataclass
205170class TestSessionState :
206- # True if the CSV header has been written to report__path.
207- has_written_report_header : bool = False
208-
209- # The file path to write the detail report to, if enabled.
210- report_path : str | None = None
171+ test_case_summaries : list [TestCaseSummary ]
211172
212- test_case_summaries : list [TestCaseSummary ] = field (default_factory = list )
173+ def __init__ (self ):
174+ self .test_case_summaries = []
213175
214176
215177@dataclass
@@ -287,11 +249,11 @@ def count_ops(program: dict[str, ExportedProgram] | ExportedProgram) -> Counter:
287249 )
288250
289251
290- def begin_test_session (report_path : str | None ):
252+ def begin_test_session ():
291253 global _active_session
292254
293255 assert _active_session is None , "A test session is already active."
294- _active_session = TestSessionState (report_path = report_path )
256+ _active_session = TestSessionState ()
295257
296258
297259def log_test_summary (summary : TestCaseSummary ):
@@ -300,15 +262,6 @@ def log_test_summary(summary: TestCaseSummary):
300262 if _active_session is not None :
301263 _active_session .test_case_summaries .append (summary )
302264
303- if _active_session .report_path is not None :
304- file_mode = "a" if _active_session .has_written_report_header else "w"
305- with open (_active_session .report_path , file_mode ) as f :
306- if not _active_session .has_written_report_header :
307- write_csv_header (f )
308- _active_session .has_written_report_header = True
309-
310- write_csv_row (summary , f )
311-
312265
313266def complete_test_session () -> RunSummary :
314267 global _active_session
@@ -327,13 +280,6 @@ def _sum_op_counts(counter: Counter | None) -> int | None:
327280 return sum (counter .values ()) if counter is not None else None
328281
329282
330- def _serialize_params (params : dict [str , Any ] | None ) -> str :
331- if params is not None :
332- return str (dict (sorted (params .items ())))
333- else :
334- return ""
335-
336-
337283def _serialize_op_counts (counter : Counter | None ) -> str :
338284 """
339285 A utility function to serialize op counts to a string, for the purpose of including
@@ -345,49 +291,91 @@ def _serialize_op_counts(counter: Counter | None) -> str:
345291 return ""
346292
347293
348- def write_csv_header (output : TextIO ):
349- writer = csv .DictWriter (output , CSV_FIELD_NAMES )
350- writer .writeheader ()
351-
352-
353- def write_csv_row (record : TestCaseSummary , output : TextIO ):
354- writer = csv .DictWriter (output , CSV_FIELD_NAMES )
355-
356- row = {
357- "Test ID" : record .name ,
358- "Test Case" : record .base_name ,
359- "Flow" : record .flow ,
360- "Params" : _serialize_params (record .params ),
361- "Result" : record .result .to_short_str (),
362- "Result Detail" : record .result .to_detail_str (),
363- "Delegated" : "True" if record .is_delegated () else "False" ,
364- "Quantize Time (s)" : (
365- f"{ record .quantize_time .total_seconds ():.3f} "
366- if record .quantize_time
367- else None
368- ),
369- "Lower Time (s)" : (
370- f"{ record .lower_time .total_seconds ():.3f} " if record .lower_time else None
294+ def generate_csv_report (summary : RunSummary , output : TextIO ):
295+ """Write a run summary report to a file in CSV format."""
296+
297+ field_names = [
298+ "Test ID" ,
299+ "Test Case" ,
300+ "Flow" ,
301+ "Result" ,
302+ "Result Detail" ,
303+ "Delegated" ,
304+ "Quantize Time (s)" ,
305+ "Lower Time (s)" ,
306+ ]
307+
308+ # Tests can have custom parameters. We'll want to report them here, so we need
309+ # a list of all unique parameter names.
310+ param_names = reduce (
311+ lambda a , b : a .union (b ),
312+ (
313+ set (s .params .keys ())
314+ for s in summary .test_case_summaries
315+ if s .params is not None
371316 ),
372- }
373-
374- for output_idx , error_stats in enumerate (record .tensor_error_statistics ):
375- if output_idx >= MAX_LOGGED_MODEL_OUTPUTS :
376- print (
377- f"Model output stats are truncated as model has more than { MAX_LOGGED_MODEL_OUTPUTS } outputs. Consider increasing MAX_LOGGED_MODEL_OUTPUTS."
378- )
379- break
380-
381- row [f"Output { output_idx } Error Max" ] = f"{ error_stats .error_max :.3f} "
382- row [f"Output { output_idx } Error MAE" ] = f"{ error_stats .error_mae :.3f} "
383- row [f"Output { output_idx } SNR" ] = f"{ error_stats .sqnr :.3f} "
384-
385- row ["Delegated Nodes" ] = _sum_op_counts (record .delegated_op_counts )
386- row ["Undelegated Nodes" ] = _sum_op_counts (record .undelegated_op_counts )
387- row ["Delegated Ops" ] = _serialize_op_counts (record .delegated_op_counts )
388- row ["Undelegated Ops" ] = _serialize_op_counts (record .undelegated_op_counts )
389- row ["PTE Size (Kb)" ] = (
390- f"{ record .pte_size_bytes / 1000.0 :.3f} " if record .pte_size_bytes else ""
317+ set (),
318+ )
319+ field_names += (s .capitalize () for s in param_names )
320+
321+ # Add tensor error statistic field names for each output index.
322+ max_outputs = max (
323+ len (s .tensor_error_statistics ) for s in summary .test_case_summaries
324+ )
325+ for i in range (max_outputs ):
326+ field_names .extend (
327+ [
328+ f"Output { i } Error Max" ,
329+ f"Output { i } Error MAE" ,
330+ f"Output { i } SNR" ,
331+ ]
332+ )
333+ field_names .extend (
334+ [
335+ "Delegated Nodes" ,
336+ "Undelegated Nodes" ,
337+ "Delegated Ops" ,
338+ "Undelegated Ops" ,
339+ "PTE Size (Kb)" ,
340+ ]
391341 )
392342
393- writer .writerow (row )
343+ writer = csv .DictWriter (output , field_names )
344+ writer .writeheader ()
345+
346+ for record in summary .test_case_summaries :
347+ row = {
348+ "Test ID" : record .name ,
349+ "Test Case" : record .base_name ,
350+ "Flow" : record .flow ,
351+ "Result" : record .result .to_short_str (),
352+ "Result Detail" : record .result .to_detail_str (),
353+ "Delegated" : "True" if record .is_delegated () else "False" ,
354+ "Quantize Time (s)" : (
355+ f"{ record .quantize_time .total_seconds ():.3f} "
356+ if record .quantize_time
357+ else None
358+ ),
359+ "Lower Time (s)" : (
360+ f"{ record .lower_time .total_seconds ():.3f} "
361+ if record .lower_time
362+ else None
363+ ),
364+ }
365+ if record .params is not None :
366+ row .update ({k .capitalize (): v for k , v in record .params .items ()})
367+
368+ for output_idx , error_stats in enumerate (record .tensor_error_statistics ):
369+ row [f"Output { output_idx } Error Max" ] = f"{ error_stats .error_max :.3f} "
370+ row [f"Output { output_idx } Error MAE" ] = f"{ error_stats .error_mae :.3f} "
371+ row [f"Output { output_idx } SNR" ] = f"{ error_stats .sqnr :.3f} "
372+
373+ row ["Delegated Nodes" ] = _sum_op_counts (record .delegated_op_counts )
374+ row ["Undelegated Nodes" ] = _sum_op_counts (record .undelegated_op_counts )
375+ row ["Delegated Ops" ] = _serialize_op_counts (record .delegated_op_counts )
376+ row ["Undelegated Ops" ] = _serialize_op_counts (record .undelegated_op_counts )
377+ row ["PTE Size (Kb)" ] = (
378+ f"{ record .pte_size_bytes / 1000.0 :.3f} " if record .pte_size_bytes else ""
379+ )
380+
381+ writer .writerow (row )
0 commit comments