1
1
import csv
2
2
3
3
from collections import Counter
4
- from dataclasses import dataclass
4
+ from dataclasses import dataclass , field
5
5
from datetime import timedelta
6
6
from enum import IntEnum
7
7
from functools import reduce
11
11
from torch .export import ExportedProgram
12
12
13
13
14
+ # The maximum number of model output tensors to log statistics for. Most model tests will
15
+ # only have one output, but some may return more than one tensor. This upper bound is needed
16
+ # upfront since the file is written progressively. Any outputs beyond these will not have stats logged.
17
+ MAX_LOGGED_MODEL_OUTPUTS = 2
18
+
19
+
20
+ # Field names for the CSV report.
21
+ CSV_FIELD_NAMES = [
22
+ "Test ID" ,
23
+ "Test Case" ,
24
+ "Flow" ,
25
+ "Params" ,
26
+ "Result" ,
27
+ "Result Detail" ,
28
+ "Delegated" ,
29
+ "Quantize Time (s)" ,
30
+ "Lower Time (s)" ,
31
+ "Delegated Nodes" ,
32
+ "Undelegated Nodes" ,
33
+ "Delegated Ops" ,
34
+ "Undelegated Ops" ,
35
+ "PTE Size (Kb)" ,
36
+ ]
37
+
38
+ for i in range (MAX_LOGGED_MODEL_OUTPUTS ):
39
+ CSV_FIELD_NAMES .extend (
40
+ [
41
+ f"Output { i } Error Max" ,
42
+ f"Output { i } Error MAE" ,
43
+ f"Output { i } SNR" ,
44
+ ]
45
+ )
46
+
47
+
14
48
# Operators that are excluded from the counts returned by count_ops. These are used to
15
49
# exclude operatations that are not logically relevant or delegatable to backends.
16
50
OP_COUNT_IGNORED_OPS = {
@@ -57,15 +91,15 @@ def is_non_backend_failure(self):
57
91
58
92
def is_backend_failure (self ):
59
93
return not self .is_success () and not self .is_non_backend_failure ()
60
-
94
+
61
95
def to_short_str (self ):
62
- if self in { TestResult .SUCCESS , TestResult .SUCCESS_UNDELEGATED }:
96
+ if self in {TestResult .SUCCESS , TestResult .SUCCESS_UNDELEGATED }:
63
97
return "Pass"
64
98
elif self == TestResult .SKIPPED :
65
99
return "Skip"
66
100
else :
67
101
return "Fail"
68
-
102
+
69
103
def to_detail_str (self ):
70
104
if self == TestResult .SUCCESS :
71
105
return ""
@@ -160,14 +194,22 @@ class TestCaseSummary:
160
194
""" The size of the PTE file in bytes. """
161
195
162
196
def is_delegated (self ):
163
- return any (v > 0 for v in self .delegated_op_counts .values ()) if self .delegated_op_counts else False
197
+ return (
198
+ any (v > 0 for v in self .delegated_op_counts .values ())
199
+ if self .delegated_op_counts
200
+ else False
201
+ )
164
202
165
203
204
+ @dataclass
166
205
class TestSessionState :
167
- test_case_summaries : list [TestCaseSummary ]
206
+ # True if the CSV header has been written to report__path.
207
+ has_written_report_header : bool = False
168
208
169
- def __init__ (self ):
170
- self .test_case_summaries = []
209
+ # The file path to write the detail report to, if enabled.
210
+ report_path : str | None = None
211
+
212
+ test_case_summaries : list [TestCaseSummary ] = field (default_factory = list )
171
213
172
214
173
215
@dataclass
@@ -245,11 +287,11 @@ def count_ops(program: dict[str, ExportedProgram] | ExportedProgram) -> Counter:
245
287
)
246
288
247
289
248
- def begin_test_session ():
290
+ def begin_test_session (report_path : str | None ):
249
291
global _active_session
250
292
251
293
assert _active_session is None , "A test session is already active."
252
- _active_session = TestSessionState ()
294
+ _active_session = TestSessionState (report_path = report_path )
253
295
254
296
255
297
def log_test_summary (summary : TestCaseSummary ):
@@ -258,6 +300,15 @@ def log_test_summary(summary: TestCaseSummary):
258
300
if _active_session is not None :
259
301
_active_session .test_case_summaries .append (summary )
260
302
303
+ if _active_session .report_path is not None :
304
+ file_mode = "a" if _active_session .has_written_report_header else "w"
305
+ with open (_active_session .report_path , file_mode ) as f :
306
+ if not _active_session .has_written_report_header :
307
+ write_csv_header (f )
308
+ _active_session .has_written_report_header = True
309
+
310
+ write_csv_row (summary , f )
311
+
261
312
262
313
def complete_test_session () -> RunSummary :
263
314
global _active_session
@@ -276,6 +327,13 @@ def _sum_op_counts(counter: Counter | None) -> int | None:
276
327
return sum (counter .values ()) if counter is not None else None
277
328
278
329
330
+ def _serialize_params (params : dict [str , Any ] | None ) -> str :
331
+ if params is not None :
332
+ return str (dict (sorted (params .items ())))
333
+ else :
334
+ return ""
335
+
336
+
279
337
def _serialize_op_counts (counter : Counter | None ) -> str :
280
338
"""
281
339
A utility function to serialize op counts to a string, for the purpose of including
@@ -287,87 +345,49 @@ def _serialize_op_counts(counter: Counter | None) -> str:
287
345
return ""
288
346
289
347
290
- def generate_csv_report (summary : RunSummary , output : TextIO ):
291
- """Write a run summary report to a file in CSV format."""
292
-
293
- field_names = [
294
- "Test ID" ,
295
- "Test Case" ,
296
- "Flow" ,
297
- "Result" ,
298
- "Result Detail" ,
299
- "Delegated" ,
300
- "Quantize Time (s)" ,
301
- "Lower Time (s)" ,
302
- ]
303
-
304
- # Tests can have custom parameters. We'll want to report them here, so we need
305
- # a list of all unique parameter names.
306
- param_names = reduce (
307
- lambda a , b : a .union (b ),
308
- (
309
- set (s .params .keys ())
310
- for s in summary .test_case_summaries
311
- if s .params is not None
312
- ),
313
- set (),
314
- )
315
- field_names += (s .capitalize () for s in param_names )
316
-
317
- # Add tensor error statistic field names for each output index.
318
- max_outputs = max (
319
- len (s .tensor_error_statistics ) for s in summary .test_case_summaries
320
- )
321
- for i in range (max_outputs ):
322
- field_names .extend (
323
- [
324
- f"Output { i } Error Max" ,
325
- f"Output { i } Error MAE" ,
326
- f"Output { i } SNR" ,
327
- ]
328
- )
329
- field_names .extend (
330
- [
331
- "Delegated Nodes" ,
332
- "Undelegated Nodes" ,
333
- "Delegated Ops" ,
334
- "Undelegated Ops" ,
335
- "PTE Size (Kb)" ,
336
- ]
337
- )
338
-
339
- writer = csv .DictWriter (output , field_names )
348
+ def write_csv_header (output : TextIO ):
349
+ writer = csv .DictWriter (output , CSV_FIELD_NAMES )
340
350
writer .writeheader ()
341
351
342
- for record in summary .test_case_summaries :
343
- row = {
344
- "Test ID" : record .name ,
345
- "Test Case" : record .base_name ,
346
- "Flow" : record .flow ,
347
- "Result" : record .result .to_short_str (),
348
- "Result Detail" : record .result .to_detail_str (),
349
- "Delegated" : "True" if record .is_delegated () else "False" ,
350
- "Quantize Time (s)" : (
351
- f"{ record .quantize_time .total_seconds ():.3f} " if record .quantize_time else None
352
- ),
353
- "Lower Time (s)" : (
354
- f"{ record .lower_time .total_seconds ():.3f} " if record .lower_time else None
355
- ),
356
- }
357
- if record .params is not None :
358
- row .update ({k .capitalize (): v for k , v in record .params .items ()})
359
-
360
- for output_idx , error_stats in enumerate (record .tensor_error_statistics ):
361
- row [f"Output { output_idx } Error Max" ] = f"{ error_stats .error_max :.3f} "
362
- row [f"Output { output_idx } Error MAE" ] = f"{ error_stats .error_mae :.3f} "
363
- row [f"Output { output_idx } SNR" ] = f"{ error_stats .sqnr :.3f} "
364
-
365
- row ["Delegated Nodes" ] = _sum_op_counts (record .delegated_op_counts )
366
- row ["Undelegated Nodes" ] = _sum_op_counts (record .undelegated_op_counts )
367
- row ["Delegated Ops" ] = _serialize_op_counts (record .delegated_op_counts )
368
- row ["Undelegated Ops" ] = _serialize_op_counts (record .undelegated_op_counts )
369
- row ["PTE Size (Kb)" ] = (
370
- f"{ record .pte_size_bytes / 1000.0 :.3f} " if record .pte_size_bytes else ""
371
- )
372
352
373
- writer .writerow (row )
353
+ def write_csv_row (record : TestCaseSummary , output : TextIO ):
354
+ writer = csv .DictWriter (output , CSV_FIELD_NAMES )
355
+
356
+ row = {
357
+ "Test ID" : record .name ,
358
+ "Test Case" : record .base_name ,
359
+ "Flow" : record .flow ,
360
+ "Params" : _serialize_params (record .params ),
361
+ "Result" : record .result .to_short_str (),
362
+ "Result Detail" : record .result .to_detail_str (),
363
+ "Delegated" : "True" if record .is_delegated () else "False" ,
364
+ "Quantize Time (s)" : (
365
+ f"{ record .quantize_time .total_seconds ():.3f} "
366
+ if record .quantize_time
367
+ else None
368
+ ),
369
+ "Lower Time (s)" : (
370
+ f"{ record .lower_time .total_seconds ():.3f} " if record .lower_time else None
371
+ ),
372
+ }
373
+
374
+ for output_idx , error_stats in enumerate (record .tensor_error_statistics ):
375
+ if output_idx >= MAX_LOGGED_MODEL_OUTPUTS :
376
+ print (
377
+ f"Model output stats are truncated as model has more than { MAX_LOGGED_MODEL_OUTPUTS } outputs. Consider increasing MAX_LOGGED_MODEL_OUTPUTS."
378
+ )
379
+ break
380
+
381
+ row [f"Output { output_idx } Error Max" ] = f"{ error_stats .error_max :.3f} "
382
+ row [f"Output { output_idx } Error MAE" ] = f"{ error_stats .error_mae :.3f} "
383
+ row [f"Output { output_idx } SNR" ] = f"{ error_stats .sqnr :.3f} "
384
+
385
+ row ["Delegated Nodes" ] = _sum_op_counts (record .delegated_op_counts )
386
+ row ["Undelegated Nodes" ] = _sum_op_counts (record .undelegated_op_counts )
387
+ row ["Delegated Ops" ] = _serialize_op_counts (record .delegated_op_counts )
388
+ row ["Undelegated Ops" ] = _serialize_op_counts (record .undelegated_op_counts )
389
+ row ["PTE Size (Kb)" ] = (
390
+ f"{ record .pte_size_bytes / 1000.0 :.3f} " if record .pte_size_bytes else ""
391
+ )
392
+
393
+ writer .writerow (row )
0 commit comments