Skip to content

Commit 359ed0c

Browse files
committed
Code Fix 2
按照comment更新代码
1 parent 218c860 commit 359ed0c

File tree

5 files changed

+62
-64
lines changed

5 files changed

+62
-64
lines changed

graph_net/analysis_util.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44
from scipy.stats import gmean
55
from graph_net.config.datatype_tolerance_config import get_precision
6-
from graph_net.positive_tolerance_interpretation_manager import get_positive_tolerance_interpretation
6+
from graph_net.positive_tolerance_interpretation import PositiveToleranceInterpretation
77
from graph_net.verify_aggregated_params import determine_tolerances
88

99

@@ -294,25 +294,22 @@ def get_correctness(dtype: str, t: int, correctness_data: dict, index: int) -> b
294294
return bool(result[index])
295295
return False
296296

297-
def fake_perf_degrad(tolerance, error_code, type="default") -> str:
297+
def fake_perf_degrad(tolerance, error_code, positive_tolerance_interpretation: PositiveToleranceInterpretation,) -> str:
298298
"""
299299
Judge current correctness based on tolerance t and status.
300300
Refactored to delegate logic to PositiveToleranceInterpretation classes.
301301
"""
302-
303-
strategy = get_positive_tolerance_interpretation(type)
304-
305-
if strategy.is_error_tolerated(tolerance, error_code):
302+
if positive_tolerance_interpretation.is_error_tolerated(tolerance, error_code):
306303
return "correct"
307304

308305
return error_code
309306

310307
def calculate_scores(
311308
samples: list,
309+
positive_tolerance_interpretation: PositiveToleranceInterpretation,
312310
p: float = 0,
313311
b: float = 0.1,
314312
type: str = "ESt",
315-
interpretation_type: str = "default",
316313
) -> tuple:
317314
"""
318315
Use a standard tolerance to evaluate all samples and calculate S(t) and ES(t) scores for each tolerance level.
@@ -324,8 +321,8 @@ def calculate_scores(
324321

325322
scores = {}
326323

327-
strategy = get_positive_tolerance_interpretation(interpretation_type)
328-
tolerances = determine_tolerances(samples,interpretation_type)
324+
strategy = positive_tolerance_interpretation
325+
tolerances = determine_tolerances(samples,positive_tolerance_interpretation)
329326

330327
for tolerance in tolerances:
331328
rectified_speedups = []

graph_net/plot_ESt.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import matplotlib.pyplot as plt
55
from graph_net import analysis_util
66
from graph_net import verify_aggregated_params
7-
from graph_net.positive_tolerance_interpretation_manager import g_type_name2_positive_tolerance_interpretation_cls
8-
7+
from graph_net.positive_tolerance_interpretation_manager import (
8+
get_supported_positive_tolerance_interpretation_types, get_positive_tolerance_interpretation
9+
)
910

1011
class ESScoresWrapper:
1112
"""Wrapper for es_scores dict to allow attribute assignment."""
@@ -263,6 +264,7 @@ def main(args):
263264
# 2. Calculate scores for each curve and verify aggregated/microscopic consistency
264265
all_es_scores = {}
265266
all_aggregated_results = {}
267+
positive_tolerance_interpretation = get_positive_tolerance_interpretation(args.interpretation_type)
266268

267269
for folder_name, samples in all_results.items():
268270
print(f"\nCalculating ESt scores for '{folder_name}'...")
@@ -272,7 +274,7 @@ def main(args):
272274
p=args.negative_speedup_penalty,
273275
b=args.fpdb,
274276
type="ESt",
275-
interpretation_type=args.interpretation_type,
277+
positive_tolerance_interpretation=positive_tolerance_interpretation,
276278
)
277279

278280
# Keep original behavior: assign es_scores directly
@@ -287,7 +289,7 @@ def main(args):
287289
folder_name,
288290
negative_speedup_penalty=args.negative_speedup_penalty,
289291
fpdb=args.fpdb,
290-
interpretation_type=args.interpretation_type
292+
positive_tolerance_interpretation=positive_tolerance_interpretation,
291293
)
292294
)
293295
# Store aggregated results for plotting
@@ -435,7 +437,7 @@ def main(args):
435437
parser.add_argument(
436438
"--positive-tolerance-interpretation",
437439
dest="interpretation_type",
438-
choices=list(g_type_name2_positive_tolerance_interpretation_cls.keys()),
440+
choices=get_supported_positive_tolerance_interpretation_types(),
439441
default="default",
440442
help="Select how positive tolerance values are interpreted into error types."
441443
)
Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,9 @@
1-
from typing import Type
1+
from typing import Type, List
22
from graph_net.positive_tolerance_interpretation import PositiveToleranceInterpretation
33
from graph_net.default_positive_tolerance_interpretation import DefaultPositiveToleranceInterpretation
44
from graph_net.mismatch_extended_positive_tolerance_interpretation import \
55
MismatchExtendedPositiveToleranceInterpretation
66

7-
# Registry of available classes
8-
g_type_name2_positive_tolerance_interpretation_cls: dict[str, Type[PositiveToleranceInterpretation]] = {
9-
'default': DefaultPositiveToleranceInterpretation,
10-
'mismatch_extended': MismatchExtendedPositiveToleranceInterpretation
11-
}
12-
13-
147
def get_positive_tolerance_interpretation(type_name: str) -> PositiveToleranceInterpretation:
158
"""
169
Factory function to retrieve an instance of the requested interpretation strategy.
@@ -24,11 +17,20 @@ def get_positive_tolerance_interpretation(type_name: str) -> PositiveToleranceIn
2417
Raises:
2518
ValueError: If type_name is not registered.
2619
"""
27-
if type_name not in g_type_name2_positive_tolerance_interpretation_cls:
28-
supported = list(g_type_name2_positive_tolerance_interpretation_cls.keys())
20+
if type_name not in _g_type_name2_positive_tolerance_interpretation_cls:
21+
supported = list(_g_type_name2_positive_tolerance_interpretation_cls.keys())
2922
raise ValueError(f"Unknown positive tolerance interpretation: '{type_name}'. Supported: {supported}")
3023

3124
# Instantiate and return.
3225
# If stateful caching is needed, this logic can be modified to return singletons.
33-
cls = g_type_name2_positive_tolerance_interpretation_cls[type_name]
34-
return cls()
26+
cls = _g_type_name2_positive_tolerance_interpretation_cls[type_name]
27+
return cls()
28+
29+
def get_supported_positive_tolerance_interpretation_types() -> List[str]:
30+
return list(_g_type_name2_positive_tolerance_interpretation_cls.keys())
31+
32+
# Registry of available classes
33+
_g_type_name2_positive_tolerance_interpretation_cls: dict[str, Type[PositiveToleranceInterpretation]] = {
34+
'default': DefaultPositiveToleranceInterpretation,
35+
'mismatch_extended': MismatchExtendedPositiveToleranceInterpretation
36+
}

graph_net/samples_statistics.py

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,24 @@
99
import scipy
1010
from scipy.stats import gmean
1111
from collections.abc import Callable
12-
from abc import ABC, abstractmethod
12+
from graph_net.positive_tolerance_interpretation import PositiveToleranceInterpretation
1313

14-
from graph_net.positive_tolerance_interpretation_manager import get_positive_tolerance_interpretation
15-
16-
def get_errno_from_error_type(error_type: str, interpretation_type: str = "default") -> int:
14+
def get_errno_from_error_type(error_type: str,
15+
positive_tolerance_interpretation: PositiveToleranceInterpretation) -> int:
1716
"""
1817
Map error type string to errno (error number) using the appropriate strategy.
1918
2019
Args:
2120
error_type: Error type string (e.g., "accuracy", "runtime_fail")
22-
interpretation_type: Evaluation mode ("default" or "mismatch_extended")
21+
positive_tolerance_interpretation: Evaluation mode ("default" or "mismatch_extended")
2322
2423
Returns:
25-
int: Errno based on the selected interpretation_type's logic.
24+
int: Errno based on the selected positive_tolerance_interpretation's logic.
2625
"""
27-
interpreter = get_positive_tolerance_interpretation(interpretation_type)
28-
return interpreter.get_errno(error_type)
26+
return positive_tolerance_interpretation.get_errno(error_type)
2927

30-
def get_errno_tolerance_mapping(custom_mapping, interpretation_type: str = "default"):
28+
def get_errno_tolerance_mapping(custom_mapping,
29+
positive_tolerance_interpretation: PositiveToleranceInterpretation):
3130
"""
3231
Map errno (error number) back to error type string.
3332
@@ -36,14 +35,13 @@ def get_errno_tolerance_mapping(custom_mapping, interpretation_type: str = "defa
3635
3736
Args:
3837
errno: Error number
39-
interpretation_type: Evaluation mode ("default" or "mismatch_extended")
38+
positive_tolerance_interpretation: Evaluation mode ("default" or "mismatch_extended")
4039
4140
Returns:
4241
Representative error type string (e.g., "accuracy", "compile_fail")
4342
"""
4443
if custom_mapping: return custom_mapping
45-
interpreter = get_positive_tolerance_interpretation(interpretation_type)
46-
return interpreter.get_tolerance_mapping()
44+
return positive_tolerance_interpretation.get_tolerance_mapping()
4745

4846
def calculate_alpha(correct_speedups: list[float]) -> float:
4947
"""
@@ -116,8 +114,8 @@ def calculate_eta(correct_speedups: list[float]) -> float:
116114

117115

118116
def calculate_pi(
119-
errno2count: dict[int, int], total_samples: int, correct_speedups: list[float]
120-
) -> dict[int, float]:
117+
errno2count: dict[Union[int, str], int], total_samples: int, correct_speedups: list[float]
118+
) -> dict[Union[int, str], float]:
121119
"""
122120
Calculate pi: error type proportions for t > 0.
123121
@@ -285,26 +283,22 @@ def calculate_es_t_from_aggregated(
285283
def calculate_es_components_values(
286284
total_samples: int,
287285
correct_speedups: list[float],
288-
errno2count: dict[Union[int, str], int], # 更新类型注解支持 str
286+
errno2count: dict[Union[int, str], int], # support str
289287
tolerance: int,
288+
positive_tolerance_interpretation: PositiveToleranceInterpretation,
290289
negative_speedup_penalty: float = 0.0,
291290
b: float = 0.1,
292291
pi: Optional[dict[Union[int, str], float]] = None,
293292
errno_to_tolerance: Optional[dict[Union[int, str], int]] = None,
294-
interpretation_type: str = "default"
295293
) -> dict:
296294
"""
297295
Calculate aggregated parameters for a given tolerance level.
298-
299-
Args:
300-
...
301-
interpretation_type: "default" (int error codes) or "mismatch_extended" (str error codes).
302296
"""
303297

304298
if pi is None:
305299
pi = calculate_pi(errno2count, total_samples, correct_speedups)
306300

307-
errno_to_tolerance = get_errno_tolerance_mapping(errno_to_tolerance, interpretation_type)
301+
errno_to_tolerance = get_errno_tolerance_mapping(errno_to_tolerance, positive_tolerance_interpretation)
308302

309303
errno2tolerance = resolve_errno_tolerance(errno2count, errno_to_tolerance)
310304

graph_net/verify_aggregated_params.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,26 @@
22
from collections import OrderedDict, Counter
33
from graph_net import analysis_util
44
from graph_net import samples_statistics
5+
from graph_net.positive_tolerance_interpretation import PositiveToleranceInterpretation
56
from graph_net.samples_statistics import (
67
get_errno_from_error_type,
78
)
89

910

10-
def determine_tolerances(samples: list,interpretation_type = "default") -> range:
11+
def determine_tolerances(samples: list,
12+
positive_tolerance_interpretation: PositiveToleranceInterpretation,) -> range:
1113
"""Determine tolerance range based on observed errno categories."""
1214
# Currently errno categories are 1=accuracy, 2=runtime, 3=compile.
1315
# Keep logic data-driven for future extension.
14-
if interpretation_type == "default":
15-
default_errnos = {1, 2, 3}
16-
elif interpretation_type == "mismatch_extended":
17-
default_errnos = {1, 2, 3, 4}
18-
return range(-10, len(default_errnos) + 2)
16+
mapping = positive_tolerance_interpretation.get_tolerance_mapping()
1917

18+
if not mapping:
19+
max_errno = 3
20+
else:
21+
max_errno = max(mapping.keys())
22+
return range(-10, max_errno + 2)
2023

21-
def extract_statistics_at_tolerance(samples: list, tolerance: int,interpretation_type:str = "default") -> dict:
24+
def extract_statistics_at_tolerance(samples: list, tolerance: int,positive_tolerance_interpretation: PositiveToleranceInterpretation) -> dict:
2225
"""Extract statistics for a given tolerance level."""
2326
sample_data = [
2427
(
@@ -42,7 +45,7 @@ def extract_statistics_at_tolerance(samples: list, tolerance: int,interpretation
4245

4346
errno2count = dict(
4447
Counter(
45-
get_errno_from_error_type(fail_type,interpretation_type)
48+
get_errno_from_error_type(fail_type,positive_tolerance_interpretation)
4649
for _, _, _, _, fail_type in sample_data
4750
if fail_type is not None
4851
)
@@ -105,7 +108,7 @@ def calculate_es_constructor_params_for_tolerance(
105108
pi: dict,
106109
negative_speedup_penalty: float,
107110
fpdb: float,
108-
interpretation_type: str = "default",
111+
positive_tolerance_interpretation: PositiveToleranceInterpretation
109112
) -> dict:
110113
"""Calculate ES(t) constructor parameters (alpha, beta, gamma, lambda, eta) and final scores for a tolerance level."""
111114
aggregated_params = samples_statistics.calculate_es_components_values(
@@ -116,7 +119,7 @@ def calculate_es_constructor_params_for_tolerance(
116119
negative_speedup_penalty=negative_speedup_penalty,
117120
b=fpdb,
118121
pi=pi,
119-
interpretation_type=interpretation_type,
122+
positive_tolerance_interpretation=positive_tolerance_interpretation,
120123
)
121124

122125
alpha = aggregated_params["alpha"]
@@ -205,7 +208,7 @@ def __init__(
205208
total_samples: int,
206209
negative_speedup_penalty: float,
207210
fpdb: float,
208-
interpretation_type: str = "default",
211+
positive_tolerance_interpretation: PositiveToleranceInterpretation,
209212
):
210213
self.samples = samples
211214
self.total_samples = total_samples
@@ -218,10 +221,10 @@ def __init__(
218221
"slowdown_speedups": [],
219222
"errno2count": {},
220223
}
221-
self.interpretation_type = interpretation_type
224+
self.positive_tolerance_interpretation = positive_tolerance_interpretation
222225

223226
def build_report(self, tolerance: int) -> dict:
224-
current_stats = extract_statistics_at_tolerance(self.samples, tolerance,self.interpretation_type)
227+
current_stats = extract_statistics_at_tolerance(self.samples, tolerance,self.positive_tolerance_interpretation)
225228

226229
if tolerance == 1:
227230
self.pi = _freeze_statistics_at_tolerance(
@@ -247,7 +250,7 @@ def build_report(self, tolerance: int) -> dict:
247250
pi_for_calc,
248251
self.negative_speedup_penalty,
249252
self.fpdb,
250-
self.interpretation_type,
253+
self.positive_tolerance_interpretation,
251254
)
252255
# Use calculated pi from es_constructor_params for display and return
253256
calculated_pi = es_constructor_params.get("pi", self.pi)
@@ -289,9 +292,9 @@ def _empty_report(self, tolerance: int) -> dict:
289292
def verify_es_constructor_params_across_tolerances(
290293
samples: list,
291294
folder_name: str,
295+
positive_tolerance_interpretation: PositiveToleranceInterpretation,
292296
negative_speedup_penalty: float = 0,
293297
fpdb: float = 0.1,
294-
interpretation_type: str = "default",
295298
) -> dict:
296299
"""
297300
Verify and print ES constructor parameters (alpha, beta, gamma, lambda, eta, pi) for each
@@ -307,13 +310,13 @@ def verify_es_constructor_params_across_tolerances(
307310
print(f"Verifying Aggregated Parameters for '{folder_name}'")
308311
print(f"{'=' * 80}")
309312

310-
tolerances = determine_tolerances(samples,interpretation_type)
313+
tolerances = determine_tolerances(samples,positive_tolerance_interpretation)
311314
builder = ToleranceReportBuilder(
312315
samples=samples,
313316
total_samples=total_samples,
314317
negative_speedup_penalty=negative_speedup_penalty,
315318
fpdb=fpdb,
316-
interpretation_type=interpretation_type,
319+
positive_tolerance_interpretation=positive_tolerance_interpretation,
317320
)
318321

319322
results = OrderedDict(

0 commit comments

Comments
 (0)