Skip to content

Commit 22339b3

Browse files
committed
refactor: unify error type to errno mapping for better sorting
- Replace error_type_counts (dict[str, int]) with errno2count (dict[int, int]) - Add get_errno_from_error_type() to map error type strings to errno (1, 2, 3) - Add get_error_type_from_errno() for reverse mapping when error type strings are needed - Update calculate_pi() to use errno2count and return dict[int, float] - Update calculate_all_aggregated_parameters() to use errno2count and errno_tolerance_thresholds - Update analysis_util.py and verify_aggregated_params.py to use errno2count - Improve code maintainability by using integer errno for sorting and comparison
1 parent 498f60d commit 22339b3

File tree

3 files changed

+120
-56
lines changed

3 files changed

+120
-56
lines changed

graph_net/analysis_util.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections import OrderedDict, defaultdict
77
from graph_net.config.datatype_tolerance_config import get_precision
88
from graph_net import samples_statistics
9+
from graph_net.samples_statistics import get_errno_from_error_type
910

1011

1112
def extract_speedup_data_from_subdirs(benchmark_path: str) -> dict:
@@ -568,7 +569,7 @@ def calculate_s_scores(
568569
def print_stat_info(
569570
t_key,
570571
correct_count,
571-
error_type_counts,
572+
errno2count,
572573
pi,
573574
correct_negative_speedup_count,
574575
correct_speedups,
@@ -584,7 +585,7 @@ def print_stat_info(
584585
aggregated_params = samples_statistics.calculate_all_aggregated_parameters(
585586
total_samples=total_samples,
586587
correct_speedups=correct_speedups,
587-
error_type_counts=error_type_counts,
588+
errno2count=errno2count,
588589
t_key=t_key,
589590
negative_speedup_penalty=negative_speedup_penalty,
590591
fpdb=fpdb,
@@ -626,13 +627,13 @@ def print_stat_info(
626627
final_correct_count = 0
627628
final_correct_negative_speedup_count = 0
628629
final_correct_speedups = []
629-
final_error_type_counts = {} # Store error type counts at t=1
630+
final_errno2count = {} # Store error type counts at t=1 (using errno)
630631

631632
for t_key in t_keys:
632633
rectified_speedups = []
633634
rectified_speedups_fake_degrad = []
634635
correct_count = 0
635-
error_type_counts = {} # Dictionary to count errors by type
636+
errno2count = {} # Dictionary to count errors by errno
636637
correct_negative_speedup_count = 0
637638
correct_speedups = []
638639

@@ -652,9 +653,10 @@ def print_stat_info(
652653
if speedup is not None and speedup < 1:
653654
correct_negative_speedup_count += 1
654655

655-
# Count errors by type
656+
# Count errors by errno (convert error type string to errno)
656657
if fail_type is not None:
657-
error_type_counts[fail_type] = error_type_counts.get(fail_type, 0) + 1
658+
errno = get_errno_from_error_type(fail_type)
659+
errno2count[errno] = errno2count.get(errno, 0) + 1
658660

659661
# Store state at t=1 for ES(t) calculation
660662
if t_key == 1:
@@ -683,12 +685,12 @@ def print_stat_info(
683685
if t_key == 1:
684686
# Calculate pi at t=1 using the dedicated function
685687
pi = samples_statistics.calculate_pi(
686-
error_type_counts, total_samples, correct_speedups
688+
errno2count, total_samples, correct_speedups
687689
)
688690
final_correct_count = correct_count
689691
final_correct_negative_speedup_count = correct_negative_speedup_count
690692
final_correct_speedups = correct_speedups
691-
final_error_type_counts = error_type_counts.copy() # Save for t >= 1
693+
final_errno2count = errno2count.copy() # Save for t >= 1
692694

693695
if rectified_speedups:
694696
s_scores[t_key] = gmean(rectified_speedups)
@@ -700,17 +702,17 @@ def print_stat_info(
700702
expected_s, expected_es = print_stat_info(
701703
t_key,
702704
correct_count,
703-
error_type_counts,
705+
errno2count,
704706
pi,
705707
correct_negative_speedup_count,
706708
correct_speedups,
707709
)
708710
else:
709-
# For t >= 1, use error_type_counts from t=1 (frozen state)
711+
# For t >= 1, use errno2count from t=1 (frozen state)
710712
expected_s, expected_es = print_stat_info(
711713
t_key,
712714
final_correct_count,
713-
final_error_type_counts, # Use the frozen error_type_counts from t=1
715+
final_errno2count, # Use the frozen errno2count from t=1
714716
pi,
715717
final_correct_negative_speedup_count,
716718
final_correct_speedups,
@@ -722,6 +724,6 @@ def print_stat_info(
722724
s_scores._aggregated_results[t_key] = expected_s
723725
s_scores_fake_degrad._aggregated_results[t_key] = expected_es
724726

725-
print(f" - pi: {list(pi)}")
727+
print(f" - pi: {dict(sorted(pi.items()))}")
726728

727729
return s_scores, s_scores_fake_degrad

graph_net/samples_statistics.py

Lines changed: 82 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,56 @@
99
from collections.abc import Callable
1010

1111

12+
def get_errno_from_error_type(error_type: str) -> int:
13+
"""
14+
Map error type string to errno (error number) for sorting.
15+
16+
According to the paper:
17+
- c=1: accuracy errors (精度错误)
18+
- c=2: runtime crashes (运行时崩溃)
19+
- c=3: compilation failures (编译失败)
20+
21+
Args:
22+
error_type: Error type string (e.g., "accuracy", "eager", "compiled")
23+
24+
Returns:
25+
Errno (1, 2, or 3) based on error type
26+
"""
27+
if error_type == "accuracy":
28+
return 1
29+
elif error_type in ("eager", "other", "runtime_fail", "eager_fail"):
30+
return 2
31+
elif error_type in ("compiled", "compile_fail"):
32+
return 3
33+
else:
34+
# Default to 2 for unknown error types (runtime errors)
35+
return 2
36+
37+
38+
def get_error_type_from_errno(errno: int) -> str:
39+
"""
40+
Map errno (error number) back to error type string.
41+
42+
This is the reverse mapping of get_errno_from_error_type.
43+
Used when error type string information is needed.
44+
45+
Args:
46+
errno: Error number (1, 2, or 3)
47+
48+
Returns:
49+
Error type string:
50+
- 1 -> "accuracy"
51+
- 2 -> "runtime_fail"
52+
- 3 -> "compile_fail"
53+
"""
54+
errno_to_error_type = {
55+
1: "accuracy",
56+
2: "runtime_fail",
57+
3: "compile_fail",
58+
}
59+
return errno_to_error_type.get(errno, "runtime_fail")
60+
61+
1262
def calculate_alpha(correct_speedups: list[float]) -> float:
1363
"""
1464
Calculate alpha: geometric mean of correct sample speedups.
@@ -80,30 +130,31 @@ def calculate_eta(correct_speedups: list[float]) -> float:
80130

81131

82132
def calculate_pi(
83-
error_type_counts: dict[str, int], total_samples: int, correct_speedups: list[float]
84-
) -> dict[str, float]:
133+
errno2count: dict[int, int], total_samples: int, correct_speedups: list[float]
134+
) -> dict[int, float]:
85135
"""
86136
Calculate pi: error type proportions for t > 0.
87137
88138
According to Appendix C: pi_c is the proportion of error type c among all error samples.
89139
90140
Args:
91-
error_type_counts: Dictionary mapping error type names to their counts
141+
errno2count: Dictionary mapping errno (error number) to their counts.
142+
Errno values: 1=accuracy, 2=runtime, 3=compilation.
92143
total_samples: Total number of samples
93144
correct_speedups: List of speedup values for correct samples
94145
95146
Returns:
96-
Dictionary mapping error type names to their proportions among error samples.
147+
Dictionary mapping errno to their proportions among error samples.
97148
If error_count is 0, returns a dictionary with all proportions set to 0.0.
98149
"""
99150
correct_count = len(correct_speedups)
100151
error_count = total_samples - correct_count
101152
if error_count == 0:
102-
return {error_type: 0.0 for error_type in error_type_counts.keys()}
153+
return {errno: 0.0 for errno in errno2count.keys()}
103154

104155
pi = {}
105-
for error_type, count in error_type_counts.items():
106-
pi[error_type] = count / error_count
156+
for errno, count in errno2count.items():
157+
pi[errno] = count / error_count
107158
return pi
108159

109160

@@ -210,12 +261,12 @@ def calculate_es_t_from_aggregated(
210261
def calculate_all_aggregated_parameters(
211262
total_samples: int,
212263
correct_speedups: list[float],
213-
error_type_counts: dict[str, int],
264+
errno2count: dict[int, int],
214265
t_key: int,
215266
negative_speedup_penalty: float = 0.0,
216267
fpdb: float = 0.1,
217-
pi: dict[str, float] | None = None,
218-
error_tolerance_thresholds: dict[str, int] | None = None,
268+
pi: dict[int, float] | None = None,
269+
errno_tolerance_thresholds: dict[int, int] | None = None,
219270
) -> dict:
220271
"""
221272
Calculate all aggregated parameters for a given tolerance level.
@@ -225,15 +276,16 @@ def calculate_all_aggregated_parameters(
225276
Args:
226277
total_samples: Total number of samples
227278
correct_speedups: List of speedup values for correct samples
228-
error_type_counts: Dictionary mapping error type names to their counts
279+
errno2count: Dictionary mapping errno (error number) to their counts.
280+
Errno values: 1=accuracy, 2=runtime, 3=compilation.
229281
t_key: Tolerance level
230282
negative_speedup_penalty: Penalty power p for negative speedup
231283
fpdb: Base penalty b for severe errors
232-
pi: Dictionary mapping error type names to their proportions (calculated at t=1).
233-
If None, will be calculated from error_type_counts.
234-
error_tolerance_thresholds: Dictionary mapping error type names to their tolerance thresholds.
284+
pi: Dictionary mapping errno to their proportions (calculated at t=1).
285+
If None, will be calculated from errno2count.
286+
errno_tolerance_thresholds: Dictionary mapping errno to their tolerance thresholds.
235287
An error type is tolerated (not penalized) when t >= threshold.
236-
If None, uses default thresholds: {"accuracy": 1} for accuracy errors, 3 for others.
288+
If None, uses default thresholds: {1: 1} for accuracy errors (errno=1), {2: 3, 3: 3} for others.
237289
238290
Returns:
239291
Dictionary containing all aggregated parameters and calculated scores:
@@ -243,36 +295,34 @@ def calculate_all_aggregated_parameters(
243295
'lambda': float,
244296
'eta': float,
245297
'gamma': float,
246-
'pi': dict[str, float],
298+
'pi': dict[int, float],
247299
's_t': float,
248300
'es_t': float
249301
}
250302
"""
251303
# Use default error tolerance thresholds if not provided
252-
if error_tolerance_thresholds is None:
253-
error_tolerance_thresholds = {}
254-
for error_type in error_type_counts.keys():
255-
if error_type == "accuracy":
256-
error_tolerance_thresholds[error_type] = 1
257-
else:
258-
error_tolerance_thresholds[error_type] = 3
304+
if errno_tolerance_thresholds is None:
305+
errno_tolerance_thresholds = {}
306+
for errno in errno2count.keys():
307+
if errno == 1: # accuracy errors
308+
errno_tolerance_thresholds[errno] = 1
309+
else: # runtime (2) or compilation (3) errors
310+
errno_tolerance_thresholds[errno] = 3
259311

260312
# Calculate pi if not provided
261313
if pi is None:
262-
pi = calculate_pi(error_type_counts, total_samples, correct_speedups)
314+
pi = calculate_pi(errno2count, total_samples, correct_speedups)
263315

264316
# Convert dictionary-based pi and thresholds to indexed format for calculate_gamma
265-
# Create ordered list of error types for consistent indexing
266-
error_types = sorted(error_type_counts.keys())
267-
errno_tolerances = [
268-
error_tolerance_thresholds.get(error_type, 3) for error_type in error_types
269-
]
317+
# Create ordered list of errnos for consistent indexing (sorted by errno)
318+
errnos = sorted(errno2count.keys())
319+
errno_tolerances = [errno_tolerance_thresholds.get(errno, 3) for errno in errnos]
270320

271321
# Create get_pi function that maps error type index to pi value
272322
def get_pi(error_type_index: int) -> float:
273-
if error_type_index < len(error_types):
274-
error_type = error_types[error_type_index]
275-
return pi.get(error_type, 0.0)
323+
if error_type_index < len(errnos):
324+
errno = errnos[error_type_index]
325+
return pi.get(errno, 0.0)
276326
return 0.0
277327

278328
alpha = calculate_alpha(correct_speedups)

graph_net/verify_aggregated_params.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
from collections import OrderedDict, Counter
55
from graph_net import analysis_util
66
from graph_net import samples_statistics
7+
from graph_net.samples_statistics import (
8+
get_errno_from_error_type,
9+
get_error_type_from_errno,
10+
)
711

812

913
def calculate_aggregated_parameters(
@@ -45,7 +49,7 @@ def calculate_aggregated_parameters(
4549
final_correct_negative_speedup_count = 0
4650
final_correct_speedups = []
4751
final_slowdown_speedups = []
48-
final_error_type_counts = {} # Store error type counts at t=1
52+
final_errno2count = {} # Store error type counts at t=1 (using errno)
4953

5054
results = OrderedDict()
5155

@@ -74,10 +78,10 @@ def calculate_aggregated_parameters(
7478
slowdown_speedups = [speedup for speedup in correct_speedups if speedup < 1]
7579
correct_negative_speedup_count = len(slowdown_speedups)
7680

77-
# Count errors by type using Counter
78-
error_type_counts = dict(
81+
# Count errors by errno using Counter (convert error type string to errno)
82+
errno2count = dict(
7983
Counter(
80-
fail_type
84+
get_errno_from_error_type(fail_type)
8185
for _, _, _, _, fail_type in sample_data
8286
if fail_type is not None
8387
)
@@ -101,13 +105,13 @@ def calculate_aggregated_parameters(
101105
# Calculate pi at t=1 using the dedicated function
102106
if t_key == 1:
103107
pi = samples_statistics.calculate_pi(
104-
error_type_counts, total_samples, correct_speedups
108+
errno2count, total_samples, correct_speedups
105109
)
106110
final_correct_count = correct_count
107111
final_correct_negative_speedup_count = correct_negative_speedup_count
108112
final_correct_speedups = correct_speedups
109113
final_slowdown_speedups = slowdown_speedups
110-
final_error_type_counts = error_type_counts.copy() # Save for t >= 1
114+
final_errno2count = errno2count.copy() # Save for t >= 1
111115

112116
# Calculate aggregated parameters
113117
if total_samples > 0:
@@ -127,16 +131,16 @@ def calculate_aggregated_parameters(
127131
stats_slowdown_speedups = final_slowdown_speedups
128132

129133
# Calculate all aggregated parameters using the dedicated module
130-
# For t >= 1, use error_type_counts from t=1 (frozen state)
134+
# For t >= 1, use errno2count from t=1 (frozen state)
131135
if t_key < 1:
132-
stats_error_type_counts = error_type_counts
136+
stats_errno2count = errno2count
133137
else:
134-
stats_error_type_counts = final_error_type_counts # Use frozen from t=1
138+
stats_errno2count = final_errno2count # Use frozen from t=1
135139

136140
aggregated_params = samples_statistics.calculate_all_aggregated_parameters(
137141
total_samples=total_samples,
138142
correct_speedups=stats_correct_speedups,
139-
error_type_counts=stats_error_type_counts,
143+
errno2count=stats_errno2count,
140144
t_key=t_key,
141145
negative_speedup_penalty=negative_speedup_penalty,
142146
fpdb=fpdb,
@@ -184,9 +188,17 @@ def calculate_aggregated_parameters(
184188
)
185189
print(f" gamma (average error penalty): {gamma:.6f}")
186190
if t_key >= 1:
191+
# pi is now dict[int, float], convert to list for display
192+
errnos = sorted(pi.keys())
193+
pi_list = [pi[errno] for errno in errnos]
187194
indicator = [1 if t_key < 1 else 0, 1 if t_key < 3 else 0]
188-
pi_indicator_sum = sum(pi[i] * indicator[i] for i in range(len(pi)))
189-
print(f" - pi: {list(pi)}")
195+
# Calculate pi_indicator_sum using errno-based pi
196+
pi_indicator_sum = sum(
197+
pi.get(errno, 0.0) * indicator[min(i, len(indicator) - 1)]
198+
for i, errno in enumerate(errnos)
199+
)
200+
print(f" - pi (errno -> proportion): {dict(sorted(pi.items()))}")
201+
print(f" - pi (as list): {pi_list}")
190202
print(f" - indicator: {indicator}")
191203
print(
192204
f" - gamma = fpdb^(sum(pi[i] * indicator[i])) = {fpdb}^{pi_indicator_sum:.6f} = {gamma:.6f}"

0 commit comments

Comments
 (0)