-
Notifications
You must be signed in to change notification settings - Fork 21
torch compile exception wrapper handling for better comparator #1119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -8,14 +8,72 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import types | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from collections import ChainMap, OrderedDict, deque | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from importlib.util import find_spec | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from typing import Any, Optional | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import sentry_sdk | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from codeflash.cli_cmds.console import logger | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from codeflash.picklepatch.pickle_placeholder import PicklePlaceholderAccessError | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| HAS_NUMPY = find_spec("numpy") is not None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _is_torch_runtime_error(exc: BaseException) -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Check if an exception is a TorchRuntimeError from torch.compile/Dynamo.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Check by class name to avoid importing torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return type(exc).__name__ == "TorchRuntimeError" and "torch._dynamo" in type(exc).__module__ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _get_wrapped_exception_from_torch_runtime_error(exc: BaseException) -> Optional[BaseException]: # noqa: FA100 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Extract the underlying exception wrapped by TorchRuntimeError. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TorchRuntimeError from torch.compile/Dynamo wraps the original exception | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| using Python's exception chaining (__cause__). | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not _is_torch_runtime_error(exc): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # First try to get the chained exception via __cause__ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if exc.__cause__ is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return exc.__cause__ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _exceptions_are_equivalent(orig: BaseException, new: BaseException) -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Check if two exceptions are semantically equivalent. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| This handles the case where torch.compile wraps exceptions in TorchRuntimeError | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| while the original code raises the underlying exception directly. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # If types match, they're potentially equivalent (will be compared further) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if type(orig).__name__ == type(new).__name__: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Check if one is a TorchRuntimeError wrapping the other's type | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if _is_torch_runtime_error(new): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| wrapped = _get_wrapped_exception_from_torch_runtime_error(new) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if wrapped is not None and type(wrapped).__name__ == type(orig).__name__: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Also check the error message for the wrapped exception type | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # TorchRuntimeError message contains "got ExceptionType('...')" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| error_msg = str(new) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| orig_type_name = type(orig).__name__ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if f"got {orig_type_name}(" in error_msg or f"got {orig_type_name}:" in error_msg: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if _is_torch_runtime_error(orig): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| wrapped = _get_wrapped_exception_from_torch_runtime_error(orig) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if wrapped is not None and type(wrapped).__name__ == type(new).__name__: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| error_msg = str(orig) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| new_type_name = type(new).__name__ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+49
to
+70
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ⚡️Codeflash found 16% (0.16x) speedup for
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Test | Status |
|---|---|
| ⚙️ Existing Unit Tests | 🔘 None Found |
| 🌀 Generated Regression Tests | ✅ 2938 Passed |
| ⏪ Replay Tests | 🔘 None Found |
| 🔎 Concolic Coverage Tests | ✅ 1 Passed |
| 📊 Tests Coverage | 78.9% |
🌀 Click to see Generated Regression Tests
from typing import Optional
# imports
from codeflash.verification.comparator import _exceptions_are_equivalent
def _is_torch_runtime_error(exc: BaseException) -> bool:
"""Check if an exception is a TorchRuntimeError from torch.compile/Dynamo."""
# Check by class name to avoid importing torch
return type(exc).__name__ == "TorchRuntimeError" and "torch._dynamo" in type(exc).__module__
def _get_wrapped_exception_from_torch_runtime_error(exc: BaseException) -> Optional[BaseException]: # noqa: FA100
"""Extract the underlying exception wrapped by TorchRuntimeError.
TorchRuntimeError from torch.compile/Dynamo wraps the original exception
using Python's exception chaining (__cause__).
"""
if not _is_torch_runtime_error(exc):
return None
# First try to get the chained exception via __cause__
if exc.__cause__ is not None:
return exc.__cause__
return None
def test_same_builtin_type_is_equivalent():
# Basic case: identical built-in exception types should be equivalent
orig = ValueError("original")
new = ValueError("new")
# types have the same __name__ so should be True
codeflash_output = _exceptions_are_equivalent(orig, new) # 882ns -> 861ns (2.44% faster)
def test_same_type_name_different_classes_equivalent():
# If types have the same name (even if different classes), equivalence is based on name
# Create a custom exception class named 'ValueError' (different object than built-in)
CustomValueError = type("ValueError", (Exception,), {}) # class name matches built-in
orig = ValueError("built-in") # built-in ValueError
new = CustomValueError("custom") # custom class with the same name
# Should be True because only the __name__ is compared in the function's first check
codeflash_output = _exceptions_are_equivalent(orig, new) # 791ns -> 802ns (1.37% slower)
def test_torch_runtime_error_with_cause_matches_orig():
# Simulate a TorchRuntimeError wrapper by creating a class with the same name
# and a module path containing 'torch._dynamo'
TorchRuntimeErrorFake = type("TorchRuntimeError", (Exception,), {})
TorchRuntimeErrorFake.__module__ = "torch._dynamo.fake_module" # include required substring
# Underlying exception that was originally raised
underlying = TypeError("underlying problem")
# Create the fake TorchRuntimeError and attach the underlying exception via __cause__
wrapped = TorchRuntimeErrorFake("wrapper message")
wrapped.__cause__ = underlying
# orig is the underlying exception, new is the wrapper: should be considered equivalent
codeflash_output = _exceptions_are_equivalent(underlying, wrapped) # 2.23μs -> 1.95μs (14.3% faster)
def test_torch_runtime_error_wraps_new_side_equivalence():
# Reverse of previous: orig is a wrapper, new is the underlying exception
TorchRuntimeErrorFake = type("TorchRuntimeError", (Exception,), {})
TorchRuntimeErrorFake.__module__ = "torch._dynamo.other"
underlying = KeyError("missing")
wrapper = TorchRuntimeErrorFake("wrapped key error")
wrapper.__cause__ = underlying
# orig is wrapper, new is underlying -> should be True
codeflash_output = _exceptions_are_equivalent(wrapper, underlying) # 2.37μs -> 1.88μs (26.0% faster)
def test_torch_runtime_error_message_parenthesis_format_matches():
# TorchRuntimeError without a __cause__ but with a message containing "got TypeError('...')"
TorchRuntimeErrorFake = type("TorchRuntimeError", (Exception,), {})
TorchRuntimeErrorFake.__module__ = "torch._dynamo.something"
# Craft message that matches the "got <TypeName>(" pattern
msg = "some runtime wrapper; got TypeError('oops') while compiling"
wrapper = TorchRuntimeErrorFake(msg)
orig = TypeError("oops")
# Since the message contains "got TypeError(" the function should report equivalence
codeflash_output = _exceptions_are_equivalent(orig, wrapper) # 2.50μs -> 2.38μs (5.08% faster)
def test_torch_runtime_error_message_colon_format_matches():
# TorchRuntimeError message with the "got TypeError:" pattern should match as well
TorchRuntimeErrorFake = type("TorchRuntimeError", (Exception,), {})
TorchRuntimeErrorFake.__module__ = "torch._dynamo"
msg = "error occurred; got TypeError: details follow"
wrapper = TorchRuntimeErrorFake(msg)
orig = TypeError("details follow")
# Should be True because of the "got TypeError:" pattern
codeflash_output = _exceptions_are_equivalent(orig, wrapper) # 2.48μs -> 2.31μs (7.81% faster)
def test_non_equivalent_different_types():
# Completely different exception types without any torch wrapper should not be equivalent
orig = IndexError("index fail")
new = ValueError("value fail")
codeflash_output = _exceptions_are_equivalent(orig, new) # 1.38μs -> 942ns (46.8% faster)
def test_torch_runtime_error_with_incorrect_module_not_recognized():
# Create a class named TorchRuntimeError but with module that does NOT contain 'torch._dynamo'
TorchRuntimeErrorNotDynamo = type("TorchRuntimeError", (Exception,), {})
TorchRuntimeErrorNotDynamo.__module__ = "some.other.module"
underlying = RuntimeError("bad")
wrapper = TorchRuntimeErrorNotDynamo("wrapper")
wrapper.__cause__ = underlying
# Because the module doesn't include 'torch._dynamo', the wrapper should NOT be recognized,
# and since names differ (TorchRuntimeError vs RuntimeError), result should be False
codeflash_output = _exceptions_are_equivalent(underlying, wrapper) # 1.55μs -> 1.14μs (36.1% faster)
def test_torch_runtime_error_message_partial_no_false_positive():
# Ensure that similar-but-not-matching messages do not produce false positives.
TorchRuntimeErrorFake = type("TorchRuntimeError", (Exception,), {})
TorchRuntimeErrorFake.__module__ = "torch._dynamo.sub"
# Message contains 'got Value(' which is not the exact searched pattern 'got ValueError('
msg = "wrapper says got Value('x') somewhere"
wrapper = TorchRuntimeErrorFake(msg)
orig = ValueError("x")
# Should be False because "got Value(" does not match "got ValueError(" or "got ValueError:"
codeflash_output = _exceptions_are_equivalent(orig, wrapper) # 2.85μs -> 2.60μs (10.0% faster)
def test_large_scale_many_name_based_equivalences():
# Large-scale test: ensure performance & correctness over many name-equality cases.
# We'll create 500 pairs (well under the 1000 element limit) where the type names match.
pairs = []
for i in range(500): # 500 iterations is under the "do not exceed 1000 steps" guidance
# Create two distinct classes that share the same __name__
name = f"BulkExc{i}"
A = type(name, (Exception,), {})
B = type(name, (Exception,), {})
# instantiate exceptions
a_inst = A(f"a {i}")
b_inst = B(f"b {i}")
pairs.append((a_inst, b_inst))
# All pairs should be equivalent because only the __name__ is checked in the first step
for a_inst, b_inst in pairs:
codeflash_output = _exceptions_are_equivalent(a_inst, b_inst) # 109μs -> 112μs (2.28% slower)
def test_torch_runtime_error_message_matches_custom_exception_name():
# Message-based match should work for non-builtins too: orig is custom-named exception
CustomErr = type("CustomErr", (Exception,), {})
TorchRuntimeErrorFake = type("TorchRuntimeError", (Exception,), {})
TorchRuntimeErrorFake.__module__ = "torch._dynamo.extra"
# Message contains the custom exception name in the expected pattern
msg = "compilation failed; got CustomErr('x') at runtime"
wrapper = TorchRuntimeErrorFake(msg)
orig = CustomErr("x")
# Should be True because message contains "got CustomErr("
codeflash_output = _exceptions_are_equivalent(orig, wrapper) # 2.52μs -> 2.46μs (2.48% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.from codeflash.verification.comparator import _exceptions_are_equivalent
def _is_torch_runtime_error(exc: BaseException) -> bool:
"""Check if an exception is a TorchRuntimeError from torch.compile/Dynamo."""
# Check by class name to avoid importing torch
return type(exc).__name__ == "TorchRuntimeError" and "torch._dynamo" in type(exc).__module__
def _get_wrapped_exception_from_torch_runtime_error(exc: BaseException):
"""Extract the underlying exception wrapped by TorchRuntimeError.
TorchRuntimeError from torch.compile/Dynamo wraps the original exception
using Python's exception chaining (__cause__).
"""
if not _is_torch_runtime_error(exc):
return None
# First try to get the chained exception via __cause__
if exc.__cause__ is not None:
return exc.__cause__
return None
class TestBasicFunctionality:
"""Test cases for basic equivalence checking of standard exceptions."""
def test_identical_exception_types(self):
"""Test that two exceptions of the same type are considered equivalent."""
exc1 = ValueError("test error")
exc2 = ValueError("different message")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 922ns -> 852ns (8.22% faster)
def test_identical_standard_exceptions(self):
"""Test equivalence with various standard Python exceptions."""
pairs = [
(TypeError("msg1"), TypeError("msg2")),
(RuntimeError("msg1"), RuntimeError("msg2")),
(KeyError("msg1"), KeyError("msg2")),
(IndexError("msg1"), IndexError("msg2")),
(AttributeError("msg1"), AttributeError("msg2")),
(ZeroDivisionError("msg1"), ZeroDivisionError("msg2")),
]
for exc1, exc2 in pairs:
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 2.79μs -> 2.80μs (0.036% slower)
def test_different_exception_types(self):
"""Test that different exception types are not equivalent."""
exc1 = ValueError("error")
exc2 = TypeError("error")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 1.35μs -> 811ns (66.8% faster)
def test_different_exception_types_all_combinations(self):
"""Test equivalence returns False for various mismatched exception types."""
pairs = [
(ValueError("msg"), TypeError("msg")),
(RuntimeError("msg"), KeyError("msg")),
(IndexError("msg"), AttributeError("msg")),
(ZeroDivisionError("msg"), Exception("msg")),
]
for exc1, exc2 in pairs:
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 3.33μs -> 2.25μs (48.1% faster)
def test_custom_exception_same_type(self):
"""Test equivalence for custom exception classes."""
class CustomError(Exception):
pass
exc1 = CustomError("message1")
exc2 = CustomError("message2")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 601ns -> 611ns (1.64% slower)
def test_custom_exception_different_types(self):
"""Test non-equivalence for different custom exception types."""
class CustomError1(Exception):
pass
class CustomError2(Exception):
pass
exc1 = CustomError1("message")
exc2 = CustomError2("message")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 1.00μs -> 691ns (45.0% faster)
def test_base_exception_type(self):
"""Test with base Exception type."""
exc1 = Exception("message1")
exc2 = Exception("message2")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 902ns -> 801ns (12.6% faster)
def test_exception_with_no_message(self):
"""Test equivalence for exceptions without messages."""
exc1 = ValueError()
exc2 = ValueError()
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 851ns -> 752ns (13.2% faster)
def test_exception_with_empty_vs_nonempty_message(self):
"""Test equivalence is independent of message content."""
exc1 = ValueError("")
exc2 = ValueError("detailed error message")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 851ns -> 731ns (16.4% faster)
class TestEdgeCases:
"""Test cases for edge conditions and unusual scenarios."""
def test_exception_with_special_characters_in_message(self):
"""Test exceptions with special characters don't affect type comparison."""
exc1 = ValueError("Error with @#$%^&*()")
exc2 = ValueError("Error with \n\t\r")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 901ns -> 752ns (19.8% faster)
def test_exception_with_unicode_characters(self):
"""Test exceptions with unicode characters."""
exc1 = ValueError("Error: \u2764\u2665\u263a")
exc2 = ValueError("Error: \u00e9\u00e8\u00ea")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 832ns -> 791ns (5.18% faster)
def test_exception_with_very_long_message(self):
"""Test exceptions with very long messages."""
long_msg1 = "x" * 10000
long_msg2 = "y" * 10000
exc1 = ValueError(long_msg1)
exc2 = ValueError(long_msg2)
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 872ns -> 732ns (19.1% faster)
def test_exception_with_numeric_message(self):
"""Test exceptions with numeric content in messages."""
exc1 = ValueError(12345)
exc2 = ValueError(67890)
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 801ns -> 701ns (14.3% faster)
def test_exception_inheritance_not_equivalent(self):
"""Test that exception inheritance hierarchy doesn't make them equivalent."""
class BaseCustom(Exception):
pass
class DerivedCustom(BaseCustom):
pass
exc1 = BaseCustom("message")
exc2 = DerivedCustom("message")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 1.07μs -> 731ns (46.5% faster)
def test_exception_with_args_tuple(self):
"""Test exceptions constructed with multiple arguments."""
exc1 = ValueError("arg1", "arg2", "arg3")
exc2 = ValueError("different1", "different2")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 832ns -> 752ns (10.6% faster)
def test_exception_with_none_as_message(self):
"""Test exception with None as message."""
exc1 = ValueError(None)
exc2 = ValueError("actual message")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 882ns -> 732ns (20.5% faster)
def test_exception_same_instance_compared(self):
"""Test comparing the same exception instance with itself."""
exc = ValueError("message")
codeflash_output = _exceptions_are_equivalent(exc, exc) # 851ns -> 712ns (19.5% faster)
def test_exception_from_different_modules_same_name(self):
"""Test handling of different classes with same name but different modules."""
# Create mock exception classes with same name but different identities
class ValueError1(BaseException):
pass
class ValueError2(BaseException):
pass
# Rename them to have the same __name__ but different __module__
ValueError1.__name__ = "SameName"
ValueError2.__name__ = "SameName"
ValueError1.__module__ = "module1"
ValueError2.__module__ = "module2"
exc1 = ValueError1("msg")
exc2 = ValueError2("msg")
# They should be equivalent because type names match
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 601ns -> 631ns (4.75% slower)
def test_exception_with_multiline_message(self):
"""Test exceptions with multiline error messages."""
msg1 = "Line1\nLine2\nLine3\nLine4"
msg2 = "Different\nMultiline\nMessage\nHere"
exc1 = RuntimeError(msg1)
exc2 = RuntimeError(msg2)
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 892ns -> 802ns (11.2% faster)
def test_exception_message_containing_type_name(self):
"""Test message that contains other exception type names."""
exc1 = ValueError("This is a ValueError")
exc2 = ValueError("This is a TypeError or KeyError")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 872ns -> 831ns (4.93% faster)
class TestTorchRuntimeErrorHandling:
"""Test cases for TorchRuntimeError wrapping behavior."""
def test_torch_runtime_error_detection(self):
"""Test that TorchRuntimeError can be detected by type name."""
# Create a mock TorchRuntimeError-like exception
class TorchRuntimeError(BaseException):
pass
# Set the module to match torch._dynamo
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
exc = TorchRuntimeError("wrapped error")
def test_torch_runtime_error_false_for_different_module(self):
"""Test that TorchRuntimeError from different module is not detected."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "some_other_module"
TorchRuntimeError.__name__ = "TorchRuntimeError"
exc = TorchRuntimeError("wrapped error")
def test_get_wrapped_exception_with_cause(self):
"""Test extracting wrapped exception from __cause__."""
# Create a TorchRuntimeError-like class
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
# Create the wrapped exception
original_exc = ValueError("original error")
torch_exc = TorchRuntimeError("torch wrapped error")
torch_exc.__cause__ = original_exc
wrapped = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
def test_get_wrapped_exception_returns_none_for_non_torch(self):
"""Test that non-TorchRuntimeError returns None."""
exc = ValueError("regular error")
wrapped = _get_wrapped_exception_from_torch_runtime_error(exc)
def test_get_wrapped_exception_no_cause_returns_none(self):
"""Test TorchRuntimeError without __cause__ returns None."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
torch_exc = TorchRuntimeError("no cause")
wrapped = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
def test_torch_runtime_error_wrapping_via_cause_equivalent(self):
"""Test equivalence when TorchRuntimeError wraps exception via __cause__."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
# Original exception
exc1 = ValueError("original")
# TorchRuntimeError wrapping the original
exc2 = TorchRuntimeError("wrapped")
exc2.__cause__ = ValueError("wrapped original")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 2.23μs -> 2.17μs (2.76% faster)
def test_torch_runtime_error_wrapping_different_type_not_equivalent(self):
"""Test non-equivalence when TorchRuntimeError wraps different type."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
# Original exception
exc1 = ValueError("original")
# TorchRuntimeError wrapping a different type
exc2 = TorchRuntimeError("wrapped")
exc2.__cause__ = TypeError("wrapped different")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 2.92μs -> 2.48μs (18.2% faster)
def test_torch_runtime_error_message_pattern_matching(self):
"""Test TorchRuntimeError message pattern matching for type name."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
# Original exception
exc1 = ValueError("original")
# TorchRuntimeError with message containing the wrapped type
exc2 = TorchRuntimeError("got ValueError('wrapped error')")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 2.38μs -> 2.13μs (11.7% faster)
def test_torch_runtime_error_message_pattern_colon_format(self):
"""Test TorchRuntimeError message pattern with colon format."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
# Original exception
exc1 = RuntimeError("original")
# TorchRuntimeError with message containing the wrapped type (colon format)
exc2 = TorchRuntimeError("got RuntimeError: wrapped error")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 2.44μs -> 2.29μs (6.96% faster)
def test_torch_runtime_error_reverse_wrapping(self):
"""Test equivalence when original is TorchRuntimeError wrapping new type."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
# Original as TorchRuntimeError
exc1 = TorchRuntimeError("wrapped")
exc1.__cause__ = ValueError("wrapped original")
# New exception
exc2 = ValueError("new")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 2.33μs -> 1.91μs (21.5% faster)
def test_torch_runtime_error_both_sides(self):
"""Test when both exceptions are TorchRuntimeError with different wrapped types."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
# Both as TorchRuntimeError wrapping different types
exc1 = TorchRuntimeError("wrapped1")
exc1.__cause__ = ValueError("val error")
exc2 = TorchRuntimeError("wrapped2")
exc2.__cause__ = TypeError("type error")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 530ns -> 651ns (18.6% slower)
class TestComplexScenarios:
"""Test cases combining multiple features and complex scenarios."""
def test_exception_chain_with_context(self):
"""Test exceptions that are part of an exception chain."""
try:
try:
raise ValueError("original")
except ValueError as e:
raise TypeError("new") from e
except TypeError as e:
exc1 = e
exc2 = TypeError("different message")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 942ns -> 921ns (2.28% faster)
def test_multiple_exception_types_in_sequence(self):
"""Test multiple different exception types don't become equivalent."""
exceptions = [ValueError("msg"), TypeError("msg"), RuntimeError("msg"), KeyError("msg"), IndexError("msg")]
for i in range(len(exceptions)):
for j in range(len(exceptions)):
if i == j:
codeflash_output = _exceptions_are_equivalent(exceptions[i], exceptions[j])
else:
codeflash_output = _exceptions_are_equivalent(exceptions[i], exceptions[j])
def test_exception_with_complex_args(self):
"""Test exceptions with complex argument structures."""
exc1 = ValueError({"key": "value"}, [1, 2, 3], (4, 5, 6))
exc2 = ValueError({"different": "dict"}, ["a", "b"], ("x", "y"))
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 871ns -> 872ns (0.115% slower)
def test_exception_subclass_comparison(self):
"""Test comparison of exception subclasses."""
class CustomValueError(ValueError):
pass
exc1 = ValueError("base")
exc2 = CustomValueError("subclass")
# Different type names, so not equivalent
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 1.24μs -> 821ns (51.3% faster)
def test_exception_type_name_only_comparison(self):
"""Test that comparison is based on type name, not identity."""
# Create two different exception classes with the same name
exec_code = """
class DynamicException(Exception):
pass
exc1 = DynamicException("msg1")
"""
exec_globals1 = {}
exec(exec_code, exec_globals1)
exc1 = exec_globals1["exc1"]
exec_globals2 = {}
exec(exec_code, exec_globals2)
exc2 = exec_globals2["exc1"]
# Even though they're different classes, they have the same name
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 792ns -> 811ns (2.34% slower)
class TestLargeScale:
"""Test cases for large scale data and performance scenarios."""
def test_large_number_of_exception_comparisons(self):
"""Test comparing a large number of exception pairs."""
# Create 500 ValueError instances
exceptions = [ValueError(f"error_{i}") for i in range(500)]
# Compare each with first exception
for exc in exceptions:
codeflash_output = _exceptions_are_equivalent(exceptions[0], exc) # 132μs -> 140μs (5.94% slower)
def test_large_number_of_different_exception_types(self):
"""Test with many different exception types."""
exception_types = [
ValueError,
TypeError,
RuntimeError,
KeyError,
IndexError,
AttributeError,
ZeroDivisionError,
IOError,
OSError,
ImportError,
NotImplementedError,
StopIteration,
ArithmeticError,
BufferError,
LookupError,
MemoryError,
]
# Test each type is only equivalent to itself
for i, exc_type1 in enumerate(exception_types):
exc1 = exc_type1("message")
for j, exc_type2 in enumerate(exception_types):
exc2 = exc_type2("message")
if i == j:
codeflash_output = _exceptions_are_equivalent(exc1, exc2)
else:
codeflash_output = _exceptions_are_equivalent(exc1, exc2)
def test_very_large_exception_messages(self):
"""Test with exceptions having extremely large messages."""
# Create a 1MB string
large_msg1 = "x" * (1024 * 1024)
large_msg2 = "y" * (1024 * 1024)
exc1 = ValueError(large_msg1)
exc2 = ValueError(large_msg2)
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 1.06μs -> 922ns (15.2% faster)
def test_exception_comparison_performance(self):
"""Test performance of exception comparison with various inputs."""
import time
# Create many exception pairs
pairs = []
for i in range(1000):
if i % 2 == 0:
pairs.append((ValueError(f"msg_{i}"), ValueError(f"msg_{i + 1}")))
else:
pairs.append((ValueError(f"msg_{i}"), TypeError(f"msg_{i + 1}")))
# Perform comparisons and measure time
start_time = time.time()
for exc1, exc2 in pairs:
_exceptions_are_equivalent(exc1, exc2) # 353μs -> 289μs (22.3% faster)
end_time = time.time()
def test_many_torch_runtime_errors(self):
"""Test handling of many TorchRuntimeError instances."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
# Create 500 TorchRuntimeError instances wrapping different base exceptions
torch_errors = []
base_errors = []
for i in range(500):
base_exc = ValueError(f"base_{i}")
torch_exc = TorchRuntimeError(f"wrapped_{i}")
torch_exc.__cause__ = base_exc
torch_errors.append(torch_exc)
base_errors.append(base_exc)
# All should be equivalent to their corresponding base exceptions
for i in range(500):
codeflash_output = _exceptions_are_equivalent(
base_errors[i], torch_errors[i]
) # 351μs -> 299μs (17.4% faster)
def test_exception_types_with_long_names(self):
"""Test exception type comparison with very long type names."""
# Create exception classes with long names
long_name1 = "VeryLongExceptionNameWith" + "x" * 500
long_name2 = "VeryLongExceptionNameWith" + "y" * 500
ExceptionClass1 = type(long_name1, (Exception,), {})
ExceptionClass2 = type(long_name2, (Exception,), {})
exc1 = ExceptionClass1("message")
exc2 = ExceptionClass2("message")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 1.09μs -> 791ns (38.1% faster)
def test_exception_comparison_with_circular_references(self):
"""Test exception comparison with objects containing circular references."""
# Create a circular reference in exception args
circular_list = [1, 2, 3]
circular_list.append(circular_list)
exc1 = ValueError(circular_list)
exc2 = ValueError([4, 5, 6])
# Should still compare by type name successfully
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 861ns -> 892ns (3.48% slower)
def test_batched_torch_error_comparison(self):
"""Test batch processing of many torch runtime error comparisons."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
# Create 100 base exceptions and 100 torch wrapped versions
base_exceptions = [ValueError(f"base_{i}") for i in range(100)]
torch_exceptions = []
for i in range(100):
torch_exc = TorchRuntimeError(f"wrapped_{i}")
torch_exc.__cause__ = ValueError(f"wrapped_base_{i}")
torch_exceptions.append(torch_exc)
# Compare each base with each torch exception of same index
results = []
for i in range(100):
codeflash_output = _exceptions_are_equivalent(base_exceptions[i], torch_exceptions[i])
result = codeflash_output # 72.4μs -> 62.3μs (16.2% faster)
results.append(result)
def test_stress_test_many_exception_instantiations(self):
"""Stress test creating and comparing many exception instances."""
# Create 1000 exception instances
exceptions = []
for i in range(1000):
exc_type = ValueError if i % 2 == 0 else TypeError
exceptions.append(exc_type(f"message_{i}"))
# Compare each exception with a few others
equivalences = 0
non_equivalences = 0
for i in range(0, 1000, 10):
for j in range(i, min(i + 5, 1000)):
if _exceptions_are_equivalent(exceptions[i], exceptions[j]):
equivalences += 1
else:
non_equivalences += 1
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.from codeflash.verification.comparator import _exceptions_are_equivalent
def test__exceptions_are_equivalent():
_exceptions_are_equivalent(BaseException(), BaseException())🔎 Click to see Concolic Coverage Tests
| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup |
|---|---|---|---|
codeflash_concolic_g5uubgk6/tmpkbpq1ztw/test_concolic_coverage.py::test__exceptions_are_equivalent |
952ns | 832ns | 14.4%✅ |
To test or edit this optimization locally git merge codeflash/optimize-pr1119-2026-01-20T05.30.28
Click to see suggested changes
| # If types match, they're potentially equivalent (will be compared further) | |
| if type(orig).__name__ == type(new).__name__: | |
| return True | |
| # Check if one is a TorchRuntimeError wrapping the other's type | |
| if _is_torch_runtime_error(new): | |
| wrapped = _get_wrapped_exception_from_torch_runtime_error(new) | |
| if wrapped is not None and type(wrapped).__name__ == type(orig).__name__: | |
| return True | |
| # Also check the error message for the wrapped exception type | |
| # TorchRuntimeError message contains "got ExceptionType('...')" | |
| error_msg = str(new) | |
| orig_type_name = type(orig).__name__ | |
| if f"got {orig_type_name}(" in error_msg or f"got {orig_type_name}:" in error_msg: | |
| return True | |
| if _is_torch_runtime_error(orig): | |
| wrapped = _get_wrapped_exception_from_torch_runtime_error(orig) | |
| if wrapped is not None and type(wrapped).__name__ == type(new).__name__: | |
| return True | |
| error_msg = str(orig) | |
| new_type_name = type(new).__name__ | |
| # Cache type lookups to avoid redundant attribute access | |
| orig_type = type(orig) | |
| new_type = type(new) | |
| orig_type_name = orig_type.__name__ | |
| new_type_name = new_type.__name__ | |
| # If types match, they're potentially equivalent (will be compared further) | |
| if orig_type_name == new_type_name: | |
| return True | |
| # Check if new is a TorchRuntimeError wrapping the other's type | |
| if new_type_name == "TorchRuntimeError" and "torch._dynamo" in new_type.__module__: | |
| wrapped = _get_wrapped_exception_from_torch_runtime_error(new) | |
| if wrapped is not None and type(wrapped).__name__ == orig_type_name: | |
| return True | |
| # Also check the error message for the wrapped exception type | |
| # TorchRuntimeError message contains "got ExceptionType('...')" | |
| error_msg = str(new) | |
| if f"got {orig_type_name}(" in error_msg or f"got {orig_type_name}:" in error_msg: | |
| return True | |
| # Check if orig is a TorchRuntimeError wrapping the other's type | |
| if orig_type_name == "TorchRuntimeError" and "torch._dynamo" in orig_type.__module__: | |
| wrapped = _get_wrapped_exception_from_torch_runtime_error(orig) | |
| if wrapped is not None and type(wrapped).__name__ == new_type_name: | |
| return True | |
| error_msg = str(orig) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
⚡️Codeflash found 24% (0.24x) speedup for
_get_wrapped_exception_from_torch_runtime_errorincodeflash/verification/comparator.py⏱️ Runtime :
603 microseconds→486 microseconds(best of241runs)📝 Explanation and details
The optimized code achieves a 24% speedup by eliminating redundant function calls and attribute lookups through inlining.
Key Optimizations
1. Function Call Elimination
The original code calls
_is_torch_runtime_error(exc)as a separate function, which incurs overhead for:The optimized version inlines this check directly, removing this overhead entirely.
2. Reduced Attribute Lookups
The original code calls
type(exc)twice within_is_torch_runtime_error():The optimized version caches
type(exc)inexc_typeand reuses it:This eliminates one redundant
type()call per invocation.3. Streamlined Control Flow
The original version had nested conditionals checking the same condition. The optimized version combines the type check with the action (returning
exc.__cause__) in a single conditional block, reducing branching overhead.Performance Impact
Based on
function_references, this function is called from_exceptions_are_equivalent()andcomparator(), which are in the exception comparison hot path during verification. Thecomparator()function handles exception objects during test execution, meaning this optimization directly benefits:The annotated tests show consistent speedups across all test cases (15-31% faster), with particularly strong gains in:
This optimization is most valuable when the function is called frequently in loops or exception-heavy code paths, which aligns with its usage in the verification/comparison pipeline.
✅ Correctness verification report:
🌀 Click to see Generated Regression Tests
🔎 Click to see Concolic Coverage Tests
codeflash_concolic_g5uubgk6/tmpbtersvva/test_concolic_coverage.py::test__get_wrapped_exception_from_torch_runtime_errorTo test or edit this optimization locally
git merge codeflash/optimize-pr1119-2026-01-20T05.21.22