Skip to content

Commit 846e389

Browse files
【Hackathon 9th Sprint No.89】feat:Add extend logic for fake_perf_degrad (#423)
* Hackathon NO.89 Update 修改fake_perf_degrad及其后续逻辑 * Code Fix 修改了comment中的要求 * Update samples_statistics.py * Revert "Update samples_statistics.py" This reverts commit 5381b74. * Revert "Code Fix" This reverts commit 55caedc. * Code Fix 按照comment修复了代码 * Code Fix 2 按照comment更新代码 * code Fix 3 修复代码结构,使用pre-commit * code Fix 4 修复了枚举逻辑和映射逻辑,按照新的映射逻辑修复了容忍度判断 --------- Co-authored-by: 699574 <[email protected]>
1 parent d6bccea commit 846e389

8 files changed

+398
-117
lines changed

graph_net/analysis_util.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import sys
44
from scipy.stats import gmean
55
from graph_net.config.datatype_tolerance_config import get_precision
6+
from graph_net.positive_tolerance_interpretation import PositiveToleranceInterpretation
7+
from graph_net.verify_aggregated_params import determine_tolerances
68

79

810
def detect_sample_status(log_text: str) -> str:
@@ -293,38 +295,24 @@ def get_correctness(dtype: str, t: int, correctness_data: dict, index: int) -> b
293295
return False
294296

295297

296-
def fake_perf_degrad(tolerance, error_code, type="default") -> str:
298+
def fake_perf_degrad(
299+
tolerance,
300+
error_code,
301+
positive_tolerance_interpretation: PositiveToleranceInterpretation,
302+
) -> str:
297303
"""
298304
Judge current correctness based on tolerance t and status.
305+
Refactored to delegate logic to PositiveToleranceInterpretation classes.
299306
"""
300-
if type == "default":
301-
if tolerance >= 3:
302-
return "correct"
303-
elif error_code == "accuracy" and tolerance >= 1:
304-
return "correct"
305-
else:
306-
return error_code
307-
elif type == "extended":
308-
if (
309-
error_code == "compile_fail" or error_code == "runtime_fail"
310-
) and tolerance >= 4:
311-
return "correct"
312-
elif error_code == "eager_fail" and tolerance >= 3:
313-
return "correct"
314-
elif (
315-
error_code == "shape_mismatch" or error_code == "type_mismatch"
316-
) and tolerance >= 2:
317-
return "correct"
318-
elif error_code == "accuracy" and tolerance >= 1:
319-
return "correct"
320-
else:
321-
return error_code
322-
else:
323-
raise NotImplementedError
307+
if positive_tolerance_interpretation.is_error_tolerated(tolerance, error_code):
308+
return "correct"
309+
310+
return error_code
324311

325312

326313
def calculate_scores(
327314
samples: list,
315+
positive_tolerance_interpretation: PositiveToleranceInterpretation,
328316
p: float = 0,
329317
b: float = 0.1,
330318
type: str = "ESt",
@@ -339,7 +327,10 @@ def calculate_scores(
339327

340328
scores = {}
341329

342-
for tolerance in range(-10, 5):
330+
strategy = positive_tolerance_interpretation
331+
tolerances = determine_tolerances(samples, positive_tolerance_interpretation)
332+
333+
for tolerance in tolerances:
343334
rectified_speedups = []
344335
rectified_speedups_fake_degrad = []
345336

@@ -373,12 +364,10 @@ def calculate_scores(
373364
)
374365
else:
375366
if not is_correct_at_t1[idx]:
376-
current_correctness = fake_perf_degrad(
367+
is_tolerated = strategy.is_error_tolerated(
377368
tolerance, fail_type_at_t1[idx]
378369
)
379-
rec_speedup_fake_degrad = (
380-
1 if current_correctness == "correct" else b
381-
)
370+
rec_speedup_fake_degrad = 1 if is_tolerated else b
382371
else:
383372
rec_speedup_fake_degrad = (
384373
speedup_at_t1[idx] ** (p + 1)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from enum import IntEnum
2+
3+
from graph_net.positive_tolerance_interpretation import PositiveToleranceInterpretation
4+
5+
6+
class DefaultErrorEnum(IntEnum):
7+
"""
8+
Values correspond to the minimum tolerance level required.
9+
"""
10+
11+
kAccuracyViolation = 1 # Accuracy
12+
kRuntimeFailure = 2 # Includes Runtime, NaN, Inf, TypeMismatch, etc.
13+
kCompilationFailed = 3 # Compile Failure
14+
15+
@classmethod
16+
def get_error_enum(cls, base_error_type: str) -> "DefaultErrorEnum":
17+
if not base_error_type:
18+
return cls.kRuntimeFailure
19+
20+
etype = base_error_type.lower()
21+
22+
if "accuracy" in etype:
23+
return cls.kAccuracyViolation
24+
25+
if "compile_fail" in etype:
26+
return cls.kCompilationFailed
27+
28+
return cls.kRuntimeFailure
29+
30+
31+
class DefaultPositiveToleranceInterpretation(PositiveToleranceInterpretation):
32+
"""
33+
Legacy interpretation:
34+
- t=1: Accuracy errors tolerated.
35+
- t=3: Runtime/Compilation errors tolerated.
36+
"""
37+
38+
def __init__(self, *argc, **kwargs):
39+
super().__init__(*argc, **kwargs)
40+
41+
def type_name(self) -> str:
42+
return "default"
43+
44+
def get_errno(self, error_type: str) -> int:
45+
return DefaultErrorEnum.get_error_enum(error_type).value
46+
47+
def get_error_type(self, errno: int) -> str:
48+
mapping = {1: "accuracy", 2: "runtime_fail", 3: "compile_fail"}
49+
return mapping.get(errno, "unknown_error")
50+
51+
def get_tolerance_mapping(self) -> dict[int, int]:
52+
return {
53+
DefaultErrorEnum.kAccuracyViolation.value: 1,
54+
DefaultErrorEnum.kRuntimeFailure.value: 3,
55+
DefaultErrorEnum.kCompilationFailed.value: 3,
56+
}
57+
58+
def is_error_tolerated(self, tolerance: int, base_error_code: str) -> bool:
59+
if base_error_code == "correct":
60+
return True
61+
if base_error_code in ["eager_fail", "reference_fail"]:
62+
return False
63+
64+
error_enum = DefaultErrorEnum.get_error_enum(base_error_code)
65+
mapping = self.get_tolerance_mapping()
66+
required_threshold = mapping.get(error_enum.value, 999)
67+
68+
return tolerance >= required_threshold
69+
70+
def num_errno_enum_values(self) -> int:
71+
"""
72+
Default mode defines 3 levels of errors:
73+
1: Accuracy
74+
2: Runtime (Generic)
75+
3: Compilation
76+
"""
77+
return len(DefaultErrorEnum)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from enum import IntEnum
2+
3+
from graph_net.positive_tolerance_interpretation import PositiveToleranceInterpretation
4+
5+
6+
class MismatchExtendedErrorEnum(IntEnum):
7+
"""
8+
Values correspond to the minimum tolerance level required.
9+
"""
10+
11+
kAccuracyViolation = 1
12+
kValueTypeOrMetaMismatch = 2
13+
kExecutionFailed = 3
14+
kCompilationFailed = 4
15+
16+
@classmethod
17+
def get_error_enum(cls, base_error_type: str) -> "MismatchExtendedErrorEnum":
18+
if not base_error_type:
19+
return cls.kExecutionFailed
20+
21+
etype = base_error_type.lower()
22+
if "accuracy" in etype:
23+
return cls.kAccuracyViolation
24+
if any(x in etype for x in ["nan", "inf", "type_mismatch", "shape_mismatch"]):
25+
return cls.kValueTypeOrMetaMismatch
26+
if "compile_fail" in etype:
27+
return cls.kCompilationFailed
28+
29+
return cls.kExecutionFailed
30+
31+
32+
class MismatchExtendedPositiveToleranceInterpretation(PositiveToleranceInterpretation):
33+
"""
34+
Extended interpretation (ESt):
35+
- t=1: Accuracy
36+
- t=2: NaN/Inf/Type/Shape
37+
- t=3: Runtime
38+
- t=4: Compile
39+
"""
40+
41+
def __init__(self, *argc, **kwargs):
42+
super().__init__(*argc, **kwargs)
43+
44+
def type_name(self) -> str:
45+
return "mismatch_extended"
46+
47+
def get_errno(self, error_type: str) -> int:
48+
return MismatchExtendedErrorEnum.get_error_enum(error_type).value
49+
50+
def get_error_type(self, errno: int) -> str:
51+
mapping = {
52+
1: "accuracy",
53+
2: "type/shape_mismatch",
54+
3: "runtime_fail",
55+
4: "compile_fail",
56+
}
57+
return mapping.get(errno, "unknown_error")
58+
59+
def get_tolerance_mapping(self) -> dict[int, int]:
60+
return {
61+
MismatchExtendedErrorEnum.kAccuracyViolation.value: 1,
62+
MismatchExtendedErrorEnum.kValueTypeOrMetaMismatch.value: 2,
63+
MismatchExtendedErrorEnum.kExecutionFailed.value: 3,
64+
MismatchExtendedErrorEnum.kCompilationFailed.value: 4,
65+
}
66+
67+
def is_error_tolerated(self, tolerance: int, base_error_code: str) -> bool:
68+
if base_error_code == "correct":
69+
return True
70+
if base_error_code in ["eager_fail", "reference_fail"]:
71+
return False
72+
73+
error_enum = MismatchExtendedErrorEnum.get_error_enum(base_error_code)
74+
mapping = self.get_tolerance_mapping()
75+
required_threshold = mapping.get(error_enum.value, 999)
76+
77+
return tolerance >= required_threshold
78+
79+
def num_errno_enum_values(self) -> int:
80+
"""
81+
Extended mode defines 4 levels of errors:
82+
1: Accuracy
83+
2: Type/Shape/NaN
84+
3: Runtime
85+
4: Compilation
86+
"""
87+
return len(MismatchExtendedErrorEnum)

graph_net/plot_ESt.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
import matplotlib.pyplot as plt
55
from graph_net import analysis_util
66
from graph_net import verify_aggregated_params
7+
from graph_net.positive_tolerance_interpretation_manager import (
8+
get_supported_positive_tolerance_interpretation_types,
9+
get_positive_tolerance_interpretation,
10+
)
711

812

913
class ESScoresWrapper:
@@ -262,6 +266,9 @@ def main(args):
262266
# 2. Calculate scores for each curve and verify aggregated/microscopic consistency
263267
all_es_scores = {}
264268
all_aggregated_results = {}
269+
positive_tolerance_interpretation = get_positive_tolerance_interpretation(
270+
args.interpretation_type
271+
)
265272

266273
for folder_name, samples in all_results.items():
267274
print(f"\nCalculating ESt scores for '{folder_name}'...")
@@ -271,6 +278,7 @@ def main(args):
271278
p=args.negative_speedup_penalty,
272279
b=args.fpdb,
273280
type="ESt",
281+
positive_tolerance_interpretation=positive_tolerance_interpretation,
274282
)
275283

276284
# Keep original behavior: assign es_scores directly
@@ -285,6 +293,7 @@ def main(args):
285293
folder_name,
286294
negative_speedup_penalty=args.negative_speedup_penalty,
287295
fpdb=args.fpdb,
296+
positive_tolerance_interpretation=positive_tolerance_interpretation,
288297
)
289298
)
290299
# Store aggregated results for plotting
@@ -429,6 +438,13 @@ def main(args):
429438
action="store_false",
430439
help="Disable aggregation mode verification.",
431440
)
441+
parser.add_argument(
442+
"--positive-tolerance-interpretation",
443+
dest="interpretation_type",
444+
choices=get_supported_positive_tolerance_interpretation_types(),
445+
default="default",
446+
help="Select how positive tolerance values are interpreted into error types.",
447+
)
432448
parser.set_defaults(enable_aggregation_mode=True)
433449
args = parser.parse_args()
434450
main(args)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from abc import ABC, abstractmethod
2+
3+
4+
class PositiveToleranceInterpretation(ABC):
5+
"""
6+
Abstract base class defining how positive tolerance values (t > 0)
7+
are interpreted and mapped to specific error types.
8+
"""
9+
10+
def __init__(self, *argc, **kwargs):
11+
pass
12+
13+
@abstractmethod
14+
def type_name(self) -> str:
15+
"""Return the unique string identifier for this interpretation strategy."""
16+
raise NotImplementedError
17+
18+
@abstractmethod
19+
def get_errno(self, error_type: str) -> int:
20+
"""Map a raw error type string to an internal error number (errno)."""
21+
raise NotImplementedError
22+
23+
@abstractmethod
24+
def get_error_type(self, errno: int) -> str:
25+
"""Map an internal error number (errno) back to a representative string."""
26+
raise NotImplementedError
27+
28+
@abstractmethod
29+
def get_tolerance_mapping(self) -> dict[int, int]:
30+
"""
31+
Return the mapping of errno.
32+
Used for statistical calculations (Gamma/Pi).
33+
"""
34+
raise NotImplementedError
35+
36+
@abstractmethod
37+
def is_error_tolerated(self, tolerance: int, base_error_code: str) -> bool:
38+
"""
39+
Determine if a specific error is considered 'correct' under the given tolerance.
40+
Replaces the old 'fake_perf_degrad' logic.
41+
"""
42+
raise NotImplementedError
43+
44+
@abstractmethod
45+
def num_errno_enum_values(self) -> int:
46+
"""
47+
Return the number of defined error categories (or the maximum errno).
48+
49+
Example:
50+
- Default: returns 3 (Accuracy, Runtime, Compile)
51+
- MismatchExtended: returns 4 (Accuracy, Data, Runtime, Compile)
52+
"""
53+
raise NotImplementedError

0 commit comments

Comments
 (0)