Skip to content

Commit a4aa31f

Browse files
committed
refactor: split tolerance report generation
1 parent 22339b3 commit a4aa31f

File tree

3 files changed

+348
-263
lines changed

3 files changed

+348
-263
lines changed

graph_net/analysis_util.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -582,13 +582,13 @@ def print_stat_info(
582582
print(f" - Details for tolerance={t_key}:")
583583
if total_samples > 0:
584584
# Calculate all aggregated parameters using the dedicated module
585-
aggregated_params = samples_statistics.calculate_all_aggregated_parameters(
585+
aggregated_params = samples_statistics.calculate_es_components_values(
586586
total_samples=total_samples,
587587
correct_speedups=correct_speedups,
588588
errno2count=errno2count,
589-
t_key=t_key,
589+
tolerance=t_key,
590590
negative_speedup_penalty=negative_speedup_penalty,
591-
fpdb=fpdb,
591+
b=fpdb,
592592
pi=pi,
593593
)
594594

@@ -597,8 +597,12 @@ def print_stat_info(
597597
lambda_ = aggregated_params["lambda"]
598598
eta = aggregated_params["eta"]
599599
gamma = aggregated_params["gamma"]
600-
expected_s = aggregated_params["s_t"]
601-
expected_es = aggregated_params["es_t"]
600+
expected_s = samples_statistics.calculate_s_t_from_aggregated(
601+
alpha, beta, lambda_, eta, negative_speedup_penalty, fpdb
602+
)
603+
expected_es = samples_statistics.calculate_es_t_from_aggregated(
604+
alpha, beta, lambda_, eta, gamma, negative_speedup_penalty
605+
)
602606

603607
print(
604608
f" - alpha: {alpha:.3f} (Geometric mean speedup of correct samples)"

graph_net/samples_statistics.py

Lines changed: 64 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,10 @@ def calculate_pi(
149149
"""
150150
correct_count = len(correct_speedups)
151151
error_count = total_samples - correct_count
152+
counted_errors = sum(errno2count.values())
153+
assert (
154+
error_count == counted_errors
155+
), f"error_count mismatch: got {error_count}, but errno2count sums to {counted_errors}"
152156
if error_count == 0:
153157
return {errno: 0.0 for errno in errno2count.keys()}
154158

@@ -158,10 +162,36 @@ def calculate_pi(
158162
return pi
159163

160164

165+
def resolve_errno_tolerance(
166+
errno2count: dict[int, int], custom_map: dict[int, int] | None
167+
) -> dict[int, int]:
168+
"""
169+
Build a sorted errno -> tolerance map for downstream gamma calculation.
170+
171+
Args:
172+
errno2count: Observed errno occurrences in the dataset.
173+
custom_map: Optional overrides mapping errno to its minimal tolerated tolerance.
174+
175+
Returns:
176+
Ordered dict (by errno) mapping each errno seen in errno2count
177+
to the tolerance value where it becomes tolerated. Defaults to:
178+
- errno 1 (accuracy) -> 1
179+
- errno >=2 (runtime/compile) -> 3
180+
"""
181+
custom_map = custom_map or {}
182+
183+
def tolerance_for(errno: int) -> int:
184+
if errno in custom_map:
185+
return custom_map[errno]
186+
return 1 if errno == 1 else 3
187+
188+
return {errno: tolerance_for(errno) for errno in sorted(errno2count.keys())}
189+
190+
161191
def calculate_gamma(
162192
tolerance: int,
163-
get_pi: Callable[[int], float],
164-
errno_tolerances: list[int],
193+
pi_value4errno: Callable[[int], float],
194+
errno_as_tolerances: dict[int, int],
165195
b: float = 0.1,
166196
) -> float:
167197
"""
@@ -172,26 +202,24 @@ def calculate_gamma(
172202
173203
Args:
174204
tolerance: Tolerance level t
175-
get_pi: Function that takes error type index c and returns π_c (proportion of error type c)
176-
errno_tolerances: List of tolerance thresholds for each error type.
177-
Index corresponds to error type index c, value is the threshold.
178-
An error type is tolerated (not penalized) when t >= threshold.
205+
pi_value4errno: Function that takes errno and returns π_c (proportion of error type c).
206+
errno_as_tolerances: Mapping of errno to tolerance thresholds.
207+
An error type is tolerated (not penalized) when t >= threshold for that errno.
179208
b: Base penalty for severe errors (default: 0.1)
180209
181210
Returns:
182211
Gamma value (average error penalty)
183212
"""
184-
if len(errno_tolerances) == 0:
213+
if tolerance <= 0:
185214
return b
186215

187-
# Calculate indicator for each error type: 1 if not tolerated, 0 if tolerated
188-
pi_sum = 0.0
189-
for error_type_index in range(len(errno_tolerances)):
190-
pi_c = get_pi(error_type_index)
191-
threshold_c = errno_tolerances[error_type_index]
192-
# Error type is not tolerated (penalized) when t < threshold
193-
indicator = 1 if tolerance < threshold_c else 0
194-
pi_sum += pi_c * indicator
216+
# Calculate indicator-weighted pi sum for errnos that are not tolerated
217+
pi_sum = sum(
218+
pi_value
219+
for errno, errno_tolerance in errno_as_tolerances.items()
220+
for pi_value in [pi_value4errno(errno)]
221+
if tolerance < errno_tolerance
222+
)
195223

196224
return b**pi_sum
197225

@@ -202,7 +230,7 @@ def calculate_s_t_from_aggregated(
202230
lambda_: float,
203231
eta: float,
204232
negative_speedup_penalty: float,
205-
fpdb: float,
233+
b: float,
206234
) -> float:
207235
"""
208236
Calculate S(t) from aggregated parameters.
@@ -215,15 +243,15 @@ def calculate_s_t_from_aggregated(
215243
lambda_: Fraction of correct samples
216244
eta: Fraction of slowdown cases within correct samples
217245
negative_speedup_penalty: Penalty power p for negative speedup
218-
fpdb: Base penalty b for severe errors
246+
b: Base penalty for severe errors or accuracy violation
219247
220248
Returns:
221249
S(t) value calculated from aggregated parameters
222250
"""
223251
return (
224252
alpha**lambda_
225253
* beta ** (lambda_ * eta * negative_speedup_penalty)
226-
* fpdb ** (1 - lambda_)
254+
* b ** (1 - lambda_)
227255
)
228256

229257

@@ -258,85 +286,60 @@ def calculate_es_t_from_aggregated(
258286
)
259287

260288

261-
def calculate_all_aggregated_parameters(
289+
def calculate_es_components_values(
262290
total_samples: int,
263291
correct_speedups: list[float],
264292
errno2count: dict[int, int],
265-
t_key: int,
293+
tolerance: int,
266294
negative_speedup_penalty: float = 0.0,
267-
fpdb: float = 0.1,
295+
b: float = 0.1,
268296
pi: dict[int, float] | None = None,
269-
errno_tolerance_thresholds: dict[int, int] | None = None,
297+
errno_as_tolerance: dict[int, int] | None = None,
270298
) -> dict:
271299
"""
272-
Calculate all aggregated parameters for a given tolerance level.
273-
274-
This is a convenience function that calculates all aggregated parameters at once.
300+
Calculate aggregated parameters for a given tolerance level.
275301
276302
Args:
277303
total_samples: Total number of samples
278304
correct_speedups: List of speedup values for correct samples
279305
errno2count: Dictionary mapping errno (error number) to their counts.
280306
Errno values: 1=accuracy, 2=runtime, 3=compilation.
281-
t_key: Tolerance level
307+
tolerance: Tolerance level
282308
negative_speedup_penalty: Penalty power p for negative speedup
283-
fpdb: Base penalty b for severe errors
309+
b: Base penalty for severe errors or accuracy violation
284310
pi: Dictionary mapping errno to their proportions (calculated at t=1).
285311
If None, will be calculated from errno2count.
286-
errno_tolerance_thresholds: Dictionary mapping errno to their tolerance thresholds.
287-
An error type is tolerated (not penalized) when t >= threshold.
288-
If None, uses default thresholds: {1: 1} for accuracy errors (errno=1), {2: 3, 3: 3} for others.
312+
errno_as_tolerance: Mapping from errno to its minimum tolerated tolerance.
313+
An error type is tolerated (not penalized) when tolerance >= its value.
314+
If None, defaults to {1: 1} for accuracy, {2: 3, 3: 3} for others.
289315
290316
Returns:
291-
Dictionary containing all aggregated parameters and calculated scores:
317+
Dictionary containing ES(t) component values:
292318
{
293319
'alpha': float,
294320
'beta': float,
295321
'lambda': float,
296322
'eta': float,
297323
'gamma': float,
298-
'pi': dict[int, float],
299-
's_t': float,
300-
'es_t': float
324+
'pi': dict[int, float]
301325
}
302326
"""
303-
# Use default error tolerance thresholds if not provided
304-
if errno_tolerance_thresholds is None:
305-
errno_tolerance_thresholds = {}
306-
for errno in errno2count.keys():
307-
if errno == 1: # accuracy errors
308-
errno_tolerance_thresholds[errno] = 1
309-
else: # runtime (2) or compilation (3) errors
310-
errno_tolerance_thresholds[errno] = 3
311-
312327
# Calculate pi if not provided
313328
if pi is None:
314329
pi = calculate_pi(errno2count, total_samples, correct_speedups)
315330

316-
# Convert dictionary-based pi and thresholds to indexed format for calculate_gamma
317-
# Create ordered list of errnos for consistent indexing (sorted by errno)
318-
errnos = sorted(errno2count.keys())
319-
errno_tolerances = [errno_tolerance_thresholds.get(errno, 3) for errno in errnos]
331+
# Prepare errno-ordered tolerance mapping for calculate_gamma
332+
errno_as_tolerances = resolve_errno_tolerance(errno2count, errno_as_tolerance)
320333

321-
# Create get_pi function that maps error type index to pi value
322-
def get_pi(error_type_index: int) -> float:
323-
if error_type_index < len(errnos):
324-
errno = errnos[error_type_index]
325-
return pi.get(errno, 0.0)
326-
return 0.0
334+
# Create pi_value4errno function that maps errno to pi value
335+
def pi_value4errno(errno: int) -> float:
336+
return pi.get(errno, 0.0)
327337

328338
alpha = calculate_alpha(correct_speedups)
329339
beta = calculate_beta(correct_speedups)
330340
lambda_ = calculate_lambda(correct_speedups, total_samples)
331341
eta = calculate_eta(correct_speedups)
332-
gamma = calculate_gamma(t_key, get_pi, errno_tolerances, fpdb)
333-
334-
s_t = calculate_s_t_from_aggregated(
335-
alpha, beta, lambda_, eta, negative_speedup_penalty, fpdb
336-
)
337-
es_t = calculate_es_t_from_aggregated(
338-
alpha, beta, lambda_, eta, gamma, negative_speedup_penalty
339-
)
342+
gamma = calculate_gamma(tolerance, pi_value4errno, errno_as_tolerances, b)
340343

341344
return {
342345
"alpha": alpha,
@@ -345,6 +348,4 @@ def get_pi(error_type_index: int) -> float:
345348
"eta": eta,
346349
"gamma": gamma,
347350
"pi": pi,
348-
"s_t": s_t,
349-
"es_t": es_t,
350351
}

0 commit comments

Comments
 (0)