Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 70 additions & 13 deletions codeflash/verification/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import types
from collections import ChainMap, OrderedDict, deque
from importlib.util import find_spec
from typing import Any
from typing import Any, Optional

import sentry_sdk

Expand All @@ -26,9 +26,78 @@
HAS_TENSORFLOW = find_spec("tensorflow") is not None


def _extract_exception_from_message(msg: str) -> Optional[BaseException]: # noqa: FA100
"""Try to extract a wrapped exception type from an error message.

Looks for patterns like "got ExceptionType('..." that indicate a wrapped exception.
Returns a synthetic exception of that type if found in builtins, None otherwise.
"""
# Pattern: "got ExceptionType('message')" or "got ExceptionType("message")"
# This pattern is used by torch._dynamo and potentially other libraries
match = re.search(r"got (\w+)\(['\"]", msg)
if match:
exc_name = match.group(1)
# Try to find this exception type in builtins
import builtins

exc_class = getattr(builtins, exc_name, None)
if exc_class is not None and isinstance(exc_class, type) and issubclass(exc_class, BaseException):
return exc_class()
return None


def _get_wrapped_exception(exc: BaseException) -> Optional[BaseException]: # noqa: FA100
"""Get the wrapped exception if this is a simple wrapper.
Comment on lines +49 to +50
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe I should just rename it to Claude.md since we've all started using it, will do that


Returns the inner exception if:
- exc is an ExceptionGroup with exactly one exception
- exc has a __cause__ (explicit chaining via 'raise X from Y')
- exc message contains a wrapped exception type pattern (e.g., "got IndexError('...")")

Returns None if exc is not a wrapper or wraps multiple exceptions.
"""
# Check for ExceptionGroup with single exception (Python 3.11+)
if hasattr(exc, "exceptions"):
exceptions = exc.exceptions
if len(exceptions) == 1:
return exceptions[0]
# Check for explicit exception chaining (__cause__)
if exc.__cause__ is not None:
return exc.__cause__
# Try to extract wrapped exception type from the message (library-agnostic)
return _extract_exception_from_message(str(exc))


def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911
"""Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent."""
try:
# Handle exceptions specially - before type check to allow wrapper comparison
if isinstance(orig, BaseException) and isinstance(new, BaseException):
if isinstance(orig, PicklePlaceholderAccessError) or isinstance(new, PicklePlaceholderAccessError):
# If this error was raised, there was an attempt to access the PicklePlaceholder, which represents an unpickleable object.
# The test results should be rejected as the behavior of the unpickleable object is unknown.
logger.debug("Unable to verify behavior of unpickleable object in replay test")
return False

# If types match exactly, compare attributes
if type(orig) is type(new):
orig_dict = {k: v for k, v in orig.__dict__.items() if not k.startswith("_")}
new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")}
return comparator(orig_dict, new_dict, superset_obj)

# Types differ - check if one is a wrapper over the other
# Check if orig wraps something that matches new
wrapped_orig = _get_wrapped_exception(orig)
if wrapped_orig is not None and comparator(wrapped_orig, new, superset_obj):
return True

# Check if new wraps something that matches orig
wrapped_new = _get_wrapped_exception(new)
if wrapped_new is not None and comparator(orig, wrapped_new, superset_obj): # noqa: SIM103
return True

return False

if type(orig) is not type(new):
type_obj = type(orig)
new_type_obj = type(new)
Expand Down Expand Up @@ -67,18 +136,6 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
if math.isnan(orig) and math.isnan(new):
return True
return math.isclose(orig, new)
if isinstance(orig, BaseException):
if isinstance(orig, PicklePlaceholderAccessError) or isinstance(new, PicklePlaceholderAccessError):
# If this error was raised, there was an attempt to access the PicklePlaceholder, which represents an unpickleable object.
# The test results should be rejected as the behavior of the unpickleable object is unknown.
logger.debug("Unable to verify behavior of unpickleable object in replay test")
return False
# if str(orig) != str(new):
# return False
# compare the attributes of the two exception objects to determine if they are equivalent.
orig_dict = {k: v for k, v in orig.__dict__.items() if not k.startswith("_")}
new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")}
return comparator(orig_dict, new_dict, superset_obj)

if HAS_JAX:
import jax # type: ignore # noqa: PGH003
Expand Down
Loading
Loading