1
1
import csv
2
2
3
3
from collections import Counter
4
- from dataclasses import dataclass
4
+ from dataclasses import dataclass , field
5
5
from datetime import timedelta
6
6
from enum import IntEnum
7
7
from functools import reduce
11
11
from torch .export import ExportedProgram
12
12
13
13
14
+ # The maximum number of model output tensors to log statistics for. Most model tests will
15
+ # only have one output, but some may return more than one tensor. This upper bound is needed
16
+ # upfront since the file is written progressively. Any outputs beyond these will not have stats logged.
17
+ MAX_LOGGED_MODEL_OUTPUTS = 2
18
+
19
+
20
+ # Field names for the CSV report.
21
+ CSV_FIELD_NAMES = [
22
+ "Test ID" ,
23
+ "Test Case" ,
24
+ "Flow" ,
25
+ "Params" ,
26
+ "Result" ,
27
+ "Result Detail" ,
28
+ "Delegated" ,
29
+ "Quantize Time (s)" ,
30
+ "Lower Time (s)" ,
31
+ "Delegated Nodes" ,
32
+ "Undelegated Nodes" ,
33
+ "Delegated Ops" ,
34
+ "Undelegated Ops" ,
35
+ "PTE Size (Kb)" ,
36
+ ]
37
+
38
+ for i in range (MAX_LOGGED_MODEL_OUTPUTS ):
39
+ CSV_FIELD_NAMES .extend (
40
+ [
41
+ f"Output { i } Error Max" ,
42
+ f"Output { i } Error MAE" ,
43
+ f"Output { i } SNR" ,
44
+ ]
45
+ )
46
+
47
+
14
48
# Operators that are excluded from the counts returned by count_ops. These are used to
15
49
# exclude operatations that are not logically relevant or delegatable to backends.
16
50
OP_COUNT_IGNORED_OPS = {
@@ -167,11 +201,15 @@ def is_delegated(self):
167
201
)
168
202
169
203
204
+ @dataclass
170
205
class TestSessionState :
171
- test_case_summaries : list [TestCaseSummary ]
206
+ # True if the CSV header has been written to report__path.
207
+ has_written_report_header : bool = False
172
208
173
- def __init__ (self ):
174
- self .test_case_summaries = []
209
+ # The file path to write the detail report to, if enabled.
210
+ report_path : str | None = None
211
+
212
+ test_case_summaries : list [TestCaseSummary ] = field (default_factory = list )
175
213
176
214
177
215
@dataclass
@@ -249,11 +287,11 @@ def count_ops(program: dict[str, ExportedProgram] | ExportedProgram) -> Counter:
249
287
)
250
288
251
289
252
- def begin_test_session ():
290
+ def begin_test_session (report_path : str | None ):
253
291
global _active_session
254
292
255
293
assert _active_session is None , "A test session is already active."
256
- _active_session = TestSessionState ()
294
+ _active_session = TestSessionState (report_path = report_path )
257
295
258
296
259
297
def log_test_summary (summary : TestCaseSummary ):
@@ -262,6 +300,15 @@ def log_test_summary(summary: TestCaseSummary):
262
300
if _active_session is not None :
263
301
_active_session .test_case_summaries .append (summary )
264
302
303
+ if _active_session .report_path is not None :
304
+ file_mode = "a" if _active_session .has_written_report_header else "w"
305
+ with open (_active_session .report_path , file_mode ) as f :
306
+ if not _active_session .has_written_report_header :
307
+ write_csv_header (f )
308
+ _active_session .has_written_report_header = True
309
+
310
+ write_csv_row (summary , f )
311
+
265
312
266
313
def complete_test_session () -> RunSummary :
267
314
global _active_session
@@ -280,6 +327,13 @@ def _sum_op_counts(counter: Counter | None) -> int | None:
280
327
return sum (counter .values ()) if counter is not None else None
281
328
282
329
330
+ def _serialize_params (params : dict [str , Any ] | None ) -> str :
331
+ if params is not None :
332
+ return str (dict (sorted (params .items ())))
333
+ else :
334
+ return ""
335
+
336
+
283
337
def _serialize_op_counts (counter : Counter | None ) -> str :
284
338
"""
285
339
A utility function to serialize op counts to a string, for the purpose of including
@@ -291,91 +345,49 @@ def _serialize_op_counts(counter: Counter | None) -> str:
291
345
return ""
292
346
293
347
294
- def generate_csv_report (summary : RunSummary , output : TextIO ):
295
- """Write a run summary report to a file in CSV format."""
296
-
297
- field_names = [
298
- "Test ID" ,
299
- "Test Case" ,
300
- "Flow" ,
301
- "Result" ,
302
- "Result Detail" ,
303
- "Delegated" ,
304
- "Quantize Time (s)" ,
305
- "Lower Time (s)" ,
306
- ]
307
-
308
- # Tests can have custom parameters. We'll want to report them here, so we need
309
- # a list of all unique parameter names.
310
- param_names = reduce (
311
- lambda a , b : a .union (b ),
312
- (
313
- set (s .params .keys ())
314
- for s in summary .test_case_summaries
315
- if s .params is not None
316
- ),
317
- set (),
318
- )
319
- field_names += (s .capitalize () for s in param_names )
320
-
321
- # Add tensor error statistic field names for each output index.
322
- max_outputs = max (
323
- len (s .tensor_error_statistics ) for s in summary .test_case_summaries
324
- )
325
- for i in range (max_outputs ):
326
- field_names .extend (
327
- [
328
- f"Output { i } Error Max" ,
329
- f"Output { i } Error MAE" ,
330
- f"Output { i } SNR" ,
331
- ]
332
- )
333
- field_names .extend (
334
- [
335
- "Delegated Nodes" ,
336
- "Undelegated Nodes" ,
337
- "Delegated Ops" ,
338
- "Undelegated Ops" ,
339
- "PTE Size (Kb)" ,
340
- ]
341
- )
342
-
343
- writer = csv .DictWriter (output , field_names )
348
+ def write_csv_header (output : TextIO ):
349
+ writer = csv .DictWriter (output , CSV_FIELD_NAMES )
344
350
writer .writeheader ()
345
351
346
- for record in summary .test_case_summaries :
347
- row = {
348
- "Test ID" : record .name ,
349
- "Test Case" : record .base_name ,
350
- "Flow" : record .flow ,
351
- "Result" : record .result .to_short_str (),
352
- "Result Detail" : record .result .to_detail_str (),
353
- "Delegated" : "True" if record .is_delegated () else "False" ,
354
- "Quantize Time (s)" : (
355
- f"{ record .quantize_time .total_seconds ():.3f} "
356
- if record .quantize_time
357
- else None
358
- ),
359
- "Lower Time (s)" : (
360
- f"{ record .lower_time .total_seconds ():.3f} "
361
- if record .lower_time
362
- else None
363
- ),
364
- }
365
- if record .params is not None :
366
- row .update ({k .capitalize (): v for k , v in record .params .items ()})
367
-
368
- for output_idx , error_stats in enumerate (record .tensor_error_statistics ):
369
- row [f"Output { output_idx } Error Max" ] = f"{ error_stats .error_max :.3f} "
370
- row [f"Output { output_idx } Error MAE" ] = f"{ error_stats .error_mae :.3f} "
371
- row [f"Output { output_idx } SNR" ] = f"{ error_stats .sqnr :.3f} "
372
-
373
- row ["Delegated Nodes" ] = _sum_op_counts (record .delegated_op_counts )
374
- row ["Undelegated Nodes" ] = _sum_op_counts (record .undelegated_op_counts )
375
- row ["Delegated Ops" ] = _serialize_op_counts (record .delegated_op_counts )
376
- row ["Undelegated Ops" ] = _serialize_op_counts (record .undelegated_op_counts )
377
- row ["PTE Size (Kb)" ] = (
378
- f"{ record .pte_size_bytes / 1000.0 :.3f} " if record .pte_size_bytes else ""
379
- )
380
352
381
- writer .writerow (row )
353
+ def write_csv_row (record : TestCaseSummary , output : TextIO ):
354
+ writer = csv .DictWriter (output , CSV_FIELD_NAMES )
355
+
356
+ row = {
357
+ "Test ID" : record .name ,
358
+ "Test Case" : record .base_name ,
359
+ "Flow" : record .flow ,
360
+ "Params" : _serialize_params (record .params ),
361
+ "Result" : record .result .to_short_str (),
362
+ "Result Detail" : record .result .to_detail_str (),
363
+ "Delegated" : "True" if record .is_delegated () else "False" ,
364
+ "Quantize Time (s)" : (
365
+ f"{ record .quantize_time .total_seconds ():.3f} "
366
+ if record .quantize_time
367
+ else None
368
+ ),
369
+ "Lower Time (s)" : (
370
+ f"{ record .lower_time .total_seconds ():.3f} " if record .lower_time else None
371
+ ),
372
+ }
373
+
374
+ for output_idx , error_stats in enumerate (record .tensor_error_statistics ):
375
+ if output_idx >= MAX_LOGGED_MODEL_OUTPUTS :
376
+ print (
377
+ f"Model output stats are truncated as model has more than { MAX_LOGGED_MODEL_OUTPUTS } outputs. Consider increasing MAX_LOGGED_MODEL_OUTPUTS."
378
+ )
379
+ break
380
+
381
+ row [f"Output { output_idx } Error Max" ] = f"{ error_stats .error_max :.3f} "
382
+ row [f"Output { output_idx } Error MAE" ] = f"{ error_stats .error_mae :.3f} "
383
+ row [f"Output { output_idx } SNR" ] = f"{ error_stats .sqnr :.3f} "
384
+
385
+ row ["Delegated Nodes" ] = _sum_op_counts (record .delegated_op_counts )
386
+ row ["Undelegated Nodes" ] = _sum_op_counts (record .undelegated_op_counts )
387
+ row ["Delegated Ops" ] = _serialize_op_counts (record .delegated_op_counts )
388
+ row ["Undelegated Ops" ] = _serialize_op_counts (record .undelegated_op_counts )
389
+ row ["PTE Size (Kb)" ] = (
390
+ f"{ record .pte_size_bytes / 1000.0 :.3f} " if record .pte_size_bytes else ""
391
+ )
392
+
393
+ writer .writerow (row )
0 commit comments