|
8 | 8 | import types |
9 | 9 | from collections import ChainMap, OrderedDict, deque |
10 | 10 | from importlib.util import find_spec |
11 | | -from typing import Any |
| 11 | +from typing import Any, Optional |
12 | 12 |
|
13 | 13 | import sentry_sdk |
14 | 14 |
|
|
26 | 26 | HAS_TENSORFLOW = find_spec("tensorflow") is not None |
27 | 27 |
|
28 | 28 |
|
| 29 | +def _extract_exception_from_message(msg: str) -> Optional[BaseException]: # noqa: FA100 |
| 30 | + """Try to extract a wrapped exception type from an error message. |
| 31 | +
|
| 32 | + Looks for patterns like "got ExceptionType('..." that indicate a wrapped exception. |
| 33 | + Returns a synthetic exception of that type if found in builtins, None otherwise. |
| 34 | + """ |
| 35 | + # Pattern: "got ExceptionType('message')" or "got ExceptionType("message")" |
| 36 | + # This pattern is used by torch._dynamo and potentially other libraries |
| 37 | + match = re.search(r"got (\w+)\(['\"]", msg) |
| 38 | + if match: |
| 39 | + exc_name = match.group(1) |
| 40 | + # Try to find this exception type in builtins |
| 41 | + import builtins |
| 42 | + |
| 43 | + exc_class = getattr(builtins, exc_name, None) |
| 44 | + if exc_class is not None and isinstance(exc_class, type) and issubclass(exc_class, BaseException): |
| 45 | + return exc_class() |
| 46 | + return None |
| 47 | + |
| 48 | + |
| 49 | +def _get_wrapped_exception(exc: BaseException) -> Optional[BaseException]: # noqa: FA100 |
| 50 | + """Get the wrapped exception if this is a simple wrapper. |
| 51 | +
|
| 52 | + Returns the inner exception if: |
| 53 | + - exc is an ExceptionGroup with exactly one exception |
| 54 | + - exc has a __cause__ (explicit chaining via 'raise X from Y') |
| 55 | + - exc message contains a wrapped exception type pattern (e.g., "got IndexError('...")") |
| 56 | +
|
| 57 | + Returns None if exc is not a wrapper or wraps multiple exceptions. |
| 58 | + """ |
| 59 | + # Check for ExceptionGroup with single exception (Python 3.11+) |
| 60 | + if hasattr(exc, "exceptions"): |
| 61 | + exceptions = exc.exceptions |
| 62 | + if len(exceptions) == 1: |
| 63 | + return exceptions[0] |
| 64 | + # Check for explicit exception chaining (__cause__) |
| 65 | + if exc.__cause__ is not None: |
| 66 | + return exc.__cause__ |
| 67 | + # Try to extract wrapped exception type from the message (library-agnostic) |
| 68 | + return _extract_exception_from_message(str(exc)) |
| 69 | + |
| 70 | + |
29 | 71 | def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001, ANN401, FBT002, PLR0911 |
30 | 72 | """Compare two objects for equality recursively. If superset_obj is True, the new object is allowed to have more keys than the original object. However, the existing keys/values must be equivalent.""" |
31 | 73 | try: |
| 74 | + # Handle exceptions specially - before type check to allow wrapper comparison |
| 75 | + if isinstance(orig, BaseException) and isinstance(new, BaseException): |
| 76 | + if isinstance(orig, PicklePlaceholderAccessError) or isinstance(new, PicklePlaceholderAccessError): |
| 77 | + # If this error was raised, there was an attempt to access the PicklePlaceholder, which represents an unpickleable object. |
| 78 | + # The test results should be rejected as the behavior of the unpickleable object is unknown. |
| 79 | + logger.debug("Unable to verify behavior of unpickleable object in replay test") |
| 80 | + return False |
| 81 | + |
| 82 | + # If types match exactly, compare attributes |
| 83 | + if type(orig) is type(new): |
| 84 | + orig_dict = {k: v for k, v in orig.__dict__.items() if not k.startswith("_")} |
| 85 | + new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")} |
| 86 | + return comparator(orig_dict, new_dict, superset_obj) |
| 87 | + |
| 88 | + # Types differ - check if one is a wrapper over the other |
| 89 | + # Check if orig wraps something that matches new |
| 90 | + wrapped_orig = _get_wrapped_exception(orig) |
| 91 | + if wrapped_orig is not None and comparator(wrapped_orig, new, superset_obj): |
| 92 | + return True |
| 93 | + |
| 94 | + # Check if new wraps something that matches orig |
| 95 | + wrapped_new = _get_wrapped_exception(new) |
| 96 | + if wrapped_new is not None and comparator(orig, wrapped_new, superset_obj): # noqa: SIM103 |
| 97 | + return True |
| 98 | + |
| 99 | + return False |
| 100 | + |
32 | 101 | if type(orig) is not type(new): |
33 | 102 | type_obj = type(orig) |
34 | 103 | new_type_obj = type(new) |
@@ -67,18 +136,6 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001 |
67 | 136 | if math.isnan(orig) and math.isnan(new): |
68 | 137 | return True |
69 | 138 | return math.isclose(orig, new) |
70 | | - if isinstance(orig, BaseException): |
71 | | - if isinstance(orig, PicklePlaceholderAccessError) or isinstance(new, PicklePlaceholderAccessError): |
72 | | - # If this error was raised, there was an attempt to access the PicklePlaceholder, which represents an unpickleable object. |
73 | | - # The test results should be rejected as the behavior of the unpickleable object is unknown. |
74 | | - logger.debug("Unable to verify behavior of unpickleable object in replay test") |
75 | | - return False |
76 | | - # if str(orig) != str(new): |
77 | | - # return False |
78 | | - # compare the attributes of the two exception objects to determine if they are equivalent. |
79 | | - orig_dict = {k: v for k, v in orig.__dict__.items() if not k.startswith("_")} |
80 | | - new_dict = {k: v for k, v in new.__dict__.items() if not k.startswith("_")} |
81 | | - return comparator(orig_dict, new_dict, superset_obj) |
82 | 139 |
|
83 | 140 | if HAS_JAX: |
84 | 141 | import jax # type: ignore # noqa: PGH003 |
|
0 commit comments