1+ import csv
12from collections import Counter
23from dataclasses import dataclass
34from enum import IntEnum
45from functools import reduce
56from typing import TextIO
67
7- import csv
8-
98from executorch .backends .test .harness .error_statistics import ErrorStatistics
109
10+
1111class TestResult (IntEnum ):
1212 """Represents the result of a test case run, indicating success or a specific failure reason."""
1313
@@ -80,13 +80,13 @@ class TestCaseSummary:
8080 """
8181 Contains summary results for the execution of a single test case.
8282 """
83-
83+
8484 backend : str
8585 """ The name of the target backend. """
8686
8787 base_name : str
8888 """ The base name of the test, not including flow or parameter suffixes. """
89-
89+
9090 flow : str
9191 """ The backend-specific flow name. Corresponds to flows registered in backends/test/suite/__init__.py. """
9292
@@ -101,7 +101,7 @@ class TestCaseSummary:
101101
102102 error : Exception | None
103103 """ The Python exception object, if any. """
104-
104+
105105 tensor_error_statistics : list [ErrorStatistics ]
106106 """
107107 Statistics about the error between the backend and reference outputs. Each element of this list corresponds to
@@ -180,8 +180,9 @@ def complete_test_session() -> RunSummary:
180180
181181 return summary
182182
183+
183184def generate_csv_report (summary : RunSummary , output : TextIO ):
184- """ Write a run summary report to a file in CSV format. """
185+ """Write a run summary report to a file in CSV format."""
185186
186187 field_names = [
187188 "Test ID" ,
@@ -190,30 +191,38 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
190191 "Flow" ,
191192 "Result" ,
192193 ]
193-
194+
194195 # Tests can have custom parameters. We'll want to report them here, so we need
195196 # a list of all unique parameter names.
196197 param_names = reduce (
197198 lambda a , b : a .union (b ),
198- (set (s .params .keys ()) for s in summary .test_case_summaries if s .params is not None ),
199- set ()
199+ (
200+ set (s .params .keys ())
201+ for s in summary .test_case_summaries
202+ if s .params is not None
203+ ),
204+ set (),
200205 )
201206 field_names += (s .capitalize () for s in param_names )
202207
203208 # Add tensor error statistic field names for each output index.
204- max_outputs = max (len (s .tensor_error_statistics ) for s in summary .test_case_summaries )
209+ max_outputs = max (
210+ len (s .tensor_error_statistics ) for s in summary .test_case_summaries
211+ )
205212 for i in range (max_outputs ):
206- field_names .extend ([
207- f"Output { i } Error Max" ,
208- f"Output { i } Error MAE" ,
209- f"Output { i } Error MSD" ,
210- f"Output { i } Error L2" ,
211- f"Output { i } SQNR" ,
212- ])
213+ field_names .extend (
214+ [
215+ f"Output { i } Error Max" ,
216+ f"Output { i } Error MAE" ,
217+ f"Output { i } Error MSD" ,
218+ f"Output { i } Error L2" ,
219+ f"Output { i } SQNR" ,
220+ ]
221+ )
213222
214223 writer = csv .DictWriter (output , field_names )
215224 writer .writeheader ()
216-
225+
217226 for record in summary .test_case_summaries :
218227 row = {
219228 "Test ID" : record .name ,
@@ -223,10 +232,8 @@ def generate_csv_report(summary: RunSummary, output: TextIO):
223232 "Result" : record .result .display_name (),
224233 }
225234 if record .params is not None :
226- row .update ({
227- k .capitalize (): v for k , v in record .params .items ()
228- })
229-
235+ row .update ({k .capitalize (): v for k , v in record .params .items ()})
236+
230237 for output_idx , error_stats in enumerate (record .tensor_error_statistics ):
231238 row [f"Output { output_idx } Error Max" ] = error_stats .error_max
232239 row [f"Output { output_idx } Error MAE" ] = error_stats .error_mae
0 commit comments