11import csv
22
33from collections import Counter
4- from dataclasses import dataclass , field
4+ from dataclasses import dataclass
55from datetime import timedelta
66from enum import IntEnum
77from functools import reduce
@@ -205,15 +205,11 @@ def is_delegated(self):
205205 )
206206
207207
208- @dataclass
209208class TestSessionState :
210- # True if the CSV header has been written to report__path.
211- has_written_report_header : bool = False
212-
213- # The file path to write the detail report to, if enabled.
214- report_path : str | None = None
209+ test_case_summaries : list [TestCaseSummary ]
215210
216- test_case_summaries : list [TestCaseSummary ] = field (default_factory = list )
211+ def __init__ (self ):
212+ self .test_case_summaries = []
217213
218214
219215@dataclass
@@ -291,11 +287,11 @@ def count_ops(program: dict[str, ExportedProgram] | ExportedProgram) -> Counter:
291287 )
292288
293289
294- def begin_test_session (report_path : str | None ):
290+ def begin_test_session ():
295291 global _active_session
296292
297293 assert _active_session is None , "A test session is already active."
298- _active_session = TestSessionState (report_path = report_path )
294+ _active_session = TestSessionState ()
299295
300296
301297def log_test_summary (summary : TestCaseSummary ):
@@ -304,15 +300,6 @@ def log_test_summary(summary: TestCaseSummary):
304300 if _active_session is not None :
305301 _active_session .test_case_summaries .append (summary )
306302
307- if _active_session .report_path is not None :
308- file_mode = "a" if _active_session .has_written_report_header else "w"
309- with open (_active_session .report_path , file_mode ) as f :
310- if not _active_session .has_written_report_header :
311- write_csv_header (f )
312- _active_session .has_written_report_header = True
313-
314- write_csv_row (summary , f )
315-
316303
317304def complete_test_session () -> RunSummary :
318305 global _active_session
@@ -331,13 +318,6 @@ def _sum_op_counts(counter: Counter | None) -> int | None:
331318 return sum (counter .values ()) if counter is not None else None
332319
333320
334- def _serialize_params (params : dict [str , Any ] | None ) -> str :
335- if params is not None :
336- return str (dict (sorted (params .items ())))
337- else :
338- return ""
339-
340-
341321def _serialize_op_counts (counter : Counter | None ) -> str :
342322 """
343323 A utility function to serialize op counts to a string, for the purpose of including
@@ -349,10 +329,19 @@ def _serialize_op_counts(counter: Counter | None) -> str:
349329 return ""
350330
351331
352- def write_csv_header (output : TextIO ):
353- writer = csv .DictWriter (output , CSV_FIELD_NAMES )
354- writer .writeheader ()
332+ def generate_csv_report (summary : RunSummary , output : TextIO ):
333+ """Write a run summary report to a file in CSV format."""
355334
335+ field_names = [
336+ "Test ID" ,
337+ "Test Case" ,
338+ "Flow" ,
339+ "Result" ,
340+ "Result Detail" ,
341+ "Delegated" ,
342+ "Quantize Time (s)" ,
343+ "Lower Time (s)" ,
344+ ]
356345
357346def write_csv_row (record : TestCaseSummary , output : TextIO ):
358347 writer = csv .DictWriter (output , CSV_FIELD_NAMES )
@@ -371,28 +360,68 @@ def write_csv_row(record: TestCaseSummary, output: TextIO):
371360 if record .quantize_time
372361 else None
373362 ),
374- "Lower Time (s)" : (
375- f"{ record .lower_time .total_seconds ():.3f} " if record .lower_time else None
376- ),
377- }
378-
379- for output_idx , error_stats in enumerate (record .tensor_error_statistics ):
380- if output_idx >= MAX_LOGGED_MODEL_OUTPUTS :
381- print (
382- f"Model output stats are truncated as model has more than { MAX_LOGGED_MODEL_OUTPUTS } outputs. Consider increasing MAX_LOGGED_MODEL_OUTPUTS."
383- )
384- break
385-
386- row [f"Output { output_idx } Error Max" ] = f"{ error_stats .error_max :.3f} "
387- row [f"Output { output_idx } Error MAE" ] = f"{ error_stats .error_mae :.3f} "
388- row [f"Output { output_idx } SNR" ] = f"{ error_stats .sqnr :.3f} "
389-
390- row ["Delegated Nodes" ] = _sum_op_counts (record .delegated_op_counts )
391- row ["Undelegated Nodes" ] = _sum_op_counts (record .undelegated_op_counts )
392- row ["Delegated Ops" ] = _serialize_op_counts (record .delegated_op_counts )
393- row ["Undelegated Ops" ] = _serialize_op_counts (record .undelegated_op_counts )
394- row ["PTE Size (Kb)" ] = (
395- f"{ record .pte_size_bytes / 1000.0 :.3f} " if record .pte_size_bytes else ""
363+ set (),
364+ )
365+ field_names + = (s .capitalize () for s in param_names )
366+
367+ # Add tensor error statistic field names for each output index.
368+ max_outputs = max (
369+ len (s .tensor_error_statistics ) for s in summary .test_case_summaries
396370 )
371+ for i in range (max_outputs ):
372+ field_names .extend (
373+ [
374+ f"Output { i } Error Max" ,
375+ f"Output { i } Error MAE" ,
376+ f"Output { i } SNR" ,
377+ ]
378+ )
379+ field_names .extend (
380+ [
381+ "Delegated Nodes" ,
382+ "Undelegated Nodes" ,
383+ "Delegated Ops" ,
384+ "Undelegated Ops" ,
385+ "PTE Size (Kb)" ,
386+ ]
387+ )
388+
389+ writer = csv .DictWriter (output , field_names )
390+ writer .writeheader ()
391+
392+ for record in summary .test_case_summaries :
393+ row = {
394+ "Test ID" : record .name ,
395+ "Test Case" : record .base_name ,
396+ "Flow" : record .flow ,
397+ "Result" : record .result .to_short_str (),
398+ "Result Detail" : record .result .to_detail_str (),
399+ "Delegated" : "True" if record .is_delegated () else "False" ,
400+ "Quantize Time (s)" : (
401+ f"{ record .quantize_time .total_seconds ():.3f} "
402+ if record .quantize_time
403+ else None
404+ ),
405+ "Lower Time (s)" : (
406+ f"{ record .lower_time .total_seconds ():.3f} "
407+ if record .lower_time
408+ else None
409+ ),
410+ }
411+ if record .params is not None :
412+ row .update ({k .capitalize (): v for k , v in record .params .items ()})
413+
414+ for output_idx , error_stats in enumerate (record .tensor_error_statistics ):
415+ row [f"Output { output_idx } Error Max" ] = f"{ error_stats .error_max :.3f} "
416+ row [f"Output { output_idx } Error MAE" ] = f"{ error_stats .error_mae :.3f} "
417+ row [f"Output { output_idx } SNR" ] = f"{ error_stats .sqnr :.3f} "
418+
419+ row ["Delegated Nodes" ] = _sum_op_counts (record .delegated_op_counts )
420+ row ["Undelegated Nodes" ] = _sum_op_counts (record .undelegated_op_counts )
421+ row ["Delegated Ops" ] = _serialize_op_counts (record .delegated_op_counts )
422+ row ["Undelegated Ops" ] = _serialize_op_counts (record .undelegated_op_counts )
423+ row ["PTE Size (Kb)" ] = (
424+ f"{ record .pte_size_bytes / 1000.0 :.3f} " if record .pte_size_bytes else ""
425+ )
397426
398- writer .writerow (row )
427+ writer .writerow (row )
0 commit comments