22from collections import OrderedDict , Counter
33from graph_net import analysis_util
44from graph_net import samples_statistics
5+ from graph_net .positive_tolerance_interpretation import PositiveToleranceInterpretation
56from graph_net .samples_statistics import (
67 get_errno_from_error_type ,
78)
89
910
10- def determine_tolerances (samples : list ,interpretation_type = "default" ) -> range :
11+ def determine_tolerances (samples : list ,
12+ positive_tolerance_interpretation : PositiveToleranceInterpretation ,) -> range :
1113 """Determine tolerance range based on observed errno categories."""
1214 # Currently errno categories are 1=accuracy, 2=runtime, 3=compile.
1315 # Keep logic data-driven for future extension.
14- if interpretation_type == "default" :
15- default_errnos = {1 , 2 , 3 }
16- elif interpretation_type == "mismatch_extended" :
17- default_errnos = {1 , 2 , 3 , 4 }
18- return range (- 10 , len (default_errnos ) + 2 )
16+ mapping = positive_tolerance_interpretation .get_tolerance_mapping ()
1917
18+ if not mapping :
19+ max_errno = 3
20+ else :
21+ max_errno = max (mapping .keys ())
22+ return range (- 10 , max_errno + 2 )
2023
21- def extract_statistics_at_tolerance (samples : list , tolerance : int ,interpretation_type : str = "default" ) -> dict :
24+ def extract_statistics_at_tolerance (samples : list , tolerance : int ,positive_tolerance_interpretation : PositiveToleranceInterpretation ) -> dict :
2225 """Extract statistics for a given tolerance level."""
2326 sample_data = [
2427 (
@@ -42,7 +45,7 @@ def extract_statistics_at_tolerance(samples: list, tolerance: int,interpretation
4245
4346 errno2count = dict (
4447 Counter (
45- get_errno_from_error_type (fail_type ,interpretation_type )
48+ get_errno_from_error_type (fail_type ,positive_tolerance_interpretation )
4649 for _ , _ , _ , _ , fail_type in sample_data
4750 if fail_type is not None
4851 )
@@ -105,7 +108,7 @@ def calculate_es_constructor_params_for_tolerance(
105108 pi : dict ,
106109 negative_speedup_penalty : float ,
107110 fpdb : float ,
108- interpretation_type : str = "default" ,
111+ positive_tolerance_interpretation : PositiveToleranceInterpretation
109112) -> dict :
110113 """Calculate ES(t) constructor parameters (alpha, beta, gamma, lambda, eta) and final scores for a tolerance level."""
111114 aggregated_params = samples_statistics .calculate_es_components_values (
@@ -116,7 +119,7 @@ def calculate_es_constructor_params_for_tolerance(
116119 negative_speedup_penalty = negative_speedup_penalty ,
117120 b = fpdb ,
118121 pi = pi ,
119- interpretation_type = interpretation_type ,
122+ positive_tolerance_interpretation = positive_tolerance_interpretation ,
120123 )
121124
122125 alpha = aggregated_params ["alpha" ]
@@ -205,7 +208,7 @@ def __init__(
205208 total_samples : int ,
206209 negative_speedup_penalty : float ,
207210 fpdb : float ,
208- interpretation_type : str = "default" ,
211+ positive_tolerance_interpretation : PositiveToleranceInterpretation ,
209212 ):
210213 self .samples = samples
211214 self .total_samples = total_samples
@@ -218,10 +221,10 @@ def __init__(
218221 "slowdown_speedups" : [],
219222 "errno2count" : {},
220223 }
221- self .interpretation_type = interpretation_type
224+ self .positive_tolerance_interpretation = positive_tolerance_interpretation
222225
223226 def build_report (self , tolerance : int ) -> dict :
224- current_stats = extract_statistics_at_tolerance (self .samples , tolerance ,self .interpretation_type )
227+ current_stats = extract_statistics_at_tolerance (self .samples , tolerance ,self .positive_tolerance_interpretation )
225228
226229 if tolerance == 1 :
227230 self .pi = _freeze_statistics_at_tolerance (
@@ -247,7 +250,7 @@ def build_report(self, tolerance: int) -> dict:
247250 pi_for_calc ,
248251 self .negative_speedup_penalty ,
249252 self .fpdb ,
250- self .interpretation_type ,
253+ self .positive_tolerance_interpretation ,
251254 )
252255 # Use calculated pi from es_constructor_params for display and return
253256 calculated_pi = es_constructor_params .get ("pi" , self .pi )
@@ -289,9 +292,9 @@ def _empty_report(self, tolerance: int) -> dict:
289292def verify_es_constructor_params_across_tolerances (
290293 samples : list ,
291294 folder_name : str ,
295+ positive_tolerance_interpretation : PositiveToleranceInterpretation ,
292296 negative_speedup_penalty : float = 0 ,
293297 fpdb : float = 0.1 ,
294- interpretation_type : str = "default" ,
295298) -> dict :
296299 """
297300 Verify and print ES constructor parameters (alpha, beta, gamma, lambda, eta, pi) for each
@@ -307,13 +310,13 @@ def verify_es_constructor_params_across_tolerances(
307310 print (f"Verifying Aggregated Parameters for '{ folder_name } '" )
308311 print (f"{ '=' * 80 } " )
309312
310- tolerances = determine_tolerances (samples ,interpretation_type )
313+ tolerances = determine_tolerances (samples ,positive_tolerance_interpretation )
311314 builder = ToleranceReportBuilder (
312315 samples = samples ,
313316 total_samples = total_samples ,
314317 negative_speedup_penalty = negative_speedup_penalty ,
315318 fpdb = fpdb ,
316- interpretation_type = interpretation_type ,
319+ positive_tolerance_interpretation = positive_tolerance_interpretation ,
317320 )
318321
319322 results = OrderedDict (
0 commit comments