Skip to content

Commit b4de398

Browse files
committed
feat: implement ES(t) macro/micro cross-validation and refactor analysis utilities
This commit implements the Error-aware Speedup Score (ES_t) metric from Section 3.2.2 of the technical report (arXiv:2510.24035), along with the mathematical proofs from Appendix B and C that establish the sample-level validity of both S_t and ES_t metrics. Key Features: ============= 1. Appendix B Implementation - Sample-level proof for S_t: - Micro-level calculation: geometric mean of rectified speedups for all samples - Macro-level calculation: S_t = α^λ · β^(ληp) · b^(1-λ) - Cross-validation: both methods produce identical results, proving S_t is equivalent to the geometric mean of sample-level rectified speedups 2. Appendix C Implementation - Sample-level proof for ES_t: - Micro-level calculation: geometric mean of error-aware rectified speedups - Macro-level calculation: ES_t = α^λ · β^(ληp) · γ_t^(1-λ) - Dynamic penalty factor: γ_t = b^(sum(π_c * indicator(t < c))) - Cross-validation: validates that ES_t is the geometric mean of error-aware rectified speedups, where failure samples use type-specific dynamic penalties instead of fixed penalty b 3. Error-aware design (Section 3.2.2): - Error type classification: c=1 (accuracy), c=2 (runtime crash), c=3 (compile failure) - Tiered tolerance rules: t≥1 tolerates accuracy errors, t≥2 tolerates runtime crashes, t≥3 tolerates all errors - Dynamic penalty γ_t adapts based on error type distribution and tolerance level 4. Independent verification script: - verify_macro_params.py: calculates and prints all macro parameters (alpha, beta, gamma, lambda, eta, pi) independently - Enables validation of plot_ESt results by computing each parameter separately 5. Mandatory validation mechanism: - plot_ESt.py: enforces macro/micro result matching before adoption - Rejects results if validation fails, ensuring calculation correctness 6. Code refactoring for maintainability: - macro_statistics.py: dedicated module for macro parameter calculations - Each parameter has independent function (alpha, beta, gamma, lambda, eta, pi) - Reduced nesting levels in analysis_util.py by extracting helper functions - Simplified scan_all_folders and added .txt file support - Improved code organization following software engineering best practices Technical Details: ================== - Micro calculation: processes each sample individually, applies rectified speedup rules, then computes geometric mean - Macro calculation: uses aggregated statistics (correct count, speedup distributions, error type proportions) to compute expected values - Validation: compares micro and macro results with tolerance threshold (1e-6) - All calculations verified against real benchmark data (118 samples) Files Changed: ============== - graph_net/analysis_util.py: refactored with helper functions, integrated macro_statistics module, reduced nesting, simplified scan_all_folders - graph_net/macro_statistics.py: new module for macro parameter calculations - graph_net/plot_ESt.py: added mandatory macro/micro validation - graph_net/verify_macro_params.py: new independent verification script All code passes pre-commit checks, compiles successfully, and has been validated with real benchmark data.
1 parent 9849633 commit b4de398

File tree

4 files changed

+721
-86
lines changed

4 files changed

+721
-86
lines changed

graph_net/analysis_util.py

Lines changed: 168 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import numpy as np
55
from scipy.stats import gmean
66
from collections import OrderedDict, defaultdict
7+
from typing import Tuple
78
from graph_net.config.datatype_tolerance_config import get_precision
9+
from graph_net import macro_statistics
810

911

1012
def extract_speedup_data_from_subdirs(benchmark_path: str) -> dict:
@@ -414,6 +416,114 @@ def get_correctness(dtype: str, t: int, correctness_data: dict, index: int) -> b
414416
return False
415417

416418

419+
def check_sample_correctness(sample: dict, t_key: int) -> Tuple[bool, str]:
420+
"""
421+
Check if a sample is correct at the given tolerance level.
422+
423+
Args:
424+
sample: Sample data dictionary
425+
t_key: Tolerance level
426+
427+
Returns:
428+
Tuple of (is_correct, fail_type)
429+
- is_correct: True if sample is correct at this tolerance
430+
- fail_type: Error type if not correct, None if correct
431+
"""
432+
performance_data = sample.get("performance", {})
433+
fail_type = performance_data.get("failure")
434+
435+
# If there's already a failure type, return it
436+
if fail_type is not None:
437+
return False, fail_type
438+
439+
# Check correctness based on datatype and tolerance
440+
datatype_data = performance_data.get("datatype", {})
441+
eager_dtypes = datatype_data.get("eager", [])
442+
compiled_dtypes = datatype_data.get("compiled", [])
443+
444+
# Check if datatypes match and are valid
445+
if not (eager_dtypes and eager_dtypes == compiled_dtypes and len(eager_dtypes) > 0):
446+
return False, "accuracy"
447+
448+
correctness_data = sample.get("correctness", {})
449+
output_count = len(correctness_data.get("[equal]", []))
450+
451+
if len(eager_dtypes) != output_count:
452+
return False, "accuracy"
453+
454+
# Check all outputs for correctness
455+
is_correct = all(
456+
get_correctness(eager_dtypes[i], t_key, correctness_data, i)
457+
for i in range(output_count)
458+
)
459+
460+
return is_correct, None if is_correct else "accuracy"
461+
462+
463+
def calculate_rectified_speedup(
464+
speedup: float, fail_type: str, negative_speedup_penalty: float, fpdb: float
465+
) -> float:
466+
"""
467+
Calculate rectified speedup for S(t) calculation.
468+
469+
Args:
470+
speedup: Original speedup value
471+
fail_type: Error type or None if correct
472+
negative_speedup_penalty: Penalty power p for negative speedup
473+
fpdb: Base penalty for failures
474+
475+
Returns:
476+
Rectified speedup value
477+
"""
478+
if fail_type is not None or speedup is None:
479+
return fpdb
480+
481+
if speedup < 1:
482+
return speedup ** (negative_speedup_penalty + 1)
483+
return speedup
484+
485+
486+
def calculate_es_rectified_speedup(
487+
speedup: float,
488+
fail_type: str,
489+
t_key: int,
490+
is_correct_at_t1: bool,
491+
speedup_at_t1: float,
492+
fail_type_at_t1: str,
493+
negative_speedup_penalty: float,
494+
fpdb: float,
495+
) -> float:
496+
"""
497+
Calculate rectified speedup for ES(t) calculation.
498+
499+
Args:
500+
speedup: Current speedup value
501+
fail_type: Current error type
502+
t_key: Current tolerance level
503+
is_correct_at_t1: Whether sample was correct at t=1
504+
speedup_at_t1: Speedup value at t=1
505+
fail_type_at_t1: Error type at t=1
506+
negative_speedup_penalty: Penalty power p
507+
fpdb: Base penalty for failures
508+
509+
Returns:
510+
Error-aware rectified speedup value
511+
"""
512+
if t_key < 1:
513+
# For t < 1, ES(t) = S(t)
514+
return calculate_rectified_speedup(
515+
speedup, fail_type, negative_speedup_penalty, fpdb
516+
)
517+
518+
# For t >= 1, use frozen state from t=1
519+
if not is_correct_at_t1 or speedup_at_t1 is None:
520+
return fake_perf_degrad(t_key, fail_type_at_t1, fpdb)
521+
522+
if speedup_at_t1 < 1:
523+
return speedup_at_t1 ** (negative_speedup_penalty + 1)
524+
return speedup_at_t1
525+
526+
417527
def fake_perf_degrad(t, error_code, fpdb=0.1):
418528
"""
419529
Calculate fake performance degradation based on tolerance t and error code.
@@ -445,6 +555,9 @@ def calculate_s_scores(
445555
"""
446556
s_scores = OrderedDict()
447557
s_scores_fake_degrad = OrderedDict()
558+
# Store macro-level calculation results for cross-validation
559+
s_scores._macro_results = OrderedDict()
560+
s_scores_fake_degrad._macro_results = OrderedDict()
448561

449562
begin = -10
450563
end = 4
@@ -462,33 +575,34 @@ def print_stat_info(
462575
correct_speedups,
463576
slowdown_speedups,
464577
):
578+
"""
579+
Calculate and print macro statistics for a given tolerance level.
580+
581+
Uses the macro_statistics module for all parameter calculations.
582+
"""
465583
print(f" - Details for tolerance={t_key}:")
466584
if total_samples > 0:
467-
alpha = gmean(correct_speedups) if correct_speedups else 1
468-
beta = gmean(slowdown_speedups) if slowdown_speedups else 1
469-
lambda_ = correct_count / total_samples if total_samples > 0 else 0
470-
eta = (
471-
correct_negative_speedup_count / correct_count
472-
if correct_count > 0
473-
else 0
474-
)
475-
indicator = [1 if t_key < 1 else 0, 1 if t_key < 3 else 0]
476-
gamma = (
477-
fpdb ** sum(pi[i] * indicator[i] for i in range(len(pi)))
478-
if t_key >= 1
479-
else fpdb
585+
# Calculate all macro parameters using the dedicated module
586+
macro_params = macro_statistics.calculate_all_macro_parameters(
587+
correct_count=correct_count,
588+
total_samples=total_samples,
589+
correct_negative_speedup_count=correct_negative_speedup_count,
590+
correct_speedups=correct_speedups,
591+
slowdown_speedups=slowdown_speedups,
592+
acc_failure_count=acc_failure_count,
593+
t_key=t_key,
594+
negative_speedup_penalty=negative_speedup_penalty,
595+
fpdb=fpdb,
596+
pi=pi,
480597
)
481598

482-
expected_s = (
483-
alpha**lambda_
484-
* beta ** (lambda_ * eta * negative_speedup_penalty)
485-
* fpdb ** (1 - lambda_)
486-
)
487-
expected_es = (
488-
alpha**lambda_
489-
* beta ** (lambda_ * eta * negative_speedup_penalty)
490-
* gamma ** (1 - lambda_)
491-
)
599+
alpha = macro_params["alpha"]
600+
beta = macro_params["beta"]
601+
lambda_ = macro_params["lambda"]
602+
eta = macro_params["eta"]
603+
gamma = macro_params["gamma"]
604+
expected_s = macro_params["s_t"]
605+
expected_es = macro_params["es_t"]
492606

493607
print(
494608
f" - alpha: {alpha:.3f} (Geometric mean speedup of correct samples)"
@@ -501,11 +615,14 @@ def print_stat_info(
501615
)
502616
else:
503617
print(" - No samples to analyze.")
618+
expected_s = fpdb
619+
expected_es = fpdb
504620

505621
return expected_s, expected_es
506622

507-
# pi is a list of constants for t > 0 for each group
508-
pi = [0, 0]
623+
# pi is a tuple of constants for t > 0 for each group: (pi[0], pi[1])
624+
# Calculated at t=1, used for all t >= 1
625+
pi = (0.0, 0.0)
509626

510627
is_correct_at_t1 = [False] * total_samples
511628
speedup_at_t1 = [None] * total_samples
@@ -525,31 +642,13 @@ def print_stat_info(
525642
correct_speedups = []
526643
slowdown_speedups = []
527644

645+
# Process all samples using helper functions to reduce nesting
528646
for idx, sample in enumerate(samples):
529647
performance_data = sample.get("performance", {})
530-
fail_type = performance_data.get("failure")
531648
speedup = performance_data.get("speedup", {}).get("e2e")
532649

533-
# Determine the true state of the current sample (for statistics and S curve)
534-
is_correct = False
535-
if fail_type is None:
536-
datatype_data = performance_data.get("datatype", {})
537-
eager_dtypes = datatype_data.get("eager", [])
538-
compiled_dtypes = datatype_data.get("compiled", [])
539-
if (
540-
eager_dtypes
541-
and eager_dtypes == compiled_dtypes
542-
and len(eager_dtypes) > 0
543-
):
544-
correctness_data = sample.get("correctness", {})
545-
output_count = len(correctness_data.get("[equal]", []))
546-
if len(eager_dtypes) == output_count:
547-
is_correct = all(
548-
get_correctness(eager_dtypes[i], t_key, correctness_data, i)
549-
for i in range(output_count)
550-
)
551-
if not is_correct:
552-
fail_type = "accuracy"
650+
# Check correctness using dedicated function
651+
is_correct, fail_type = check_sample_correctness(sample, t_key)
553652

554653
# Collect statistics
555654
if is_correct:
@@ -563,53 +662,35 @@ def print_stat_info(
563662
if fail_type == "accuracy":
564663
acc_failure_count += 1
565664

665+
# Store state at t=1 for ES(t) calculation
566666
if t_key == 1:
567667
is_correct_at_t1[idx] = is_correct
568668
speedup_at_t1[idx] = speedup
569669
fail_type_at_t1[idx] = fail_type if fail_type is not None else "CORRECT"
570670

571-
# S(t) calculation
572-
if fail_type is not None or speedup is None:
573-
regularized_speedup = fpdb
574-
else:
575-
regularized_speedup = (
576-
speedup ** (negative_speedup_penalty + 1)
577-
if speedup < 1
578-
else speedup
579-
)
671+
# Calculate rectified speedups using dedicated functions
672+
regularized_speedup = calculate_rectified_speedup(
673+
speedup, fail_type, negative_speedup_penalty, fpdb
674+
)
580675
rectified_speedups.append(regularized_speedup)
581676

582-
# ES(t) calculation: based on state change
583-
if t_key < 1:
584-
if fail_type is not None or speedup is None:
585-
rec_speedup_fake_degrad = fpdb
586-
else:
587-
rec_speedup_fake_degrad = (
588-
speedup ** (negative_speedup_penalty + 1)
589-
if speedup < 1
590-
else speedup
591-
)
592-
else:
593-
if not is_correct_at_t1[idx] or speedup_at_t1[idx] is None:
594-
fail_type_frozen = fail_type_at_t1[idx]
595-
rec_speedup_fake_degrad = fake_perf_degrad(
596-
t_key, fail_type_frozen, fpdb
597-
)
598-
else:
599-
rec_speedup_fake_degrad = (
600-
speedup_at_t1[idx] ** (negative_speedup_penalty + 1)
601-
if speedup_at_t1[idx] < 1
602-
else speedup_at_t1[idx]
603-
)
677+
rec_speedup_fake_degrad = calculate_es_rectified_speedup(
678+
speedup,
679+
fail_type,
680+
t_key,
681+
is_correct_at_t1[idx],
682+
speedup_at_t1[idx],
683+
fail_type_at_t1[idx],
684+
negative_speedup_penalty,
685+
fpdb,
686+
)
604687
rectified_speedups_fake_degrad.append(rec_speedup_fake_degrad)
605688

606689
if t_key == 1:
607-
if total_samples == correct_count:
608-
pi[0] = 0
609-
pi[1] = 0
610-
else:
611-
pi[0] = acc_failure_count / (total_samples - correct_count)
612-
pi[1] = 1 - pi[0]
690+
# Calculate pi at t=1 using the dedicated function
691+
pi = macro_statistics.calculate_pi(
692+
acc_failure_count, total_samples, correct_count
693+
)
613694
final_correct_count = correct_count
614695
final_correct_negative_speedup_count = correct_negative_speedup_count
615696
final_correct_speedups = correct_speedups
@@ -644,7 +725,10 @@ def print_stat_info(
644725
print(
645726
f" - S(t)={expected_s:.3f}, ES(t)={expected_es:.3f} for tolerance={t_key} from macro level."
646727
)
728+
# Store macro results for cross-validation
729+
s_scores._macro_results[t_key] = expected_s
730+
s_scores_fake_degrad._macro_results[t_key] = expected_es
647731

648-
print(f" - pi: {pi}")
732+
print(f" - pi: {list(pi)}")
649733

650734
return s_scores, s_scores_fake_degrad

0 commit comments

Comments
 (0)