@@ -19,49 +19,48 @@ def detect_sample_error_code(log_text: str) -> str:
1919 Error code string. Possible values:
2020 - "correct": Sample executed successfully
2121 - "eager_fail": Eager model execution failed
22+ - "compile_fail": Compiled model failed to load
2223 - "shape_mismatch": Output shape mismatch between eager and compiled
2324 - "type_mismatch": Data type mismatch between eager and compiled
2425 - "runtime_fail": Runtime error during execution
25- - "unknown": Unable to determine error type
2626 """
2727 lines = log_text .split ("\n " ) if isinstance (log_text , str ) else log_text
2828
2929 # Track phase status and mismatch types
30- eager_status = None
31- shape_mismatch = False
32- type_mismatch = False
30+ eager_success = False
31+ compile_success = False
32+ shape_match = False
33+ type_match = False
34+ runtime_fail = False
3335
3436 # Scan for status and mismatch markers
3537 for line in lines :
36- if "[Result][status]" in line and "eager:" in line :
37- eager_status = line .split ("eager:" )[1 ].strip ()
38- elif "[Shape] eager:" in line and "compiled:" in line :
39- if "match:False" in line :
40- shape_mismatch = True
41- elif "[DataType] eager:" in line and "compiled:" in line :
42- if "match:False" in line :
43- type_mismatch = True
38+ if "[Result][status] eager:success" in line :
39+ eager_success = True
40+ elif "[Datatype][compiled]" in line :
41+ compile_success = True
42+ elif "[Shape]" in line and "match:True" in line :
43+ shape_match = True
44+ elif "[DataType]" in line and "match:True" in line :
45+ type_match = True
46+
47+ if any ("Exception:" in line or "Error:" in line for line in lines ):
48+ runtime_fail = True
4449
4550 # Determine error type
46- if eager_status == "failed" :
51+ if not eager_success :
4752 return "eager_fail"
48- elif shape_mismatch :
53+ elif not compile_success :
54+ return "compile_fail"
55+ elif not shape_match :
4956 return "shape_mismatch"
50- elif type_mismatch :
57+ elif not type_match :
5158 return "type_mismatch"
52- elif eager_status == "success" :
53- return "correct"
54-
55- # Check for runtime errors if no explicit status markers
56- if any ("Exception:" in line or "Error:" in line for line in lines ):
59+ elif runtime_fail :
5760 return "runtime_fail"
58-
59- # Final fallback - check if there's any performance data
60- if any ("[Performance]" in line and ":" in line for line in lines ):
61+ else :
6162 return "correct"
6263
63- return "unknown"
64-
6564
6665def parse_single_sample_log_to_data (log_text : str ) -> dict :
6766 """
@@ -286,15 +285,34 @@ def get_correctness(dtype: str, t: int, correctness_data: dict, index: int) -> b
286285 return False
287286
288287
289- def fake_perf_degrad (tolerance , error_code ) -> str :
288+ def fake_perf_degrad (tolerance , error_code , type = "default" ) -> str :
290289 """
291- Calculate current correctness based on tolerance t and error code.
290+ Judge current correctness based on tolerance t and error code.
292291 """
293- if error_code == "accuracy" and tolerance >= 1 :
294- return "correct"
295- elif tolerance >= 3 :
296- return "correct"
297- return error_code
292+ if type == "default" :
293+ if tolerance >= 3 :
294+ return "correct"
295+ elif error_code == "accuracy" and tolerance >= 1 :
296+ return "correct"
297+ else :
298+ return error_code
299+ elif type == "extended" :
300+ if (
301+ error_code == "compile_fail" or error_code == "runtime_fail"
302+ ) and tolerance >= 4 :
303+ return "correct"
304+ elif error_code == "eager_fail" and tolerance >= 3 :
305+ return "correct"
306+ elif (
307+ error_code == "shape_mismatch" or error_code == "type_mismatch"
308+ ) and tolerance >= 2 :
309+ return "correct"
310+ elif error_code == "accuracy" and tolerance >= 1 :
311+ return "correct"
312+ else :
313+ return error_code
314+ else :
315+ raise NotImplementedError
298316
299317
300318def calculate_scores (
0 commit comments