Skip to content

Commit 76d2b15

Browse files
committed
refactor: improve naming and semantics for ES calculation
- Rename verify_es_match_at_tolerance to compare_aggregated_es_and_microscopic_es - Replace tolerance_level with tolerance parameter - Replace tolerance_threshold with atol/rtol to avoid confusion - Rename verify_aggregated_microscopic_consistency to get_verified_aggregated_es_values - Change return type to dict only (remove all_matched) - Rename verified_scores to verified_es_values - Replace micro with microscopic throughout - Rename check_sample_correctness to get_sample_correctness - Rename t1 variables to first_errno_tolerance - Rename es_components to es_constructor_params - Rename calculate_parameters_for_tolerance to calculate_es_constructor_params_for_tolerance - Rename custom_map to errno_tolerance_overrides - Rename errno_as_tolerances to errno2tolerance - Add enable_aggregation_mode command line option
1 parent a4aa31f commit 76d2b15

File tree

4 files changed

+164
-109
lines changed

4 files changed

+164
-109
lines changed

graph_net/analysis_util.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -416,9 +416,9 @@ def get_correctness(dtype: str, t: int, correctness_data: dict, index: int) -> b
416416
return False
417417

418418

419-
def check_sample_correctness(sample: dict, t_key: int) -> tuple[bool, str]:
419+
def get_sample_correctness(sample: dict, t_key: int) -> tuple[bool, str]:
420420
"""
421-
Check if a sample is correct at the given tolerance level.
421+
Get sample correctness status at the given tolerance level.
422422
423423
Args:
424424
sample: Sample data dictionary
@@ -487,9 +487,9 @@ def calculate_es_rectified_speedup(
487487
speedup: float,
488488
fail_type: str,
489489
t_key: int,
490-
is_correct_at_t1: bool,
491-
speedup_at_t1: float,
492-
fail_type_at_t1: str,
490+
is_correct_at_first_errno_tolerance: bool,
491+
speedup_at_first_errno_tolerance: float,
492+
fail_type_at_first_errno_tolerance: str,
493493
negative_speedup_penalty: float,
494494
fpdb: float,
495495
) -> float:
@@ -500,9 +500,9 @@ def calculate_es_rectified_speedup(
500500
speedup: Current speedup value
501501
fail_type: Current error type
502502
t_key: Current tolerance level
503-
is_correct_at_t1: Whether sample was correct at t=1
504-
speedup_at_t1: Speedup value at t=1
505-
fail_type_at_t1: Error type at t=1
503+
is_correct_at_first_errno_tolerance: Whether sample was correct at first errno tolerance (t=1)
504+
speedup_at_first_errno_tolerance: Speedup value at first errno tolerance (t=1)
505+
fail_type_at_first_errno_tolerance: Error type at first errno tolerance (t=1)
506506
negative_speedup_penalty: Penalty power p
507507
fpdb: Base penalty for failures
508508
@@ -515,13 +515,16 @@ def calculate_es_rectified_speedup(
515515
speedup, fail_type, negative_speedup_penalty, fpdb
516516
)
517517

518-
# For t >= 1, use frozen state from t=1
519-
if not is_correct_at_t1 or speedup_at_t1 is None:
520-
return fake_perf_degrad(t_key, fail_type_at_t1, fpdb)
518+
# For t >= 1, use frozen state from first errno tolerance (t=1)
519+
if (
520+
not is_correct_at_first_errno_tolerance
521+
or speedup_at_first_errno_tolerance is None
522+
):
523+
return fake_perf_degrad(t_key, fail_type_at_first_errno_tolerance, fpdb)
521524

522-
if speedup_at_t1 < 1:
523-
return speedup_at_t1 ** (negative_speedup_penalty + 1)
524-
return speedup_at_t1
525+
if speedup_at_first_errno_tolerance < 1:
526+
return speedup_at_first_errno_tolerance ** (negative_speedup_penalty + 1)
527+
return speedup_at_first_errno_tolerance
525528

526529

527530
def fake_perf_degrad(t, error_code, fpdb=0.1):
@@ -621,12 +624,12 @@ def print_stat_info(
621624
return expected_s, expected_es
622625

623626
# pi is a tuple of constants for t > 0 for each group: (pi[0], pi[1])
624-
# Calculated at t=1, used for all t >= 1
627+
# Calculated at first errno tolerance (t=1), used for all t >= 1
625628
pi = (0.0, 0.0)
626629

627-
is_correct_at_t1 = [False] * total_samples
628-
speedup_at_t1 = [None] * total_samples
629-
fail_type_at_t1 = ["CORRECT"] * total_samples
630+
is_correct_at_first_errno_tolerance = [False] * total_samples
631+
speedup_at_first_errno_tolerance = [None] * total_samples
632+
fail_type_at_first_errno_tolerance = ["CORRECT"] * total_samples
630633

631634
final_correct_count = 0
632635
final_correct_negative_speedup_count = 0
@@ -646,8 +649,8 @@ def print_stat_info(
646649
performance_data = sample.get("performance", {})
647650
speedup = performance_data.get("speedup", {}).get("e2e")
648651

649-
# Check correctness using dedicated function
650-
is_correct, fail_type = check_sample_correctness(sample, t_key)
652+
# Get correctness using dedicated function
653+
is_correct, fail_type = get_sample_correctness(sample, t_key)
651654

652655
# Collect statistics
653656
if is_correct:
@@ -662,11 +665,13 @@ def print_stat_info(
662665
errno = get_errno_from_error_type(fail_type)
663666
errno2count[errno] = errno2count.get(errno, 0) + 1
664667

665-
# Store state at t=1 for ES(t) calculation
668+
# Store state at first errno tolerance (t=1) for ES(t) calculation
666669
if t_key == 1:
667-
is_correct_at_t1[idx] = is_correct
668-
speedup_at_t1[idx] = speedup
669-
fail_type_at_t1[idx] = fail_type if fail_type is not None else "CORRECT"
670+
is_correct_at_first_errno_tolerance[idx] = is_correct
671+
speedup_at_first_errno_tolerance[idx] = speedup
672+
fail_type_at_first_errno_tolerance[idx] = (
673+
fail_type if fail_type is not None else "CORRECT"
674+
)
670675

671676
# Calculate rectified speedups using dedicated functions
672677
regularized_speedup = calculate_rectified_speedup(
@@ -678,9 +683,9 @@ def print_stat_info(
678683
speedup,
679684
fail_type,
680685
t_key,
681-
is_correct_at_t1[idx],
682-
speedup_at_t1[idx],
683-
fail_type_at_t1[idx],
686+
is_correct_at_first_errno_tolerance[idx],
687+
speedup_at_first_errno_tolerance[idx],
688+
fail_type_at_first_errno_tolerance[idx],
684689
negative_speedup_penalty,
685690
fpdb,
686691
)

graph_net/plot_ESt.py

Lines changed: 84 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -5,106 +5,132 @@
55
from graph_net import analysis_util
66

77

8-
def compare_single_tolerance_level(
9-
tolerance_level: int,
10-
micro_es: float,
8+
def es_result_checker(
9+
es_from_microscopic: float, es_from_macro: float, atol: float, rtol: float
10+
) -> bool:
11+
"""
12+
Check if ES(t) values from microscopic and macro calculations match.
13+
14+
Args:
15+
es_from_microscopic: ES(t) value from microscopic-level calculation
16+
es_from_macro: ES(t) value from aggregated-level calculation
17+
atol: Absolute tolerance for comparison
18+
rtol: Relative tolerance for comparison
19+
20+
Returns:
21+
True if values match within tolerance, False otherwise
22+
"""
23+
diff = abs(es_from_microscopic - es_from_macro)
24+
return diff < atol or diff < rtol * max(
25+
abs(es_from_microscopic), abs(es_from_macro), 1e-10
26+
)
27+
28+
29+
def compare_aggregated_es_and_microscopic_es(
30+
tolerance: int,
31+
microscopic_es: float,
1132
aggregated_es: float | None,
12-
tolerance_threshold: float,
33+
atol: float = 1e-3,
34+
rtol: float = 1e-3,
1335
) -> tuple[bool, float, float]:
1436
"""
15-
Compare micro and aggregated ES(t) values for a single tolerance level.
37+
Compare ES(t) values from aggregated and microscopic calculations at a tolerance level.
1638
1739
Args:
18-
tolerance_level: Tolerance level t
19-
micro_es: ES(t) value from micro-level calculation
40+
tolerance: Tolerance level t
41+
microscopic_es: ES(t) value from microscopic-level calculation
2042
aggregated_es: ES(t) value from aggregated-level calculation, or None if missing
21-
tolerance_threshold: Floating point comparison tolerance
43+
atol: Absolute tolerance for comparison
44+
rtol: Relative tolerance for comparison
2245
2346
Returns:
2447
Tuple of (is_matched, diff, relative_diff)
2548
"""
2649
if aggregated_es is None:
2750
return False, 0.0, 0.0
2851

29-
diff = abs(micro_es - aggregated_es)
30-
relative_diff = diff / max(abs(micro_es), abs(aggregated_es), 1e-10)
31-
is_matched = diff < tolerance_threshold or relative_diff < tolerance_threshold
52+
diff = abs(microscopic_es - aggregated_es)
53+
relative_diff = diff / max(abs(microscopic_es), abs(aggregated_es), 1e-10)
54+
is_matched = es_result_checker(microscopic_es, aggregated_es, atol, rtol)
3255

3356
return is_matched, diff, relative_diff
3457

3558

3659
def print_verification_result(
37-
tolerance_level: int,
38-
micro_es: float,
60+
tolerance: int,
61+
microscopic_es: float,
3962
aggregated_es: float | None,
4063
diff: float,
4164
relative_diff: float,
4265
is_matched: bool,
4366
) -> None:
4467
"""Print verification result for a single tolerance level."""
4568
if aggregated_es is None:
46-
print(f"ERROR: No aggregated result for t={tolerance_level}, cannot verify")
69+
print(f"ERROR: No aggregated result for t={tolerance}, cannot verify")
4770
elif is_matched:
4871
print(
49-
f"t={tolerance_level:3d}: MATCHED - Micro: {micro_es:.6f}, Aggregated: {aggregated_es:.6f}, Diff: {diff:.2e}"
72+
f"t={tolerance:3d}: MATCHED - Microscopic: {microscopic_es:.6f}, Aggregated: {aggregated_es:.6f}, Diff: {diff:.2e}"
5073
)
5174
else:
5275
print(
53-
f"t={tolerance_level:3d}: MISMATCH - Micro: {micro_es:.6f}, Aggregated: {aggregated_es:.6f}, Diff: {diff:.2e} ({relative_diff*100:.4f}%)"
76+
f"t={tolerance:3d}: MISMATCH - Microscopic: {microscopic_es:.6f}, Aggregated: {aggregated_es:.6f}, Diff: {diff:.2e} ({relative_diff*100:.4f}%)"
5477
)
5578

5679

57-
def verify_aggregated_micro_consistency(
58-
es_scores: dict, folder_name: str, tolerance_threshold: float
59-
) -> tuple[dict, bool]:
80+
def get_verified_aggregated_es_values(es_scores: dict, folder_name: str) -> dict:
6081
"""
61-
Verify consistency between aggregated and micro-level ES(t) calculations.
82+
Get verified ES(t) values by checking consistency between aggregated and microscopic-level calculations.
6283
6384
Args:
64-
es_scores: Dictionary of ES(t) scores from micro-level calculation
85+
es_scores: Dictionary of ES(t) scores from microscopic-level calculation
6586
folder_name: Name of the folder being verified
66-
tolerance_threshold: Floating point comparison tolerance
6787
6888
Returns:
69-
Tuple of (verified_scores, all_matched):
70-
- verified_scores: Dictionary of verified scores (only matched tolerance levels)
71-
- all_matched: True if all tolerance levels matched, False otherwise
89+
Dictionary of verified ES(t) values (only matched tolerance levels).
90+
Returns empty dict if validation fails.
7291
"""
7392
aggregated_results = getattr(es_scores, "_aggregated_results", {})
74-
verified_scores = {}
93+
verified_es_values = {}
7594
all_matched = True
7695

7796
print(f"\n{'='*80}")
78-
print(f"Verifying Aggregated/Micro Consistency for '{folder_name}'")
97+
print(f"Verifying Aggregated/Microscopic Consistency for '{folder_name}'")
7998
print(f"{'='*80}")
8099

81-
for tolerance_level, micro_es in es_scores.items():
82-
aggregated_es = aggregated_results.get(tolerance_level)
83-
is_matched, diff, relative_diff = compare_single_tolerance_level(
84-
tolerance_level, micro_es, aggregated_es, tolerance_threshold
100+
for tolerance, microscopic_es in es_scores.items():
101+
aggregated_es = aggregated_results.get(tolerance)
102+
is_matched, diff, relative_diff = compare_aggregated_es_and_microscopic_es(
103+
tolerance, microscopic_es, aggregated_es
85104
)
86105

87106
print_verification_result(
88-
tolerance_level, micro_es, aggregated_es, diff, relative_diff, is_matched
107+
tolerance,
108+
microscopic_es,
109+
aggregated_es,
110+
diff,
111+
relative_diff,
112+
is_matched,
89113
)
90114

91115
if aggregated_es is None or not is_matched:
92116
all_matched = False
93117
if is_matched:
94-
verified_scores[tolerance_level] = micro_es
118+
verified_es_values[tolerance] = microscopic_es
95119

96120
if not all_matched:
97121
print(
98-
f"\nERROR: Aggregated and micro results do not match for '{folder_name}'!"
122+
f"\nERROR: Aggregated and microscopic results do not match for '{folder_name}'!"
99123
)
100124
print("Calculation validation failed. Results will NOT be used for plotting.")
101125
print("Please verify the calculation logic using verify_aggregated_params.py")
102126
print(f"{'='*80}\n")
127+
return {}
103128
else:
104-
print(f"\nSUCCESS: All aggregated and micro results match for '{folder_name}'.")
129+
print(
130+
f"\nSUCCESS: All aggregated and microscopic results match for '{folder_name}'."
131+
)
105132
print(f"{'='*80}\n")
106-
107-
return verified_scores, all_matched
133+
return verified_es_values
108134

109135

110136
def plot_ES_results(s_scores: dict, cli_args: argparse.Namespace):
@@ -232,6 +258,18 @@ def main():
232258
default=0.1,
233259
help="Base penalty for severe errors (e.g., crashes, correctness failures).",
234260
)
261+
parser.add_argument(
262+
"--enable-aggregation-mode",
263+
action="store_true",
264+
help="Enable aggregation mode to verify aggregated/microscopic consistency. Default: enabled.",
265+
)
266+
parser.add_argument(
267+
"--disable-aggregation-mode",
268+
dest="enable_aggregation_mode",
269+
action="store_false",
270+
help="Disable aggregation mode verification.",
271+
)
272+
parser.set_defaults(enable_aggregation_mode=True)
235273
args = parser.parse_args()
236274

237275
# 1. Scan folders to get data
@@ -240,9 +278,8 @@ def main():
240278
print("No valid data found. Exiting.")
241279
return
242280

243-
# 2. Calculate scores for each curve and verify aggregated/micro consistency
281+
# 2. Calculate scores for each curve and verify aggregated/microscopic consistency
244282
all_es_scores = {}
245-
tolerance_threshold = 1e-6 # Tolerance for floating point comparison
246283

247284
for folder_name, samples in all_results.items():
248285
_, es_scores = analysis_util.calculate_s_scores(
@@ -255,15 +292,16 @@ def main():
255292
# Keep original behavior: assign es_scores directly
256293
all_es_scores[folder_name] = es_scores
257294

258-
# Verify aggregated/micro consistency
259-
verified_scores, all_matched = verify_aggregated_micro_consistency(
260-
es_scores, folder_name, tolerance_threshold
261-
)
295+
# Verify aggregated/microscopic consistency if aggregation mode is enabled
296+
if args.enable_aggregation_mode:
297+
verified_es_values = get_verified_aggregated_es_values(
298+
es_scores, folder_name
299+
)
262300

263-
if not all_matched:
264-
continue # Skip this curve if validation fails
301+
if not verified_es_values:
302+
continue # Skip this curve if validation fails
265303

266-
all_es_scores[folder_name] = verified_scores
304+
all_es_scores[folder_name] = verified_es_values
267305

268306
# 3. Plot the results
269307
if any(all_es_scores.values()):

0 commit comments

Comments
 (0)