@@ -149,6 +149,10 @@ def calculate_pi(
149149 """
150150 correct_count = len (correct_speedups )
151151 error_count = total_samples - correct_count
152+ counted_errors = sum (errno2count .values ())
153+ assert (
154+ error_count == counted_errors
155+ ), f"error_count mismatch: got { error_count } , but errno2count sums to { counted_errors } "
152156 if error_count == 0 :
153157 return {errno : 0.0 for errno in errno2count .keys ()}
154158
@@ -158,10 +162,36 @@ def calculate_pi(
158162 return pi
159163
160164
165+ def resolve_errno_tolerance (
166+ errno2count : dict [int , int ], custom_map : dict [int , int ] | None
167+ ) -> dict [int , int ]:
168+ """
169+ Build a sorted errno -> tolerance map for downstream gamma calculation.
170+
171+ Args:
172+ errno2count: Observed errno occurrences in the dataset.
173+ custom_map: Optional overrides mapping errno to its minimal tolerated tolerance.
174+
175+ Returns:
176+ Ordered dict (by errno) mapping each errno seen in errno2count
177+ to the tolerance value where it becomes tolerated. Defaults to:
178+ - errno 1 (accuracy) -> 1
179+ - errno >=2 (runtime/compile) -> 3
180+ """
181+ custom_map = custom_map or {}
182+
183+ def tolerance_for (errno : int ) -> int :
184+ if errno in custom_map :
185+ return custom_map [errno ]
186+ return 1 if errno == 1 else 3
187+
188+ return {errno : tolerance_for (errno ) for errno in sorted (errno2count .keys ())}
189+
190+
161191def calculate_gamma (
162192 tolerance : int ,
163- get_pi : Callable [[int ], float ],
164- errno_tolerances : list [ int ],
193+ pi_value4errno : Callable [[int ], float ],
194+ errno_as_tolerances : dict [ int , int ],
165195 b : float = 0.1 ,
166196) -> float :
167197 """
@@ -172,26 +202,24 @@ def calculate_gamma(
172202
173203 Args:
174204 tolerance: Tolerance level t
175- get_pi: Function that takes error type index c and returns π_c (proportion of error type c)
176- errno_tolerances: List of tolerance thresholds for each error type.
177- Index corresponds to error type index c, value is the threshold.
178- An error type is tolerated (not penalized) when t >= threshold.
205+ pi_value4errno: Function that takes errno and returns π_c (proportion of error type c).
206+ errno_as_tolerances: Mapping of errno to tolerance thresholds.
207+ An error type is tolerated (not penalized) when t >= threshold for that errno.
179208 b: Base penalty for severe errors (default: 0.1)
180209
181210 Returns:
182211 Gamma value (average error penalty)
183212 """
184- if len ( errno_tolerances ) = = 0 :
213+ if tolerance < = 0 :
185214 return b
186215
187- # Calculate indicator for each error type: 1 if not tolerated, 0 if tolerated
188- pi_sum = 0.0
189- for error_type_index in range (len (errno_tolerances )):
190- pi_c = get_pi (error_type_index )
191- threshold_c = errno_tolerances [error_type_index ]
192- # Error type is not tolerated (penalized) when t < threshold
193- indicator = 1 if tolerance < threshold_c else 0
194- pi_sum += pi_c * indicator
216+ # Calculate indicator-weighted pi sum for errnos that are not tolerated
217+ pi_sum = sum (
218+ pi_value
219+ for errno , errno_tolerance in errno_as_tolerances .items ()
220+ for pi_value in [pi_value4errno (errno )]
221+ if tolerance < errno_tolerance
222+ )
195223
196224 return b ** pi_sum
197225
@@ -202,7 +230,7 @@ def calculate_s_t_from_aggregated(
202230 lambda_ : float ,
203231 eta : float ,
204232 negative_speedup_penalty : float ,
205- fpdb : float ,
233+ b : float ,
206234) -> float :
207235 """
208236 Calculate S(t) from aggregated parameters.
@@ -215,15 +243,15 @@ def calculate_s_t_from_aggregated(
215243 lambda_: Fraction of correct samples
216244 eta: Fraction of slowdown cases within correct samples
217245 negative_speedup_penalty: Penalty power p for negative speedup
218- fpdb : Base penalty b for severe errors
246+ b : Base penalty for severe errors or accuracy violation
219247
220248 Returns:
221249 S(t) value calculated from aggregated parameters
222250 """
223251 return (
224252 alpha ** lambda_
225253 * beta ** (lambda_ * eta * negative_speedup_penalty )
226- * fpdb ** (1 - lambda_ )
254+ * b ** (1 - lambda_ )
227255 )
228256
229257
@@ -258,85 +286,60 @@ def calculate_es_t_from_aggregated(
258286 )
259287
260288
261- def calculate_all_aggregated_parameters (
289+ def calculate_es_components_values (
262290 total_samples : int ,
263291 correct_speedups : list [float ],
264292 errno2count : dict [int , int ],
265- t_key : int ,
293+ tolerance : int ,
266294 negative_speedup_penalty : float = 0.0 ,
267- fpdb : float = 0.1 ,
295+ b : float = 0.1 ,
268296 pi : dict [int , float ] | None = None ,
269- errno_tolerance_thresholds : dict [int , int ] | None = None ,
297+ errno_as_tolerance : dict [int , int ] | None = None ,
270298) -> dict :
271299 """
272- Calculate all aggregated parameters for a given tolerance level.
273-
274- This is a convenience function that calculates all aggregated parameters at once.
300+ Calculate aggregated parameters for a given tolerance level.
275301
276302 Args:
277303 total_samples: Total number of samples
278304 correct_speedups: List of speedup values for correct samples
279305 errno2count: Dictionary mapping errno (error number) to their counts.
280306 Errno values: 1=accuracy, 2=runtime, 3=compilation.
281- t_key : Tolerance level
307+ tolerance : Tolerance level
282308 negative_speedup_penalty: Penalty power p for negative speedup
283- fpdb : Base penalty b for severe errors
309+ b : Base penalty for severe errors or accuracy violation
284310 pi: Dictionary mapping errno to their proportions (calculated at t=1).
285311 If None, will be calculated from errno2count.
286- errno_tolerance_thresholds: Dictionary mapping errno to their tolerance thresholds .
287- An error type is tolerated (not penalized) when t >= threshold .
288- If None, uses default thresholds: {1: 1} for accuracy errors (errno=1) , {2: 3, 3: 3} for others.
312+ errno_as_tolerance: Mapping from errno to its minimum tolerated tolerance .
313+ An error type is tolerated (not penalized) when tolerance >= its value .
314+ If None, defaults to {1: 1} for accuracy, {2: 3, 3: 3} for others.
289315
290316 Returns:
291- Dictionary containing all aggregated parameters and calculated scores :
317+ Dictionary containing ES(t) component values :
292318 {
293319 'alpha': float,
294320 'beta': float,
295321 'lambda': float,
296322 'eta': float,
297323 'gamma': float,
298- 'pi': dict[int, float],
299- 's_t': float,
300- 'es_t': float
324+ 'pi': dict[int, float]
301325 }
302326 """
303- # Use default error tolerance thresholds if not provided
304- if errno_tolerance_thresholds is None :
305- errno_tolerance_thresholds = {}
306- for errno in errno2count .keys ():
307- if errno == 1 : # accuracy errors
308- errno_tolerance_thresholds [errno ] = 1
309- else : # runtime (2) or compilation (3) errors
310- errno_tolerance_thresholds [errno ] = 3
311-
312327 # Calculate pi if not provided
313328 if pi is None :
314329 pi = calculate_pi (errno2count , total_samples , correct_speedups )
315330
316- # Convert dictionary-based pi and thresholds to indexed format for calculate_gamma
317- # Create ordered list of errnos for consistent indexing (sorted by errno)
318- errnos = sorted (errno2count .keys ())
319- errno_tolerances = [errno_tolerance_thresholds .get (errno , 3 ) for errno in errnos ]
331+ # Prepare errno-ordered tolerance mapping for calculate_gamma
332+ errno_as_tolerances = resolve_errno_tolerance (errno2count , errno_as_tolerance )
320333
321- # Create get_pi function that maps error type index to pi value
322- def get_pi (error_type_index : int ) -> float :
323- if error_type_index < len (errnos ):
324- errno = errnos [error_type_index ]
325- return pi .get (errno , 0.0 )
326- return 0.0
334+ # Create pi_value4errno function that maps errno to pi value
335+ def pi_value4errno (errno : int ) -> float :
336+ return pi .get (errno , 0.0 )
327337
328338 alpha = calculate_alpha (correct_speedups )
329339 beta = calculate_beta (correct_speedups )
330340 lambda_ = calculate_lambda (correct_speedups , total_samples )
331341 eta = calculate_eta (correct_speedups )
332- gamma = calculate_gamma (t_key , get_pi , errno_tolerances , fpdb )
333-
334- s_t = calculate_s_t_from_aggregated (
335- alpha , beta , lambda_ , eta , negative_speedup_penalty , fpdb
336- )
337- es_t = calculate_es_t_from_aggregated (
338- alpha , beta , lambda_ , eta , gamma , negative_speedup_penalty
339- )
342+ gamma = calculate_gamma (tolerance , pi_value4errno , errno_as_tolerances , b )
340343
341344 return {
342345 "alpha" : alpha ,
@@ -345,6 +348,4 @@ def get_pi(error_type_index: int) -> float:
345348 "eta" : eta ,
346349 "gamma" : gamma ,
347350 "pi" : pi ,
348- "s_t" : s_t ,
349- "es_t" : es_t ,
350351 }
0 commit comments