55from graph_net import analysis_util
66
77
8- def compare_single_tolerance_level (
9- tolerance_level : int ,
10- micro_es : float ,
8+ def es_result_checker (
9+ es_from_microscopic : float , es_from_macro : float , atol : float , rtol : float
10+ ) -> bool :
11+ """
12+ Check if ES(t) values from microscopic and macro calculations match.
13+
14+ Args:
15+ es_from_microscopic: ES(t) value from microscopic-level calculation
16+ es_from_macro: ES(t) value from aggregated-level calculation
17+ atol: Absolute tolerance for comparison
18+ rtol: Relative tolerance for comparison
19+
20+ Returns:
21+ True if values match within tolerance, False otherwise
22+ """
23+ diff = abs (es_from_microscopic - es_from_macro )
24+ return diff < atol or diff < rtol * max (
25+ abs (es_from_microscopic ), abs (es_from_macro ), 1e-10
26+ )
27+
28+
29+ def compare_aggregated_es_and_microscopic_es (
30+ tolerance : int ,
31+ microscopic_es : float ,
1132 aggregated_es : float | None ,
12- tolerance_threshold : float ,
33+ atol : float = 1e-3 ,
34+ rtol : float = 1e-3 ,
1335) -> tuple [bool , float , float ]:
1436 """
15- Compare micro and aggregated ES(t) values for a single tolerance level.
37+ Compare ES(t) values from aggregated and microscopic calculations at a tolerance level.
1638
1739 Args:
18- tolerance_level : Tolerance level t
19- micro_es : ES(t) value from micro -level calculation
40+ tolerance : Tolerance level t
41+ microscopic_es : ES(t) value from microscopic -level calculation
2042 aggregated_es: ES(t) value from aggregated-level calculation, or None if missing
21- tolerance_threshold: Floating point comparison tolerance
43+ atol: Absolute tolerance for comparison
44+ rtol: Relative tolerance for comparison
2245
2346 Returns:
2447 Tuple of (is_matched, diff, relative_diff)
2548 """
2649 if aggregated_es is None :
2750 return False , 0.0 , 0.0
2851
29- diff = abs (micro_es - aggregated_es )
30- relative_diff = diff / max (abs (micro_es ), abs (aggregated_es ), 1e-10 )
31- is_matched = diff < tolerance_threshold or relative_diff < tolerance_threshold
52+ diff = abs (microscopic_es - aggregated_es )
53+ relative_diff = diff / max (abs (microscopic_es ), abs (aggregated_es ), 1e-10 )
54+ is_matched = es_result_checker ( microscopic_es , aggregated_es , atol , rtol )
3255
3356 return is_matched , diff , relative_diff
3457
3558
3659def print_verification_result (
37- tolerance_level : int ,
38- micro_es : float ,
60+ tolerance : int ,
61+ microscopic_es : float ,
3962 aggregated_es : float | None ,
4063 diff : float ,
4164 relative_diff : float ,
4265 is_matched : bool ,
4366) -> None :
4467 """Print verification result for a single tolerance level."""
4568 if aggregated_es is None :
46- print (f"ERROR: No aggregated result for t={ tolerance_level } , cannot verify" )
69+ print (f"ERROR: No aggregated result for t={ tolerance } , cannot verify" )
4770 elif is_matched :
4871 print (
49- f"t={ tolerance_level :3d} : MATCHED - Micro : { micro_es :.6f} , Aggregated: { aggregated_es :.6f} , Diff: { diff :.2e} "
72+ f"t={ tolerance :3d} : MATCHED - Microscopic : { microscopic_es :.6f} , Aggregated: { aggregated_es :.6f} , Diff: { diff :.2e} "
5073 )
5174 else :
5275 print (
53- f"t={ tolerance_level :3d} : MISMATCH - Micro : { micro_es :.6f} , Aggregated: { aggregated_es :.6f} , Diff: { diff :.2e} ({ relative_diff * 100 :.4f} %)"
76+ f"t={ tolerance :3d} : MISMATCH - Microscopic : { microscopic_es :.6f} , Aggregated: { aggregated_es :.6f} , Diff: { diff :.2e} ({ relative_diff * 100 :.4f} %)"
5477 )
5578
5679
57- def verify_aggregated_micro_consistency (
58- es_scores : dict , folder_name : str , tolerance_threshold : float
59- ) -> tuple [dict , bool ]:
80+ def get_verified_aggregated_es_values (es_scores : dict , folder_name : str ) -> dict :
6081 """
61- Verify consistency between aggregated and micro -level ES(t) calculations.
82+ Get verified ES(t) values by checking consistency between aggregated and microscopic -level calculations.
6283
6384 Args:
64- es_scores: Dictionary of ES(t) scores from micro -level calculation
85+ es_scores: Dictionary of ES(t) scores from microscopic -level calculation
6586 folder_name: Name of the folder being verified
66- tolerance_threshold: Floating point comparison tolerance
6787
6888 Returns:
69- Tuple of (verified_scores, all_matched):
70- - verified_scores: Dictionary of verified scores (only matched tolerance levels)
71- - all_matched: True if all tolerance levels matched, False otherwise
89+ Dictionary of verified ES(t) values (only matched tolerance levels).
90+ Returns empty dict if validation fails.
7291 """
7392 aggregated_results = getattr (es_scores , "_aggregated_results" , {})
74- verified_scores = {}
93+ verified_es_values = {}
7594 all_matched = True
7695
7796 print (f"\n { '=' * 80 } " )
78- print (f"Verifying Aggregated/Micro Consistency for '{ folder_name } '" )
97+ print (f"Verifying Aggregated/Microscopic Consistency for '{ folder_name } '" )
7998 print (f"{ '=' * 80 } " )
8099
81- for tolerance_level , micro_es in es_scores .items ():
82- aggregated_es = aggregated_results .get (tolerance_level )
83- is_matched , diff , relative_diff = compare_single_tolerance_level (
84- tolerance_level , micro_es , aggregated_es , tolerance_threshold
100+ for tolerance , microscopic_es in es_scores .items ():
101+ aggregated_es = aggregated_results .get (tolerance )
102+ is_matched , diff , relative_diff = compare_aggregated_es_and_microscopic_es (
103+ tolerance , microscopic_es , aggregated_es
85104 )
86105
87106 print_verification_result (
88- tolerance_level , micro_es , aggregated_es , diff , relative_diff , is_matched
107+ tolerance ,
108+ microscopic_es ,
109+ aggregated_es ,
110+ diff ,
111+ relative_diff ,
112+ is_matched ,
89113 )
90114
91115 if aggregated_es is None or not is_matched :
92116 all_matched = False
93117 if is_matched :
94- verified_scores [ tolerance_level ] = micro_es
118+ verified_es_values [ tolerance ] = microscopic_es
95119
96120 if not all_matched :
97121 print (
98- f"\n ERROR: Aggregated and micro results do not match for '{ folder_name } '!"
122+ f"\n ERROR: Aggregated and microscopic results do not match for '{ folder_name } '!"
99123 )
100124 print ("Calculation validation failed. Results will NOT be used for plotting." )
101125 print ("Please verify the calculation logic using verify_aggregated_params.py" )
102126 print (f"{ '=' * 80 } \n " )
127+ return {}
103128 else :
104- print (f"\n SUCCESS: All aggregated and micro results match for '{ folder_name } '." )
129+ print (
130+ f"\n SUCCESS: All aggregated and microscopic results match for '{ folder_name } '."
131+ )
105132 print (f"{ '=' * 80 } \n " )
106-
107- return verified_scores , all_matched
133+ return verified_es_values
108134
109135
110136def plot_ES_results (s_scores : dict , cli_args : argparse .Namespace ):
@@ -232,6 +258,18 @@ def main():
232258 default = 0.1 ,
233259 help = "Base penalty for severe errors (e.g., crashes, correctness failures)." ,
234260 )
261+ parser .add_argument (
262+ "--enable-aggregation-mode" ,
263+ action = "store_true" ,
264+ help = "Enable aggregation mode to verify aggregated/microscopic consistency. Default: enabled." ,
265+ )
266+ parser .add_argument (
267+ "--disable-aggregation-mode" ,
268+ dest = "enable_aggregation_mode" ,
269+ action = "store_false" ,
270+ help = "Disable aggregation mode verification." ,
271+ )
272+ parser .set_defaults (enable_aggregation_mode = True )
235273 args = parser .parse_args ()
236274
237275 # 1. Scan folders to get data
@@ -240,9 +278,8 @@ def main():
240278 print ("No valid data found. Exiting." )
241279 return
242280
243- # 2. Calculate scores for each curve and verify aggregated/micro consistency
281+ # 2. Calculate scores for each curve and verify aggregated/microscopic consistency
244282 all_es_scores = {}
245- tolerance_threshold = 1e-6 # Tolerance for floating point comparison
246283
247284 for folder_name , samples in all_results .items ():
248285 _ , es_scores = analysis_util .calculate_s_scores (
@@ -255,15 +292,16 @@ def main():
255292 # Keep original behavior: assign es_scores directly
256293 all_es_scores [folder_name ] = es_scores
257294
258- # Verify aggregated/micro consistency
259- verified_scores , all_matched = verify_aggregated_micro_consistency (
260- es_scores , folder_name , tolerance_threshold
261- )
295+ # Verify aggregated/microscopic consistency if aggregation mode is enabled
296+ if args .enable_aggregation_mode :
297+ verified_es_values = get_verified_aggregated_es_values (
298+ es_scores , folder_name
299+ )
262300
263- if not all_matched :
264- continue # Skip this curve if validation fails
301+ if not verified_es_values :
302+ continue # Skip this curve if validation fails
265303
266- all_es_scores [folder_name ] = verified_scores
304+ all_es_scores [folder_name ] = verified_es_values
267305
268306 # 3. Plot the results
269307 if any (all_es_scores .values ()):
0 commit comments