Skip to content

Commit 6033716

Browse files
committed
refactor: rename macro to aggregated and improve code quality
This commit refactors the evaluation metrics calculation code with the following improvements: 1. Terminology refactoring: macro -> aggregated - Rename macro_statistics.py to samples_statistics.py - Rename verify_macro_params.py to verify_aggregated_params.py - Update all variable and function names accordingly 2. Code structure improvements - Extract verification logic in plot_ESt.py into separate functions * compare_single_tolerance_level (12 lines) * print_verification_result (1 line) * verify_aggregated_micro_consistency (28 lines, meets ≤30 line requirement) - Refactor verify_aggregated_params.py to use functional programming style * Replace structured loops with list comprehensions * Use Counter for error type counting * Reduce multiple traversals to single pass where possible 3. Reduce function parameter coupling - calculate_beta: derive slowdown_speedups internally from correct_speedups - calculate_lambda: derive correct_count internally from correct_speedups - calculate_eta: derive statistics internally from correct_speedups 4. Decouple error type handling - calculate_pi: accept error_type_counts (dict) instead of hardcoded types - calculate_gamma: accept generic parameters (tolerance, get_pi, errno_tolerances) - Support user-defined error codes instead of hardcoded error types 5. Code quality improvements - Use explicit len() checks instead of implicit boolean conversion - Use modern Python type hints (list/tuple instead of typing.List/Tuple) - Improve code readability and maintainability All changes have been verified and pass pre-commit checks.
1 parent b4de398 commit 6033716

File tree

5 files changed

+511
-403
lines changed

5 files changed

+511
-403
lines changed

graph_net/analysis_util.py

Lines changed: 33 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
import numpy as np
55
from scipy.stats import gmean
66
from collections import OrderedDict, defaultdict
7-
from typing import Tuple
87
from graph_net.config.datatype_tolerance_config import get_precision
9-
from graph_net import macro_statistics
8+
from graph_net import samples_statistics
109

1110

1211
def extract_speedup_data_from_subdirs(benchmark_path: str) -> dict:
@@ -416,7 +415,7 @@ def get_correctness(dtype: str, t: int, correctness_data: dict, index: int) -> b
416415
return False
417416

418417

419-
def check_sample_correctness(sample: dict, t_key: int) -> Tuple[bool, str]:
418+
def check_sample_correctness(sample: dict, t_key: int) -> tuple[bool, str]:
420419
"""
421420
Check if a sample is correct at the given tolerance level.
422421
@@ -555,9 +554,9 @@ def calculate_s_scores(
555554
"""
556555
s_scores = OrderedDict()
557556
s_scores_fake_degrad = OrderedDict()
558-
# Store macro-level calculation results for cross-validation
559-
s_scores._macro_results = OrderedDict()
560-
s_scores_fake_degrad._macro_results = OrderedDict()
557+
# Store aggregated-level calculation results for cross-validation
558+
s_scores._aggregated_results = OrderedDict()
559+
s_scores_fake_degrad._aggregated_results = OrderedDict()
561560

562561
begin = -10
563562
end = 4
@@ -569,40 +568,36 @@ def calculate_s_scores(
569568
def print_stat_info(
570569
t_key,
571570
correct_count,
572-
acc_failure_count,
571+
error_type_counts,
573572
pi,
574573
correct_negative_speedup_count,
575574
correct_speedups,
576-
slowdown_speedups,
577575
):
578576
"""
579-
Calculate and print macro statistics for a given tolerance level.
577+
Calculate and print aggregated statistics for a given tolerance level.
580578
581-
Uses the macro_statistics module for all parameter calculations.
579+
Uses the samples_statistics module for all parameter calculations.
582580
"""
583581
print(f" - Details for tolerance={t_key}:")
584582
if total_samples > 0:
585-
# Calculate all macro parameters using the dedicated module
586-
macro_params = macro_statistics.calculate_all_macro_parameters(
587-
correct_count=correct_count,
583+
# Calculate all aggregated parameters using the dedicated module
584+
aggregated_params = samples_statistics.calculate_all_aggregated_parameters(
588585
total_samples=total_samples,
589-
correct_negative_speedup_count=correct_negative_speedup_count,
590586
correct_speedups=correct_speedups,
591-
slowdown_speedups=slowdown_speedups,
592-
acc_failure_count=acc_failure_count,
587+
error_type_counts=error_type_counts,
593588
t_key=t_key,
594589
negative_speedup_penalty=negative_speedup_penalty,
595590
fpdb=fpdb,
596591
pi=pi,
597592
)
598593

599-
alpha = macro_params["alpha"]
600-
beta = macro_params["beta"]
601-
lambda_ = macro_params["lambda"]
602-
eta = macro_params["eta"]
603-
gamma = macro_params["gamma"]
604-
expected_s = macro_params["s_t"]
605-
expected_es = macro_params["es_t"]
594+
alpha = aggregated_params["alpha"]
595+
beta = aggregated_params["beta"]
596+
lambda_ = aggregated_params["lambda"]
597+
eta = aggregated_params["eta"]
598+
gamma = aggregated_params["gamma"]
599+
expected_s = aggregated_params["s_t"]
600+
expected_es = aggregated_params["es_t"]
606601

607602
print(
608603
f" - alpha: {alpha:.3f} (Geometric mean speedup of correct samples)"
@@ -631,16 +626,15 @@ def print_stat_info(
631626
final_correct_count = 0
632627
final_correct_negative_speedup_count = 0
633628
final_correct_speedups = []
634-
final_slowdown_speedups = []
629+
final_error_type_counts = {} # Store error type counts at t=1
635630

636631
for t_key in t_keys:
637632
rectified_speedups = []
638633
rectified_speedups_fake_degrad = []
639634
correct_count = 0
640-
acc_failure_count = 0
635+
error_type_counts = {} # Dictionary to count errors by type
641636
correct_negative_speedup_count = 0
642637
correct_speedups = []
643-
slowdown_speedups = []
644638

645639
# Process all samples using helper functions to reduce nesting
646640
for idx, sample in enumerate(samples):
@@ -657,10 +651,10 @@ def print_stat_info(
657651
correct_speedups.append(speedup)
658652
if speedup is not None and speedup < 1:
659653
correct_negative_speedup_count += 1
660-
slowdown_speedups.append(speedup)
661654

662-
if fail_type == "accuracy":
663-
acc_failure_count += 1
655+
# Count errors by type
656+
if fail_type is not None:
657+
error_type_counts[fail_type] = error_type_counts.get(fail_type, 0) + 1
664658

665659
# Store state at t=1 for ES(t) calculation
666660
if t_key == 1:
@@ -688,13 +682,13 @@ def print_stat_info(
688682

689683
if t_key == 1:
690684
# Calculate pi at t=1 using the dedicated function
691-
pi = macro_statistics.calculate_pi(
692-
acc_failure_count, total_samples, correct_count
685+
pi = samples_statistics.calculate_pi(
686+
error_type_counts, total_samples, correct_speedups
693687
)
694688
final_correct_count = correct_count
695689
final_correct_negative_speedup_count = correct_negative_speedup_count
696690
final_correct_speedups = correct_speedups
697-
final_slowdown_speedups = slowdown_speedups
691+
final_error_type_counts = error_type_counts.copy() # Save for t >= 1
698692

699693
if rectified_speedups:
700694
s_scores[t_key] = gmean(rectified_speedups)
@@ -706,28 +700,27 @@ def print_stat_info(
706700
expected_s, expected_es = print_stat_info(
707701
t_key,
708702
correct_count,
709-
acc_failure_count,
703+
error_type_counts,
710704
pi,
711705
correct_negative_speedup_count,
712706
correct_speedups,
713-
slowdown_speedups,
714707
)
715708
else:
709+
# For t >= 1, use error_type_counts from t=1 (frozen state)
716710
expected_s, expected_es = print_stat_info(
717711
t_key,
718712
final_correct_count,
719-
acc_failure_count,
713+
final_error_type_counts, # Use the frozen error_type_counts from t=1
720714
pi,
721715
final_correct_negative_speedup_count,
722716
final_correct_speedups,
723-
final_slowdown_speedups,
724717
)
725718
print(
726-
f" - S(t)={expected_s:.3f}, ES(t)={expected_es:.3f} for tolerance={t_key} from macro level."
719+
f" - S(t)={expected_s:.3f}, ES(t)={expected_es:.3f} for tolerance={t_key} from aggregated level."
727720
)
728-
# Store macro results for cross-validation
729-
s_scores._macro_results[t_key] = expected_s
730-
s_scores_fake_degrad._macro_results[t_key] = expected_es
721+
# Store aggregated results for cross-validation
722+
s_scores._aggregated_results[t_key] = expected_s
723+
s_scores_fake_degrad._aggregated_results[t_key] = expected_es
731724

732725
print(f" - pi: {list(pi)}")
733726

0 commit comments

Comments
 (0)