Skip to content

Commit 3418b61

Browse files
authored
Merge pull request #1120 from codeflash-ai/comparator-wrapped-exceptions
Comparator for wrapped exceptions
2 parents 8245bb5 + 332f888 commit 3418b61

File tree

2 files changed

+378
-14
lines changed

2 files changed

+378
-14
lines changed

codeflash/verification/comparator.py

Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import types
99
from collections import ChainMap, OrderedDict, deque
1010
from importlib.util import find_spec
11-
from typing import Any
11+
from typing import Any, Optional
1212

1313
import sentry_sdk
1414

@@ -26,9 +26,78 @@
2626
HAS_TENSORFLOW = find_spec("tensorflow") is not None
2727

2828

29+
def _extract_exception_from_message(msg: str) -> Optional[BaseException]: # noqa: FA100
30+
"""Try to extract a wrapped exception type from an error message.
31+
32+
Looks for patterns like "got ExceptionType('..." that indicate a wrapped exception.
33+
Returns a synthetic exception of that type if found in builtins, None otherwise.
34+
"""
35+
# Pattern: "got ExceptionType('message')" or "got ExceptionType("message")"
36+
# This pattern is used by torch._dynamo and potentially other libraries
37+
match = re.search(r"got (\w+)\(['\"]", msg)
38+
if match:
39+
exc_name = match.group(1)
40+
# Try to find this exception type in builtins
41+
import builtins
42+
43+
exc_class = getattr(builtins, exc_name, None)
44+
if exc_class is not None and isinstance(exc_class, type) and issubclass(exc_class, BaseException):
45+
return exc_class()
46+
return None
47+
48+
49+
def _get_wrapped_exception(exc: BaseException) -> Optional[BaseException]: # noqa: FA100
50+
"""Get the wrapped exception if this is a simple wrapper.
51+
52+
Returns the inner exception if:
53+
- exc is an ExceptionGroup with exactly one exception
54+
- exc has a __cause__ (explicit chaining via 'raise X from Y')
55+
- exc message contains a wrapped exception type pattern (e.g., "got IndexError('...")")
56+
57+
Returns None if exc is not a wrapper or wraps multiple exceptions.
58+
"""
59+
# Check for ExceptionGroup with single exception (Python 3.11+)
60+
if hasattr(exc, "exceptions"):
61+
exceptions = exc.exceptions
62+
if len(exceptions) == 1:
63+
return exceptions[0]
64+
# Check for explicit exception chaining (__cause__)
65+
if exc.__cause__ is not None:
66+
return exc.__cause__
67+
# Try to extract wrapped exception type from the message (library-agnostic)
68+
return _extract_exception_from_message(str(exc))
69+
70+
2971
def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
3072
"""Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent."""
3173
try:
74+
# Handle exceptions specially - before type check to allow wrapper comparison
75+
if isinstance(orig, BaseException) and isinstance(new, BaseException):
76+
if isinstance(orig, PicklePlaceholderAccessError) or isinstance(new, PicklePlaceholderAccessError):
77+
# If this error was raised, there was an attempt to access the PicklePlaceholder, which represents an unpickleable object.
78+
# The test results should be rejected as the behavior of the unpickleable object is unknown.
79+
logger.debug("Unable to verify behavior of unpickleable object in replay test")
80+
return False
81+
82+
# If types match exactly, compare attributes
83+
if type(orig) is type(new):
84+
orig_dict = {k: v for k, v in orig.__dict__.items() if not k.startswith("_")}
85+
new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")}
86+
return comparator(orig_dict, new_dict, superset_obj)
87+
88+
# Types differ - check if one is a wrapper over the other
89+
# Check if orig wraps something that matches new
90+
wrapped_orig = _get_wrapped_exception(orig)
91+
if wrapped_orig is not None and comparator(wrapped_orig, new, superset_obj):
92+
return True
93+
94+
# Check if new wraps something that matches orig
95+
wrapped_new = _get_wrapped_exception(new)
96+
if wrapped_new is not None and comparator(orig, wrapped_new, superset_obj): # noqa: SIM103
97+
return True
98+
99+
return False
100+
32101
if type(orig) is not type(new):
33102
type_obj = type(orig)
34103
new_type_obj = type(new)
@@ -67,18 +136,6 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
67136
if math.isnan(orig) and math.isnan(new):
68137
return True
69138
return math.isclose(orig, new)
70-
if isinstance(orig, BaseException):
71-
if isinstance(orig, PicklePlaceholderAccessError) or isinstance(new, PicklePlaceholderAccessError):
72-
# If this error was raised, there was an attempt to access the PicklePlaceholder, which represents an unpickleable object.
73-
# The test results should be rejected as the behavior of the unpickleable object is unknown.
74-
logger.debug("Unable to verify behavior of unpickleable object in replay test")
75-
return False
76-
# if str(orig) != str(new):
77-
# return False
78-
# compare the attributes of the two exception objects to determine if they are equivalent.
79-
orig_dict = {k: v for k, v in orig.__dict__.items() if not k.startswith("_")}
80-
new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")}
81-
return comparator(orig_dict, new_dict, superset_obj)
82139

83140
if HAS_JAX:
84141
import jax # type: ignore # noqa: PGH003

0 commit comments

Comments
 (0)