Skip to content

Commit 85b4721

Browse files
committed
code Fix 4
修复了枚举逻辑和映射逻辑,按照新的映射逻辑修复了容忍度判断
1 parent 06cab86 commit 85b4721

File tree

5 files changed

+78
-55
lines changed

5 files changed

+78
-55
lines changed

graph_net/default_positive_tolerance_interpretation.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,24 @@ class DefaultErrorEnum(IntEnum):
88
Values correspond to the minimum tolerance level required.
99
"""
1010

11-
kAccuracyViolation = 1
12-
kValueTypeOrMetaMismatch = 3
13-
kExecutionFailed = 3
14-
kCompilationFailed = 3
11+
kAccuracyViolation = 1 # Accuracy
12+
kRuntimeFailure = 2 # Includes Runtime, NaN, Inf, TypeMismatch, etc.
13+
kCompilationFailed = 3 # Compile Failure
1514

1615
@classmethod
1716
def get_error_enum(cls, base_error_type: str) -> "DefaultErrorEnum":
1817
if not base_error_type:
19-
return cls.kExecutionFailed
18+
return cls.kRuntimeFailure
2019

2120
etype = base_error_type.lower()
2221

2322
if "accuracy" in etype:
2423
return cls.kAccuracyViolation
2524

26-
if any(x in etype for x in ["nan", "inf", "type_mismatch", "shape_mismatch"]):
27-
return cls.kValueTypeOrMetaMismatch
28-
2925
if "compile_fail" in etype:
3026
return cls.kCompilationFailed
3127

32-
return cls.kExecutionFailed
28+
return cls.kRuntimeFailure
3329

3430

3531
class DefaultPositiveToleranceInterpretation(PositiveToleranceInterpretation):
@@ -46,24 +42,17 @@ def type_name(self) -> str:
4642
return "default"
4743

4844
def get_errno(self, error_type: str) -> int:
49-
if not error_type:
50-
return 2
51-
etype = error_type.lower()
52-
if "accuracy" in etype:
53-
return 1
54-
if "compile_fail" in etype:
55-
return 3
56-
return 2
45+
return DefaultErrorEnum.get_error_enum(error_type).value
5746

5847
def get_error_type(self, errno: int) -> str:
5948
mapping = {1: "accuracy", 2: "runtime_fail", 3: "compile_fail"}
6049
return mapping.get(errno, "unknown_error")
6150

6251
def get_tolerance_mapping(self) -> dict[int, int]:
6352
return {
64-
1: 1, # Accuracy -> t >= 1
65-
2: 3, # Runtime -> t >= 3
66-
3: 3, # Compile -> t >= 3
53+
DefaultErrorEnum.kAccuracyViolation.value: 1,
54+
DefaultErrorEnum.kRuntimeFailure.value: 3,
55+
DefaultErrorEnum.kCompilationFailed.value: 3,
6756
}
6857

6958
def is_error_tolerated(self, tolerance: int, base_error_code: str) -> bool:
@@ -72,8 +61,17 @@ def is_error_tolerated(self, tolerance: int, base_error_code: str) -> bool:
7261
if base_error_code in ["eager_fail", "reference_fail"]:
7362
return False
7463

75-
try:
76-
error_level = DefaultErrorEnum.get_error_enum(base_error_code)
77-
return tolerance >= error_level.value
78-
except (ValueError, KeyError):
79-
return False
64+
error_enum = DefaultErrorEnum.get_error_enum(base_error_code)
65+
mapping = self.get_tolerance_mapping()
66+
required_threshold = mapping.get(error_enum.value, 999)
67+
68+
return tolerance >= required_threshold
69+
70+
def num_errno_enum_values(self) -> int:
71+
"""
72+
Default mode defines 3 levels of errors:
73+
1: Accuracy
74+
2: Runtime (Generic)
75+
3: Compilation
76+
"""
77+
return len(DefaultErrorEnum)

graph_net/mismatch_extended_positive_tolerance_interpretation.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,7 @@ def type_name(self) -> str:
4545
return "mismatch_extended"
4646

4747
def get_errno(self, error_type: str) -> int:
48-
if not error_type:
49-
return 3
50-
etype = error_type.lower()
51-
52-
if "accuracy" in etype:
53-
return 1
54-
if any(k in etype for k in ["nan", "inf", "type_mismatch", "shape_mismatch"]):
55-
return 2
56-
if "compile_fail" in etype:
57-
return 4
58-
return 3
48+
return MismatchExtendedErrorEnum.get_error_enum(error_type).value
5949

6050
def get_error_type(self, errno: int) -> str:
6151
mapping = {
@@ -67,15 +57,31 @@ def get_error_type(self, errno: int) -> str:
6757
return mapping.get(errno, "unknown_error")
6858

6959
def get_tolerance_mapping(self) -> dict[int, int]:
70-
return {1: 1, 2: 2, 3: 3, 4: 4}
60+
return {
61+
MismatchExtendedErrorEnum.kAccuracyViolation.value: 1,
62+
MismatchExtendedErrorEnum.kValueTypeOrMetaMismatch.value: 2,
63+
MismatchExtendedErrorEnum.kExecutionFailed.value: 3,
64+
MismatchExtendedErrorEnum.kCompilationFailed.value: 4,
65+
}
7166

7267
def is_error_tolerated(self, tolerance: int, base_error_code: str) -> bool:
7368
if base_error_code == "correct":
7469
return True
7570
if base_error_code in ["eager_fail", "reference_fail"]:
7671
return False
77-
try:
78-
error_level = MismatchExtendedErrorEnum.get_error_enum(base_error_code)
79-
return tolerance >= error_level.value
80-
except (ValueError, KeyError):
81-
return False
72+
73+
error_enum = MismatchExtendedErrorEnum.get_error_enum(base_error_code)
74+
mapping = self.get_tolerance_mapping()
75+
required_threshold = mapping.get(error_enum.value, 999)
76+
77+
return tolerance >= required_threshold
78+
79+
def num_errno_enum_values(self) -> int:
80+
"""
81+
Extended mode defines 4 levels of errors:
82+
1: Accuracy
83+
2: Type/Shape/NaN
84+
3: Runtime
85+
4: Compilation
86+
"""
87+
return len(MismatchExtendedErrorEnum)

graph_net/positive_tolerance_interpretation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,14 @@ def is_error_tolerated(self, tolerance: int, base_error_code: str) -> bool:
4040
Replaces the old 'fake_perf_degrad' logic.
4141
"""
4242
raise NotImplementedError
43+
44+
@abstractmethod
45+
def num_errno_enum_values(self) -> int:
46+
"""
47+
Return the number of defined error categories (or the maximum errno).
48+
49+
Example:
50+
- Default: returns 3 (Accuracy, Runtime, Compile)
51+
- MismatchExtended: returns 4 (Accuracy, Data, Runtime, Compile)
52+
"""
53+
raise NotImplementedError

graph_net/samples_statistics.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def calculate_pi(
155155

156156
def resolve_errno_tolerance(
157157
errno2count: dict[Union[int, str], int],
158+
positive_tolerance_interpretation: PositiveToleranceInterpretation,
158159
errno_tolerance_overrides: Optional[dict[Union[int, str], int]] = None,
159160
) -> dict[Union[int, str], int]:
160161
"""
@@ -176,12 +177,20 @@ def resolve_errno_tolerance(
176177
"""
177178
errno_tolerance_overrides = errno_tolerance_overrides or {}
178179

179-
def tolerance_for(errno: Union[int, str]) -> int:
180-
if errno in errno_tolerance_overrides:
181-
return errno_tolerance_overrides[errno]
180+
base_mapping = positive_tolerance_interpretation.get_tolerance_mapping()
181+
182+
def tolerance_for(err_key: Union[int, str]) -> int:
183+
if err_key in errno_tolerance_overrides:
184+
return errno_tolerance_overrides[err_key]
182185

183-
if isinstance(errno, int):
184-
return 1 if errno == 1 else 3
186+
errno_id = None
187+
if isinstance(err_key, int):
188+
errno_id = err_key
189+
elif isinstance(err_key, str):
190+
errno_id = positive_tolerance_interpretation.get_errno(err_key)
191+
192+
if errno_id is not None and errno_id in base_mapping:
193+
return base_mapping[errno_id]
185194

186195
return 999
187196

@@ -310,7 +319,9 @@ def calculate_es_components_values(
310319
errno_to_tolerance, positive_tolerance_interpretation
311320
)
312321

313-
errno2tolerance = resolve_errno_tolerance(errno2count, errno_to_tolerance)
322+
errno2tolerance = resolve_errno_tolerance(
323+
errno2count, positive_tolerance_interpretation, errno_to_tolerance
324+
)
314325

315326
def pi_value4errno(errno: Union[int, str]) -> float:
316327
return pi.get(errno, 0.0)

graph_net/verify_aggregated_params.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,11 @@ def determine_tolerances(
1313
positive_tolerance_interpretation: PositiveToleranceInterpretation,
1414
) -> range:
1515
"""Determine tolerance range based on observed errno categories."""
16-
# Currently errno categories are 1=accuracy, 2=runtime, 3=compile.
17-
# Keep logic data-driven for future extension.
18-
mapping = positive_tolerance_interpretation.get_tolerance_mapping()
19-
20-
if not mapping:
21-
max_errno = 3
16+
if samples:
17+
max_errno = len(samples)
2218
else:
23-
max_errno = max(mapping.keys())
19+
max_errno = positive_tolerance_interpretation.num_errno_enum_values()
20+
2421
return range(-10, max_errno + 2)
2522

2623

0 commit comments

Comments
 (0)