11import csv
22
33from collections import Counter
4- from dataclasses import dataclass
4+ from dataclasses import dataclass , field
55from datetime import timedelta
66from enum import IntEnum
77from functools import reduce
@@ -205,11 +205,15 @@ def is_delegated(self):
205205 )
206206
207207
208+ @dataclass
208209class TestSessionState :
209- test_case_summaries : list [TestCaseSummary ]
210+ # True if the CSV header has been written to report__path.
211+ has_written_report_header : bool = False
212+
213+ # The file path to write the detail report to, if enabled.
214+ report_path : str | None = None
210215
211- def __init__ (self ):
212- self .test_case_summaries = []
216+ test_case_summaries : list [TestCaseSummary ] = field (default_factory = list )
213217
214218
215219@dataclass
@@ -287,11 +291,11 @@ def count_ops(program: dict[str, ExportedProgram] | ExportedProgram) -> Counter:
287291 )
288292
289293
290- def begin_test_session ():
294+ def begin_test_session (report_path : str | None ):
291295 global _active_session
292296
293297 assert _active_session is None , "A test session is already active."
294- _active_session = TestSessionState ()
298+ _active_session = TestSessionState (report_path = report_path )
295299
296300
297301def log_test_summary (summary : TestCaseSummary ):
@@ -300,6 +304,15 @@ def log_test_summary(summary: TestCaseSummary):
300304 if _active_session is not None :
301305 _active_session .test_case_summaries .append (summary )
302306
307+ if _active_session .report_path is not None :
308+ file_mode = "a" if _active_session .has_written_report_header else "w"
309+ with open (_active_session .report_path , file_mode ) as f :
310+ if not _active_session .has_written_report_header :
311+ write_csv_header (f )
312+ _active_session .has_written_report_header = True
313+
314+ write_csv_row (summary , f )
315+
303316
304317def complete_test_session () -> RunSummary :
305318 global _active_session
@@ -318,6 +331,13 @@ def _sum_op_counts(counter: Counter | None) -> int | None:
318331 return sum (counter .values ()) if counter is not None else None
319332
320333
334+ def _serialize_params (params : dict [str , Any ] | None ) -> str :
335+ if params is not None :
336+ return str (dict (sorted (params .items ())))
337+ else :
338+ return ""
339+
340+
321341def _serialize_op_counts (counter : Counter | None ) -> str :
322342 """
323343 A utility function to serialize op counts to a string, for the purpose of including
@@ -329,19 +349,10 @@ def _serialize_op_counts(counter: Counter | None) -> str:
329349 return ""
330350
331351
332- def generate_csv_report (summary : RunSummary , output : TextIO ):
333- """Write a run summary report to a file in CSV format."""
352+ def write_csv_header (output : TextIO ):
353+ writer = csv .DictWriter (output , CSV_FIELD_NAMES )
354+ writer .writeheader ()
334355
335- field_names = [
336- "Test ID" ,
337- "Test Case" ,
338- "Flow" ,
339- "Result" ,
340- "Result Detail" ,
341- "Delegated" ,
342- "Quantize Time (s)" ,
343- "Lower Time (s)" ,
344- ]
345356
346357def write_csv_row (record : TestCaseSummary , output : TextIO ):
347358 writer = csv .DictWriter (output , CSV_FIELD_NAMES )
@@ -360,68 +371,28 @@ def write_csv_row(record: TestCaseSummary, output: TextIO):
360371 if record .quantize_time
361372 else None
362373 ),
363- set (),
364- )
365- field_names + = (s .capitalize () for s in param_names )
366-
367- # Add tensor error statistic field names for each output index.
368- max_outputs = max (
369- len (s .tensor_error_statistics ) for s in summary .test_case_summaries
370- )
371- for i in range (max_outputs ):
372- field_names .extend (
373- [
374- f"Output { i } Error Max" ,
375- f"Output { i } Error MAE" ,
376- f"Output { i } SNR" ,
377- ]
378- )
379- field_names .extend (
380- [
381- "Delegated Nodes" ,
382- "Undelegated Nodes" ,
383- "Delegated Ops" ,
384- "Undelegated Ops" ,
385- "PTE Size (Kb)" ,
386- ]
374+ "Lower Time (s)" : (
375+ f"{ record .lower_time .total_seconds ():.3f} " if record .lower_time else None
376+ ),
377+ }
378+
379+ for output_idx , error_stats in enumerate (record .tensor_error_statistics ):
380+ if output_idx >= MAX_LOGGED_MODEL_OUTPUTS :
381+ print (
382+ f"Model output stats are truncated as model has more than { MAX_LOGGED_MODEL_OUTPUTS } outputs. Consider increasing MAX_LOGGED_MODEL_OUTPUTS."
383+ )
384+ break
385+
386+ row [f"Output { output_idx } Error Max" ] = f"{ error_stats .error_max :.3f} "
387+ row [f"Output { output_idx } Error MAE" ] = f"{ error_stats .error_mae :.3f} "
388+ row [f"Output { output_idx } SNR" ] = f"{ error_stats .sqnr :.3f} "
389+
390+ row ["Delegated Nodes" ] = _sum_op_counts (record .delegated_op_counts )
391+ row ["Undelegated Nodes" ] = _sum_op_counts (record .undelegated_op_counts )
392+ row ["Delegated Ops" ] = _serialize_op_counts (record .delegated_op_counts )
393+ row ["Undelegated Ops" ] = _serialize_op_counts (record .undelegated_op_counts )
394+ row ["PTE Size (Kb)" ] = (
395+ f"{ record .pte_size_bytes / 1000.0 :.3f} " if record .pte_size_bytes else ""
387396 )
388397
389- writer = csv .DictWriter (output , field_names )
390- writer .writeheader ()
391-
392- for record in summary .test_case_summaries :
393- row = {
394- "Test ID" : record .name ,
395- "Test Case" : record .base_name ,
396- "Flow" : record .flow ,
397- "Result" : record .result .to_short_str (),
398- "Result Detail" : record .result .to_detail_str (),
399- "Delegated" : "True" if record .is_delegated () else "False" ,
400- "Quantize Time (s)" : (
401- f"{ record .quantize_time .total_seconds ():.3f} "
402- if record .quantize_time
403- else None
404- ),
405- "Lower Time (s)" : (
406- f"{ record .lower_time .total_seconds ():.3f} "
407- if record .lower_time
408- else None
409- ),
410- }
411- if record .params is not None :
412- row .update ({k .capitalize (): v for k , v in record .params .items ()})
413-
414- for output_idx , error_stats in enumerate (record .tensor_error_statistics ):
415- row [f"Output { output_idx } Error Max" ] = f"{ error_stats .error_max :.3f} "
416- row [f"Output { output_idx } Error MAE" ] = f"{ error_stats .error_mae :.3f} "
417- row [f"Output { output_idx } SNR" ] = f"{ error_stats .sqnr :.3f} "
418-
419- row ["Delegated Nodes" ] = _sum_op_counts (record .delegated_op_counts )
420- row ["Undelegated Nodes" ] = _sum_op_counts (record .undelegated_op_counts )
421- row ["Delegated Ops" ] = _serialize_op_counts (record .delegated_op_counts )
422- row ["Undelegated Ops" ] = _serialize_op_counts (record .undelegated_op_counts )
423- row ["PTE Size (Kb)" ] = (
424- f"{ record .pte_size_bytes / 1000.0 :.3f} " if record .pte_size_bytes else ""
425- )
426-
427- writer .writerow (row )
398+ writer .writerow (row )
0 commit comments