Skip to content

Commit 9557747

Browse files
committed
tidy error code
1 parent 9eb8dbd commit 9557747

File tree

1 file changed

+50
-32
lines changed

1 file changed

+50
-32
lines changed

graph_net/analysis_util.py

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,49 +19,48 @@ def detect_sample_error_code(log_text: str) -> str:
1919
Error code string. Possible values:
2020
- "correct": Sample executed successfully
2121
- "eager_fail": Eager model execution failed
22+
- "compile_fail": Compiled model failed to load
2223
- "shape_mismatch": Output shape mismatch between eager and compiled
2324
- "type_mismatch": Data type mismatch between eager and compiled
2425
- "runtime_fail": Runtime error during execution
25-
- "unknown": Unable to determine error type
2626
"""
2727
lines = log_text.split("\n") if isinstance(log_text, str) else log_text
2828

2929
# Track phase status and mismatch types
30-
eager_status = None
31-
shape_mismatch = False
32-
type_mismatch = False
30+
eager_success = False
31+
compile_success = False
32+
shape_match = False
33+
type_match = False
34+
runtime_fail = False
3335

3436
# Scan for status and mismatch markers
3537
for line in lines:
36-
if "[Result][status]" in line and "eager:" in line:
37-
eager_status = line.split("eager:")[1].strip()
38-
elif "[Shape] eager:" in line and "compiled:" in line:
39-
if "match:False" in line:
40-
shape_mismatch = True
41-
elif "[DataType] eager:" in line and "compiled:" in line:
42-
if "match:False" in line:
43-
type_mismatch = True
38+
if "[Result][status] eager:success" in line:
39+
eager_success = True
40+
elif "[Datatype][compiled]" in line:
41+
compile_success = True
42+
elif "[Shape]" in line and "match:True" in line:
43+
shape_match = True
44+
elif "[DataType]" in line and "match:True" in line:
45+
type_match = True
46+
47+
if any("Exception:" in line or "Error:" in line for line in lines):
48+
runtime_fail = True
4449

4550
# Determine error type
46-
if eager_status == "failed":
51+
if not eager_success:
4752
return "eager_fail"
48-
elif shape_mismatch:
53+
elif not compile_success:
54+
return "compile_fail"
55+
elif not shape_match:
4956
return "shape_mismatch"
50-
elif type_mismatch:
57+
elif not type_match:
5158
return "type_mismatch"
52-
elif eager_status == "success":
53-
return "correct"
54-
55-
# Check for runtime errors if no explicit status markers
56-
if any("Exception:" in line or "Error:" in line for line in lines):
59+
elif runtime_fail:
5760
return "runtime_fail"
58-
59-
# Final fallback - check if there's any performance data
60-
if any("[Performance]" in line and ":" in line for line in lines):
61+
else:
6162
return "correct"
6263

63-
return "unknown"
64-
6564

6665
def parse_single_sample_log_to_data(log_text: str) -> dict:
6766
"""
@@ -286,15 +285,34 @@ def get_correctness(dtype: str, t: int, correctness_data: dict, index: int) -> b
286285
return False
287286

288287

289-
def fake_perf_degrad(tolerance, error_code) -> str:
288+
def fake_perf_degrad(tolerance, error_code, type="default") -> str:
290289
"""
291-
Calculate current correctness based on tolerance t and error code.
290+
Judge current correctness based on tolerance t and error code.
292291
"""
293-
if error_code == "accuracy" and tolerance >= 1:
294-
return "correct"
295-
elif tolerance >= 3:
296-
return "correct"
297-
return error_code
292+
if type == "default":
293+
if tolerance >= 3:
294+
return "correct"
295+
elif error_code == "accuracy" and tolerance >= 1:
296+
return "correct"
297+
else:
298+
return error_code
299+
elif type == "extended":
300+
if (
301+
error_code == "compile_fail" or error_code == "runtime_fail"
302+
) and tolerance >= 4:
303+
return "correct"
304+
elif error_code == "eager_fail" and tolerance >= 3:
305+
return "correct"
306+
elif (
307+
error_code == "shape_mismatch" or error_code == "type_mismatch"
308+
) and tolerance >= 2:
309+
return "correct"
310+
elif error_code == "accuracy" and tolerance >= 1:
311+
return "correct"
312+
else:
313+
return error_code
314+
else:
315+
raise NotImplementedError
298316

299317

300318
def calculate_scores(

0 commit comments

Comments
 (0)