99from collections .abc import Callable
1010
1111
12+ def get_errno_from_error_type (error_type : str ) -> int :
13+ """
14+ Map error type string to errno (error number) for sorting.
15+
16+ According to the paper:
17+ - c=1: accuracy errors (精度错误)
18+ - c=2: runtime crashes (运行时崩溃)
19+ - c=3: compilation failures (编译失败)
20+
21+ Args:
22+ error_type: Error type string (e.g., "accuracy", "eager", "compiled")
23+
24+ Returns:
25+ Errno (1, 2, or 3) based on error type
26+ """
27+ if error_type == "accuracy" :
28+ return 1
29+ elif error_type in ("eager" , "other" , "runtime_fail" , "eager_fail" ):
30+ return 2
31+ elif error_type in ("compiled" , "compile_fail" ):
32+ return 3
33+ else :
34+ # Default to 2 for unknown error types (runtime errors)
35+ return 2
36+
37+
38+ def get_error_type_from_errno (errno : int ) -> str :
39+ """
40+ Map errno (error number) back to error type string.
41+
42+ This is the reverse mapping of get_errno_from_error_type.
43+ Used when error type string information is needed.
44+
45+ Args:
46+ errno: Error number (1, 2, or 3)
47+
48+ Returns:
49+ Error type string:
50+ - 1 -> "accuracy"
51+ - 2 -> "runtime_fail"
52+ - 3 -> "compile_fail"
53+ """
54+ errno_to_error_type = {
55+ 1 : "accuracy" ,
56+ 2 : "runtime_fail" ,
57+ 3 : "compile_fail" ,
58+ }
59+ return errno_to_error_type .get (errno , "runtime_fail" )
60+
61+
1262def calculate_alpha (correct_speedups : list [float ]) -> float :
1363 """
1464 Calculate alpha: geometric mean of correct sample speedups.
@@ -80,30 +130,31 @@ def calculate_eta(correct_speedups: list[float]) -> float:
80130
81131
82132def calculate_pi (
83- error_type_counts : dict [str , int ], total_samples : int , correct_speedups : list [float ]
84- ) -> dict [str , float ]:
133+ errno2count : dict [int , int ], total_samples : int , correct_speedups : list [float ]
134+ ) -> dict [int , float ]:
85135 """
86136 Calculate pi: error type proportions for t > 0.
87137
88138 According to Appendix C: pi_c is the proportion of error type c among all error samples.
89139
90140 Args:
91- error_type_counts: Dictionary mapping error type names to their counts
141+ errno2count: Dictionary mapping errno (error number) to their counts.
142+ Errno values: 1=accuracy, 2=runtime, 3=compilation.
92143 total_samples: Total number of samples
93144 correct_speedups: List of speedup values for correct samples
94145
95146 Returns:
96- Dictionary mapping error type names to their proportions among error samples.
147+ Dictionary mapping errno to their proportions among error samples.
97148 If error_count is 0, returns a dictionary with all proportions set to 0.0.
98149 """
99150 correct_count = len (correct_speedups )
100151 error_count = total_samples - correct_count
101152 if error_count == 0 :
102- return {error_type : 0.0 for error_type in error_type_counts .keys ()}
153+ return {errno : 0.0 for errno in errno2count .keys ()}
103154
104155 pi = {}
105- for error_type , count in error_type_counts .items ():
106- pi [error_type ] = count / error_count
156+ for errno , count in errno2count .items ():
157+ pi [errno ] = count / error_count
107158 return pi
108159
109160
@@ -210,12 +261,12 @@ def calculate_es_t_from_aggregated(
210261def calculate_all_aggregated_parameters (
211262 total_samples : int ,
212263 correct_speedups : list [float ],
213- error_type_counts : dict [str , int ],
264+ errno2count : dict [int , int ],
214265 t_key : int ,
215266 negative_speedup_penalty : float = 0.0 ,
216267 fpdb : float = 0.1 ,
217- pi : dict [str , float ] | None = None ,
218- error_tolerance_thresholds : dict [str , int ] | None = None ,
268+ pi : dict [int , float ] | None = None ,
269+ errno_tolerance_thresholds : dict [int , int ] | None = None ,
219270) -> dict :
220271 """
221272 Calculate all aggregated parameters for a given tolerance level.
@@ -225,15 +276,16 @@ def calculate_all_aggregated_parameters(
225276 Args:
226277 total_samples: Total number of samples
227278 correct_speedups: List of speedup values for correct samples
228- error_type_counts: Dictionary mapping error type names to their counts
279+ errno2count: Dictionary mapping errno (error number) to their counts.
280+ Errno values: 1=accuracy, 2=runtime, 3=compilation.
229281 t_key: Tolerance level
230282 negative_speedup_penalty: Penalty power p for negative speedup
231283 fpdb: Base penalty b for severe errors
232- pi: Dictionary mapping error type names to their proportions (calculated at t=1).
233- If None, will be calculated from error_type_counts .
234- error_tolerance_thresholds : Dictionary mapping error type names to their tolerance thresholds.
284+ pi: Dictionary mapping errno to their proportions (calculated at t=1).
285+ If None, will be calculated from errno2count .
286+ errno_tolerance_thresholds : Dictionary mapping errno to their tolerance thresholds.
235287 An error type is tolerated (not penalized) when t >= threshold.
236- If None, uses default thresholds: {"accuracy" : 1} for accuracy errors, 3 for others.
288+ If None, uses default thresholds: {1 : 1} for accuracy errors (errno=1), {2: 3, 3: 3} for others.
237289
238290 Returns:
239291 Dictionary containing all aggregated parameters and calculated scores:
@@ -243,36 +295,34 @@ def calculate_all_aggregated_parameters(
243295 'lambda': float,
244296 'eta': float,
245297 'gamma': float,
246- 'pi': dict[str , float],
298+ 'pi': dict[int , float],
247299 's_t': float,
248300 'es_t': float
249301 }
250302 """
251303 # Use default error tolerance thresholds if not provided
252- if error_tolerance_thresholds is None :
253- error_tolerance_thresholds = {}
254- for error_type in error_type_counts .keys ():
255- if error_type == "accuracy" :
256- error_tolerance_thresholds [ error_type ] = 1
257- else :
258- error_tolerance_thresholds [ error_type ] = 3
304+ if errno_tolerance_thresholds is None :
305+ errno_tolerance_thresholds = {}
306+ for errno in errno2count .keys ():
307+ if errno == 1 : # accuracy errors
308+ errno_tolerance_thresholds [ errno ] = 1
309+ else : # runtime (2) or compilation (3) errors
310+ errno_tolerance_thresholds [ errno ] = 3
259311
260312 # Calculate pi if not provided
261313 if pi is None :
262- pi = calculate_pi (error_type_counts , total_samples , correct_speedups )
314+ pi = calculate_pi (errno2count , total_samples , correct_speedups )
263315
264316 # Convert dictionary-based pi and thresholds to indexed format for calculate_gamma
265- # Create ordered list of error types for consistent indexing
266- error_types = sorted (error_type_counts .keys ())
267- errno_tolerances = [
268- error_tolerance_thresholds .get (error_type , 3 ) for error_type in error_types
269- ]
317+ # Create ordered list of errnos for consistent indexing (sorted by errno)
318+ errnos = sorted (errno2count .keys ())
319+ errno_tolerances = [errno_tolerance_thresholds .get (errno , 3 ) for errno in errnos ]
270320
271321 # Create get_pi function that maps error type index to pi value
272322 def get_pi (error_type_index : int ) -> float :
273- if error_type_index < len (error_types ):
274- error_type = error_types [error_type_index ]
275- return pi .get (error_type , 0.0 )
323+ if error_type_index < len (errnos ):
324+ errno = errnos [error_type_index ]
325+ return pi .get (errno , 0.0 )
276326 return 0.0
277327
278328 alpha = calculate_alpha (correct_speedups )
0 commit comments