-
Notifications
You must be signed in to change notification settings - Fork 21
torch compile exception wrapper handling for better comparator #1119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@claude review this |
|
Claude finished @aseembits93's task in 1m 20s —— View job Code Review Complete
SummaryThis PR adds special handling for Detailed Review✅ Strengths
🔍 Issues and Concerns1. Redundant Logic in Exception Comparison (codeflash/verification/comparator.py:143-166)The exception comparison logic at lines 143-166 largely duplicates the logic already performed in Current flow:
Recommendation: Consolidate this logic. The unwrapping in lines 143-166 could be simplified or the logic in 2. Potential Issue with Early Return (codeflash/verification/comparator.py:152-155, 162-165)When the wrapped exception type matches but can't be extracted, the code returns elif f"got {type(orig).__name__}(" in str(new) or f"got {type(orig).__name__}:" in str(new):
# The wrapped exception type matches, but we can't extract it
# Consider them equivalent since the same error occurred
return TrueThis returns Example scenario: orig = IndexError("index 0 out of bounds")
new = TorchRuntimeError("got IndexError('index 5 out of bounds')") # Different index!
# These would be considered equivalent even though they represent different errorsRecommendation: Either:
3. String Matching Fragility (codeflash/verification/comparator.py:62-63, 71-72, 152-153, 162-163)The error message parsing relies on specific string patterns: if f"got {orig_type_name}(" in error_msg or f"got {orig_type_name}:" in error_msg:This is fragile because:
Recommendation: Add a comment explaining that this is a fallback heuristic for when 4. Missing Edge Cases in TestsThe test coverage is good but could include:
5. Type Annotation Issue (codeflash/verification/comparator.py:27)def _get_wrapped_exception_from_torch_runtime_error(exc: BaseException) -> Optional[BaseException]: # noqa: FA100The 6. Test File Change Unexplained (code_to_optimize/tests/pytest/test_tensorflow_jit_code.py:12)The change from: tf = pytest.importorskip("tensorflow")to: import tensorflow as tfThis removes the conditional import skip, which means tests will fail if TensorFlow is not installed instead of being skipped. This seems unrelated to the torch exception handling and could be problematic. Recommendation: Either:
📊 Code Quality
🎯 Recommendations
VerdictThe core functionality is solid and solves a real problem. The main concerns are around edge cases and code duplication. I recommend addressing the TensorFlow import change and considering the early return behavior before merging. |
| if not _is_torch_runtime_error(exc): | ||
| return None | ||
|
|
||
| # First try to get the chained exception via __cause__ | ||
| if exc.__cause__ is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
⚡️Codeflash found 24% (0.24x) speedup for _get_wrapped_exception_from_torch_runtime_error in codeflash/verification/comparator.py
⏱️ Runtime : 603 microseconds → 486 microseconds (best of 241 runs)
📝 Explanation and details
The optimized code achieves a 24% speedup by eliminating redundant function calls and attribute lookups through inlining.
Key Optimizations
1. Function Call Elimination
The original code calls _is_torch_runtime_error(exc) as a separate function, which incurs overhead for:
- Function call/return mechanics
- Parameter passing
- Stack frame setup/teardown
The optimized version inlines this check directly, removing this overhead entirely.
2. Reduced Attribute Lookups
The original code calls type(exc) twice within _is_torch_runtime_error():
type(exc).__name__ == "TorchRuntimeError" and "torch._dynamo" in type(exc).__module__The optimized version caches type(exc) in exc_type and reuses it:
exc_type = type(exc)
if exc_type.__name__ == "TorchRuntimeError" and "torch._dynamo" in exc_type.__module__:This eliminates one redundant type() call per invocation.
3. Streamlined Control Flow
The original version had nested conditionals checking the same condition. The optimized version combines the type check with the action (returning exc.__cause__) in a single conditional block, reducing branching overhead.
Performance Impact
Based on function_references, this function is called from _exceptions_are_equivalent() and comparator(), which are in the exception comparison hot path during verification. The comparator() function handles exception objects during test execution, meaning this optimization directly benefits:
- Exception handling during test verification - where many exceptions may be checked rapidly
- TorchRuntimeError unwrapping scenarios - common when using torch.compile/Dynamo
The annotated tests show consistent speedups across all test cases (15-31% faster), with particularly strong gains in:
- Large-scale batch operations (500 exceptions: 30.7% faster)
- Rapid successive calls (500 iterations: 22.8% faster)
- Mixed exception type batches (21.6% faster)
This optimization is most valuable when the function is called frequently in loops or exception-heavy code paths, which aligns with its usage in the verification/comparison pipeline.
✅ Correctness verification report:
| Test | Status |
|---|---|
| ⚙️ Existing Unit Tests | 🔘 None Found |
| 🌀 Generated Regression Tests | ✅ 1982 Passed |
| ⏪ Replay Tests | 🔘 None Found |
| 🔎 Concolic Coverage Tests | ✅ 1 Passed |
| 📊 Tests Coverage | 100.0% |
🌀 Click to see Generated Regression Tests
# imports
from codeflash.verification.comparator import _get_wrapped_exception_from_torch_runtime_error
def _is_torch_runtime_error(exc: BaseException) -> bool:
"""Check if an exception is a TorchRuntimeError from torch.compile/Dynamo."""
# Check by class name to avoid importing torch
return type(exc).__name__ == "TorchRuntimeError" and "torch._dynamo" in type(exc).__module__
def test_non_torch_exception_returns_none_even_with_cause():
# Create a normal ValueError (not TorchRuntimeError) and set its __cause__.
base = ValueError("underlying")
exc = ValueError("wrapper")
# Manually attach a cause as would be done by "raise ... from ..."
exc.__cause__ = base
# Since this is not a TorchRuntimeError (name/module don't match), function should return None
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(exc) # 832ns -> 701ns (18.7% faster)
def test_torch_runtime_error_with_cause_returns_cause():
# Dynamically create a class named TorchRuntimeError with a module indicating torch._dynamo
TorchRuntimeError = type("TorchRuntimeError", (RuntimeError,), {"__module__": "torch._dynamo"})
# Instantiate and attach a cause exception
underlying = RuntimeError("original problem")
tr_exc = TorchRuntimeError("wrapped by torch")
tr_exc.__cause__ = underlying
# The function must detect the TorchRuntimeError (by name and module substring) and return the cause
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(tr_exc)
result = codeflash_output # 1.23μs -> 982ns (25.5% faster)
def test_torch_runtime_error_without_cause_returns_none():
# TorchRuntimeError-like instance but without a __cause__ should yield None
TorchRuntimeError = type("TorchRuntimeError", (RuntimeError,), {"__module__": "torch._dynamo"})
tr_exc = TorchRuntimeError("no cause attached")
# Explicitly ensure no cause
tr_exc.__cause__ = None
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(tr_exc) # 1.13μs -> 911ns (24.3% faster)
def test_torch_runtime_error_with_context_but_no_cause_returns_none():
# If only __context__ is set (implicit exception chaining) but __cause__ is None, return None
TorchRuntimeError = type("TorchRuntimeError", (RuntimeError,), {"__module__": "torch._dynamo"})
tr_exc = TorchRuntimeError("context but no explicit cause")
# Set context (what happens when an exception occurs during exception handling)
tr_exc.__context__ = ValueError("contextual")
tr_exc.__cause__ = None
# The function explicitly checks only __cause__
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(tr_exc) # 1.10μs -> 862ns (27.8% faster)
def test_similar_name_but_wrong_module_returns_none():
# Same class name but module does not contain "torch._dynamo" -> should not be treated as torch runtime error
TorchRuntimeError = type("TorchRuntimeError", (RuntimeError,), {"__module__": "not_torch.dynamo"})
underlying = KeyError("oops")
tr_exc = TorchRuntimeError("looks similar")
tr_exc.__cause__ = underlying
# Because the module string doesn't include the expected substring, result must be None
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(tr_exc) # 931ns -> 801ns (16.2% faster)
def test_module_contains_substring_matches_even_with_extra_path():
# Module containing the substring should match, even if it's a longer path
TorchRuntimeError = type("TorchRuntimeError", (RuntimeError,), {"__module__": "some.prefix.torch._dynamo.suffix"})
underlying = IndexError("index problem")
tr_exc = TorchRuntimeError("wrapped")
tr_exc.__cause__ = underlying
# Should detect by substring and return the underlying exception
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(tr_exc) # 1.21μs -> 922ns (31.5% faster)
def test_nested_cause_returns_only_first_level_cause():
# If the wrapped cause itself has a __cause__, the function should only return the first-level __cause__
TorchRuntimeError = type("TorchRuntimeError", (RuntimeError,), {"__module__": "torch._dynamo"})
inner = ValueError("inner-most")
middle = TypeError("middle")
middle.__cause__ = inner # middle chains to inner
outer = TorchRuntimeError("outer wrapper")
outer.__cause__ = middle # outer chains to middle
# The function should return 'middle', not 'inner'
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(outer)
result = codeflash_output # 1.33μs -> 1.06μs (25.5% faster)
def test_large_scale_many_wrapped_errors_iterable():
# Large scale test: create a sizeable collection (but < 1000) of TorchRuntimeError-like instances
# Each should unwrap to its respective cause; this checks for consistent behavior across many instances.
count = 500 # below the 1000-step loop constraint
TorchRuntimeError = type("TorchRuntimeError", (RuntimeError,), {"__module__": "torch._dynamo"})
wrappers = []
causes = []
for i in range(count):
cause = RuntimeError(f"cause-{i}") # distinct cause objects
tr_exc = TorchRuntimeError(f"wrapper-{i}")
tr_exc.__cause__ = cause
wrappers.append(tr_exc)
causes.append(cause)
# Verify each wrapper returns the exact corresponding cause
for idx, w in enumerate(wrappers):
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(w) # 165μs -> 126μs (30.7% faster)
def test_is_torch_runtime_error_helper_identifies_only_matching_types():
# Validate helper behavior for a mix of objects to ensure the gatekeeping logic is correct.
TorchRuntimeError = type("TorchRuntimeError", (RuntimeError,), {"__module__": "torch._dynamo"})
other_named = type("TorchRuntimeError", (RuntimeError,), {"__module__": "torch.something_else"})
other_module = type("OtherError", (RuntimeError,), {"__module__": "torch._dynamo"})
t1 = TorchRuntimeError("match")
t2 = other_named("wrong module")
t3 = other_module("wrong name")
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.from codeflash.verification.comparator import _get_wrapped_exception_from_torch_runtime_error
# Function to test
def _is_torch_runtime_error(exc: BaseException) -> bool:
"""Check if an exception is a TorchRuntimeError from torch.compile/Dynamo."""
# Check by class name to avoid importing torch
return type(exc).__name__ == "TorchRuntimeError" and "torch._dynamo" in type(exc).__module__
# Helper to create a mock TorchRuntimeError with the correct module and class name
class TorchRuntimeError(Exception):
"""Mock TorchRuntimeError that simulates torch._dynamo.TorchRuntimeError."""
def test_non_torch_exception_returns_none():
"""Test that non-TorchRuntimeError exceptions return None."""
# Standard ValueError should not be considered a TorchRuntimeError
exc = ValueError("Standard error")
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(exc)
result = codeflash_output # 922ns -> 701ns (31.5% faster)
def test_torch_error_with_cause_returns_cause():
"""Test that TorchRuntimeError with __cause__ returns the wrapped exception."""
# Create a wrapped exception
original_exc = RuntimeError("Original error")
# Create a TorchRuntimeError with the original as __cause__
torch_exc = TorchRuntimeError("Torch runtime error")
torch_exc.__cause__ = original_exc
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
result = codeflash_output # 1.08μs -> 862ns (25.5% faster)
def test_torch_error_without_cause_returns_none():
"""Test that TorchRuntimeError without __cause__ returns None."""
# Create a TorchRuntimeError without a cause
torch_exc = TorchRuntimeError("Torch runtime error")
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
result = codeflash_output # 911ns -> 752ns (21.1% faster)
def test_torch_error_with_none_cause_returns_none():
"""Test that TorchRuntimeError with explicitly None __cause__ returns None."""
# Create a TorchRuntimeError with explicit None cause
torch_exc = TorchRuntimeError("Torch runtime error")
torch_exc.__cause__ = None
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
result = codeflash_output # 902ns -> 781ns (15.5% faster)
def test_standard_exception_types_return_none():
"""Test that various standard exception types return None."""
# Test with different standard exception types
exceptions = [
TypeError("type error"),
AttributeError("attribute error"),
KeyError("key error"),
IndexError("index error"),
ZeroDivisionError("division error"),
]
for exc in exceptions:
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(exc)
result = codeflash_output # 2.08μs -> 1.63μs (27.6% faster)
def test_torch_error_wraps_value_error():
"""Test extraction of ValueError from TorchRuntimeError."""
# Create a ValueError and wrap it
original_exc = ValueError("Wrapped value error")
torch_exc = TorchRuntimeError("Torch error")
torch_exc.__cause__ = original_exc
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
result = codeflash_output # 1.00μs -> 812ns (23.3% faster)
def test_torch_error_with_nested_cause_chain():
"""Test extraction when exception has nested cause chains."""
# Create a chain: original -> wrapped1 -> wrapped2 (torch)
original = RuntimeError("Original")
wrapped1 = TypeError("Wrapped once")
wrapped1.__cause__ = original
torch_exc = TorchRuntimeError("Torch error")
torch_exc.__cause__ = wrapped1
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
result = codeflash_output # 892ns -> 711ns (25.5% faster)
def test_exception_with_empty_message():
"""Test handling of exceptions with empty messages."""
original_exc = RuntimeError("")
torch_exc = TorchRuntimeError("")
torch_exc.__cause__ = original_exc
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
result = codeflash_output # 882ns -> 812ns (8.62% faster)
def test_exception_with_special_characters_in_message():
"""Test exceptions with special characters in messages."""
original_exc = RuntimeError("Error\nwith\nnewlines\tand\ttabs")
torch_exc = TorchRuntimeError("Torch error")
torch_exc.__cause__ = original_exc
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
result = codeflash_output # 941ns -> 812ns (15.9% faster)
def test_exception_with_unicode_message():
"""Test exceptions with unicode characters."""
original_exc = RuntimeError("Error with unicode: \u00e9\u00e8\u00ea \u4e2d\u6587")
torch_exc = TorchRuntimeError("Torch error")
torch_exc.__cause__ = original_exc
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
result = codeflash_output # 912ns -> 801ns (13.9% faster)
def test_torch_error_with_context_vs_cause():
"""Test that only __cause__ is extracted, not __context__."""
# Create exceptions with context but no cause
original_exc = RuntimeError("Original")
torch_exc = TorchRuntimeError("Torch error")
# Simulate an implicit exception context
torch_exc.__context__ = original_exc
torch_exc.__cause__ = None
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
result = codeflash_output # 922ns -> 862ns (6.96% faster)
def test_wrong_module_name_returns_none():
"""Test that exception with wrong module name returns None."""
# Create an exception with correct class name but wrong module
exc = TorchRuntimeError("Error")
exc.__module__ = "wrong_module"
# This should return None because module is not torch._dynamo
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(exc)
result = codeflash_output # 911ns -> 781ns (16.6% faster)
def test_wrong_class_name_returns_none():
"""Test that exception with wrong class name returns None."""
# Create a custom exception with torch._dynamo module but wrong name
class WrongError(Exception):
pass
WrongError.__module__ = "torch._dynamo"
exc = WrongError("Error")
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(exc)
result = codeflash_output # 642ns -> 552ns (16.3% faster)
def test_exception_object_identity_preserved():
"""Test that the returned exception is the same object, not a copy."""
original_exc = RuntimeError("Original")
torch_exc = TorchRuntimeError("Torch error")
torch_exc.__cause__ = original_exc
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
result = codeflash_output # 892ns -> 851ns (4.82% faster)
def test_torch_error_with_custom_exception_subclass():
"""Test extraction of custom exception subclass from TorchRuntimeError."""
class CustomError(Exception):
"""Custom exception for testing."""
original_exc = CustomError("Custom error message")
torch_exc = TorchRuntimeError("Torch error")
torch_exc.__cause__ = original_exc
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
result = codeflash_output # 952ns -> 822ns (15.8% faster)
def test_exception_with_traceback():
"""Test handling of exceptions with traceback information."""
try:
1 / 0
except ZeroDivisionError as e:
original_exc = e
torch_exc = TorchRuntimeError("Torch error")
torch_exc.__cause__ = original_exc
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
result = codeflash_output # 911ns -> 802ns (13.6% faster)
def test_none_input():
"""Test that None input raises appropriate error or returns None."""
# The function signature expects BaseException, but test robustness
try:
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(None)
result = codeflash_output
except (AttributeError, TypeError):
# AttributeError is expected when calling methods on None
pass
def test_torch_error_module_substring_not_matching():
"""Test that partial module name matches are not accepted."""
# Create exception where torch._dynamo is a substring but not exact
exc = TorchRuntimeError("Error")
exc.__module__ = "notorch._dynamo"
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(exc)
result = codeflash_output # 932ns -> 821ns (13.5% faster)
def test_multiple_sequential_torch_errors():
"""Test handling of multiple TorchRuntimeError instances sequentially."""
# Create 100 different wrapped exceptions
results = []
for i in range(100):
original_exc = RuntimeError(f"Original error {i}")
torch_exc = TorchRuntimeError(f"Torch error {i}")
torch_exc.__cause__ = original_exc
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
result = codeflash_output # 29.7μs -> 24.2μs (22.6% faster)
results.append(result)
for i, result in enumerate(results):
pass
def test_large_batch_mixed_exception_types():
"""Test function with large batch of mixed exception types."""
# Create 500 exceptions: mix of TorchRuntimeError and other types
exceptions = []
for i in range(500):
if i % 2 == 0:
# Create TorchRuntimeError
original = RuntimeError(f"Error {i}")
torch_exc = TorchRuntimeError(f"Torch {i}")
torch_exc.__cause__ = original
exceptions.append(torch_exc)
else:
# Create standard exception
exceptions.append(ValueError(f"Standard {i}"))
results = []
torch_count = 0
for exc in exceptions:
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(exc)
result = codeflash_output # 141μs -> 116μs (21.6% faster)
results.append(result)
if result is not None:
torch_count += 1
def test_deep_cause_chain_extraction():
"""Test performance with deep exception chains (only direct cause extracted)."""
# Create a deep chain: a0 -> a1 -> a2 -> ... -> torch_exc
torch_exc = TorchRuntimeError("Torch error")
original = RuntimeError("Original")
torch_exc.__cause__ = original
# The function should only extract the direct cause
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
result = codeflash_output # 971ns -> 911ns (6.59% faster)
def test_large_exception_messages():
"""Test handling of exceptions with very large messages."""
# Create an exception with a large message (but < 1MB)
large_message = "x" * 100000 # 100KB message
original_exc = RuntimeError(large_message)
torch_exc = TorchRuntimeError("Torch error")
torch_exc.__cause__ = original_exc
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
result = codeflash_output # 992ns -> 892ns (11.2% faster)
def test_performance_many_checks_non_torch_exceptions():
"""Test performance checking many non-TorchRuntimeError exceptions."""
# Generate 200 standard exceptions
exceptions = [ValueError(f"Error {i}") for i in range(200)]
results = []
for exc in exceptions:
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(exc)
result = codeflash_output # 50.7μs -> 42.4μs (19.5% faster)
results.append(result)
def test_torch_errors_with_same_wrapped_exception():
"""Test multiple TorchRuntimeErrors wrapping the same exception."""
# Create one original exception
shared_original = RuntimeError("Shared original")
# Wrap it in multiple TorchRuntimeErrors
torch_errors = []
for i in range(50):
torch_exc = TorchRuntimeError(f"Torch error {i}")
torch_exc.__cause__ = shared_original
torch_errors.append(torch_exc)
results = []
for torch_exc in torch_errors:
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
result = codeflash_output # 15.4μs -> 12.5μs (23.1% faster)
results.append(result)
def test_rapid_successive_calls():
"""Test rapid successive calls to verify no state corruption."""
original_exc = RuntimeError("Original")
torch_exc = TorchRuntimeError("Torch error")
torch_exc.__cause__ = original_exc
# Call function 500 times rapidly
results = []
for _ in range(500):
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
result = codeflash_output # 142μs -> 116μs (22.8% faster)
results.append(result)
def test_varied_wrapped_exception_types_at_scale():
"""Test at scale with many different wrapped exception types."""
exception_types = [
RuntimeError,
ValueError,
TypeError,
AttributeError,
KeyError,
IndexError,
ZeroDivisionError,
NotImplementedError,
AssertionError,
OSError,
]
results = []
for i in range(100):
exc_type = exception_types[i % len(exception_types)]
original_exc = exc_type(f"Message {i}")
torch_exc = TorchRuntimeError(f"Torch error {i}")
torch_exc.__cause__ = original_exc
codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
result = codeflash_output # 30.0μs -> 24.3μs (23.6% faster)
results.append(result)
# Verify types match
for i, result in enumerate(results):
expected_type = exception_types[i % len(exception_types)]
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.from codeflash.verification.comparator import _get_wrapped_exception_from_torch_runtime_error
def test__get_wrapped_exception_from_torch_runtime_error():
_get_wrapped_exception_from_torch_runtime_error(BaseException())🔎 Click to see Concolic Coverage Tests
| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup |
|---|---|---|---|
codeflash_concolic_g5uubgk6/tmpbtersvva/test_concolic_coverage.py::test__get_wrapped_exception_from_torch_runtime_error |
842ns | 651ns | 29.3%✅ |
To test or edit this optimization locally git merge codeflash/optimize-pr1119-2026-01-20T05.21.22
| if not _is_torch_runtime_error(exc): | |
| return None | |
| # First try to get the chained exception via __cause__ | |
| if exc.__cause__ is not None: | |
| # Inline the check to avoid redundant function call and attribute lookups | |
| exc_type = type(exc) | |
| if exc_type.__name__ == "TorchRuntimeError" and "torch._dynamo" in exc_type.__module__: |
| # If types match, they're potentially equivalent (will be compared further) | ||
| if type(orig).__name__ == type(new).__name__: | ||
| return True | ||
|
|
||
| # Check if one is a TorchRuntimeError wrapping the other's type | ||
| if _is_torch_runtime_error(new): | ||
| wrapped = _get_wrapped_exception_from_torch_runtime_error(new) | ||
| if wrapped is not None and type(wrapped).__name__ == type(orig).__name__: | ||
| return True | ||
| # Also check the error message for the wrapped exception type | ||
| # TorchRuntimeError message contains "got ExceptionType('...')" | ||
| error_msg = str(new) | ||
| orig_type_name = type(orig).__name__ | ||
| if f"got {orig_type_name}(" in error_msg or f"got {orig_type_name}:" in error_msg: | ||
| return True | ||
|
|
||
| if _is_torch_runtime_error(orig): | ||
| wrapped = _get_wrapped_exception_from_torch_runtime_error(orig) | ||
| if wrapped is not None and type(wrapped).__name__ == type(new).__name__: | ||
| return True | ||
| error_msg = str(orig) | ||
| new_type_name = type(new).__name__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
⚡️Codeflash found 16% (0.16x) speedup for _exceptions_are_equivalent in codeflash/verification/comparator.py
⏱️ Runtime : 1.22 milliseconds → 1.05 milliseconds (best of 146 runs)
📝 Explanation and details
The optimized code achieves a 15% speedup by reducing redundant attribute lookups through strategic caching of type() results and their attributes.
Key Optimizations
1. Cached Type Lookups in _is_torch_runtime_error
The original code called type(exc) twice per invocation:
return type(exc).__name__ == "TorchRuntimeError" and "torch._dynamo" in type(exc).__module__The optimized version caches the result:
exc_type = type(exc)
return exc_type.__name__ == "TorchRuntimeError" and "torch._dynamo" in exc_type.__module__This eliminates one type() call per check, which is significant since this function is called frequently (3,164 hits in the original, 611 in the optimization test).
2. Upfront Type Attribute Caching in _exceptions_are_equivalent
The most impactful optimization occurs in _exceptions_are_equivalent, where type information is cached at the function start:
orig_type = type(orig)
new_type = type(new)
orig_type_name = orig_type.__name__
new_type_name = new_type.__name__This eliminates repeated type() calls throughout the function. The line profiler shows this reduces time spent on type checking from ~34% to ~4.7% of total execution time.
3. Inlined TorchRuntimeError Checks
Instead of calling _is_torch_runtime_error() (which itself performs type lookups), the optimized code directly checks the cached type attributes:
if new_type_name == "TorchRuntimeError" and "torch._dynamo" in new_type.__module__:This avoids the overhead of function calls and redundant type operations.
Performance Impact by Test Case
The optimization particularly excels in scenarios involving:
- Different exception types (46.8-66.8% faster): Cached type names enable quick inequality checks without repeated lookups
- TorchRuntimeError unwrapping (14.3-26% faster): Direct type checks bypass function call overhead
- Large-scale comparisons (17.4-22.3% faster on 500-1000 iterations): Reduced per-comparison overhead compounds significantly
The optimization shows minimal regression (~2-5% slower) only in edge cases with very few comparisons where the upfront caching overhead slightly exceeds savings.
Context and Importance
Based on function_references, _exceptions_are_equivalent is called from the hot path comparator() function, which is a recursive comparison utility handling various data types. Since comparator() may be called frequently during verification workflows (especially with large data structures or exception chains), even a 15% improvement translates to meaningful time savings in production validation scenarios.
The optimization maintains correctness by preserving all logical branches while simply reusing computed type information instead of repeatedly querying the Python object model.
✅ Correctness verification report:
| Test | Status |
|---|---|
| ⚙️ Existing Unit Tests | 🔘 None Found |
| 🌀 Generated Regression Tests | ✅ 2938 Passed |
| ⏪ Replay Tests | 🔘 None Found |
| 🔎 Concolic Coverage Tests | ✅ 1 Passed |
| 📊 Tests Coverage | 78.9% |
🌀 Click to see Generated Regression Tests
from typing import Optional
# imports
from codeflash.verification.comparator import _exceptions_are_equivalent
def _is_torch_runtime_error(exc: BaseException) -> bool:
"""Check if an exception is a TorchRuntimeError from torch.compile/Dynamo."""
# Check by class name to avoid importing torch
return type(exc).__name__ == "TorchRuntimeError" and "torch._dynamo" in type(exc).__module__
def _get_wrapped_exception_from_torch_runtime_error(exc: BaseException) -> Optional[BaseException]: # noqa: FA100
"""Extract the underlying exception wrapped by TorchRuntimeError.
TorchRuntimeError from torch.compile/Dynamo wraps the original exception
using Python's exception chaining (__cause__).
"""
if not _is_torch_runtime_error(exc):
return None
# First try to get the chained exception via __cause__
if exc.__cause__ is not None:
return exc.__cause__
return None
def test_same_builtin_type_is_equivalent():
# Basic case: identical built-in exception types should be equivalent
orig = ValueError("original")
new = ValueError("new")
# types have the same __name__ so should be True
codeflash_output = _exceptions_are_equivalent(orig, new) # 882ns -> 861ns (2.44% faster)
def test_same_type_name_different_classes_equivalent():
# If types have the same name (even if different classes), equivalence is based on name
# Create a custom exception class named 'ValueError' (different object than built-in)
CustomValueError = type("ValueError", (Exception,), {}) # class name matches built-in
orig = ValueError("built-in") # built-in ValueError
new = CustomValueError("custom") # custom class with the same name
# Should be True because only the __name__ is compared in the function's first check
codeflash_output = _exceptions_are_equivalent(orig, new) # 791ns -> 802ns (1.37% slower)
def test_torch_runtime_error_with_cause_matches_orig():
# Simulate a TorchRuntimeError wrapper by creating a class with the same name
# and a module path containing 'torch._dynamo'
TorchRuntimeErrorFake = type("TorchRuntimeError", (Exception,), {})
TorchRuntimeErrorFake.__module__ = "torch._dynamo.fake_module" # include required substring
# Underlying exception that was originally raised
underlying = TypeError("underlying problem")
# Create the fake TorchRuntimeError and attach the underlying exception via __cause__
wrapped = TorchRuntimeErrorFake("wrapper message")
wrapped.__cause__ = underlying
# orig is the underlying exception, new is the wrapper: should be considered equivalent
codeflash_output = _exceptions_are_equivalent(underlying, wrapped) # 2.23μs -> 1.95μs (14.3% faster)
def test_torch_runtime_error_wraps_new_side_equivalence():
# Reverse of previous: orig is a wrapper, new is the underlying exception
TorchRuntimeErrorFake = type("TorchRuntimeError", (Exception,), {})
TorchRuntimeErrorFake.__module__ = "torch._dynamo.other"
underlying = KeyError("missing")
wrapper = TorchRuntimeErrorFake("wrapped key error")
wrapper.__cause__ = underlying
# orig is wrapper, new is underlying -> should be True
codeflash_output = _exceptions_are_equivalent(wrapper, underlying) # 2.37μs -> 1.88μs (26.0% faster)
def test_torch_runtime_error_message_parenthesis_format_matches():
# TorchRuntimeError without a __cause__ but with a message containing "got TypeError('...')"
TorchRuntimeErrorFake = type("TorchRuntimeError", (Exception,), {})
TorchRuntimeErrorFake.__module__ = "torch._dynamo.something"
# Craft message that matches the "got <TypeName>(" pattern
msg = "some runtime wrapper; got TypeError('oops') while compiling"
wrapper = TorchRuntimeErrorFake(msg)
orig = TypeError("oops")
# Since the message contains "got TypeError(" the function should report equivalence
codeflash_output = _exceptions_are_equivalent(orig, wrapper) # 2.50μs -> 2.38μs (5.08% faster)
def test_torch_runtime_error_message_colon_format_matches():
# TorchRuntimeError message with the "got TypeError:" pattern should match as well
TorchRuntimeErrorFake = type("TorchRuntimeError", (Exception,), {})
TorchRuntimeErrorFake.__module__ = "torch._dynamo"
msg = "error occurred; got TypeError: details follow"
wrapper = TorchRuntimeErrorFake(msg)
orig = TypeError("details follow")
# Should be True because of the "got TypeError:" pattern
codeflash_output = _exceptions_are_equivalent(orig, wrapper) # 2.48μs -> 2.31μs (7.81% faster)
def test_non_equivalent_different_types():
# Completely different exception types without any torch wrapper should not be equivalent
orig = IndexError("index fail")
new = ValueError("value fail")
codeflash_output = _exceptions_are_equivalent(orig, new) # 1.38μs -> 942ns (46.8% faster)
def test_torch_runtime_error_with_incorrect_module_not_recognized():
# Create a class named TorchRuntimeError but with module that does NOT contain 'torch._dynamo'
TorchRuntimeErrorNotDynamo = type("TorchRuntimeError", (Exception,), {})
TorchRuntimeErrorNotDynamo.__module__ = "some.other.module"
underlying = RuntimeError("bad")
wrapper = TorchRuntimeErrorNotDynamo("wrapper")
wrapper.__cause__ = underlying
# Because the module doesn't include 'torch._dynamo', the wrapper should NOT be recognized,
# and since names differ (TorchRuntimeError vs RuntimeError), result should be False
codeflash_output = _exceptions_are_equivalent(underlying, wrapper) # 1.55μs -> 1.14μs (36.1% faster)
def test_torch_runtime_error_message_partial_no_false_positive():
# Ensure that similar-but-not-matching messages do not produce false positives.
TorchRuntimeErrorFake = type("TorchRuntimeError", (Exception,), {})
TorchRuntimeErrorFake.__module__ = "torch._dynamo.sub"
# Message contains 'got Value(' which is not the exact searched pattern 'got ValueError('
msg = "wrapper says got Value('x') somewhere"
wrapper = TorchRuntimeErrorFake(msg)
orig = ValueError("x")
# Should be False because "got Value(" does not match "got ValueError(" or "got ValueError:"
codeflash_output = _exceptions_are_equivalent(orig, wrapper) # 2.85μs -> 2.60μs (10.0% faster)
def test_large_scale_many_name_based_equivalences():
# Large-scale test: ensure performance & correctness over many name-equality cases.
# We'll create 500 pairs (well under the 1000 element limit) where the type names match.
pairs = []
for i in range(500): # 500 iterations is under the "do not exceed 1000 steps" guidance
# Create two distinct classes that share the same __name__
name = f"BulkExc{i}"
A = type(name, (Exception,), {})
B = type(name, (Exception,), {})
# instantiate exceptions
a_inst = A(f"a {i}")
b_inst = B(f"b {i}")
pairs.append((a_inst, b_inst))
# All pairs should be equivalent because only the __name__ is checked in the first step
for a_inst, b_inst in pairs:
codeflash_output = _exceptions_are_equivalent(a_inst, b_inst) # 109μs -> 112μs (2.28% slower)
def test_torch_runtime_error_message_matches_custom_exception_name():
# Message-based match should work for non-builtins too: orig is custom-named exception
CustomErr = type("CustomErr", (Exception,), {})
TorchRuntimeErrorFake = type("TorchRuntimeError", (Exception,), {})
TorchRuntimeErrorFake.__module__ = "torch._dynamo.extra"
# Message contains the custom exception name in the expected pattern
msg = "compilation failed; got CustomErr('x') at runtime"
wrapper = TorchRuntimeErrorFake(msg)
orig = CustomErr("x")
# Should be True because message contains "got CustomErr("
codeflash_output = _exceptions_are_equivalent(orig, wrapper) # 2.52μs -> 2.46μs (2.48% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.from codeflash.verification.comparator import _exceptions_are_equivalent
def _is_torch_runtime_error(exc: BaseException) -> bool:
"""Check if an exception is a TorchRuntimeError from torch.compile/Dynamo."""
# Check by class name to avoid importing torch
return type(exc).__name__ == "TorchRuntimeError" and "torch._dynamo" in type(exc).__module__
def _get_wrapped_exception_from_torch_runtime_error(exc: BaseException):
"""Extract the underlying exception wrapped by TorchRuntimeError.
TorchRuntimeError from torch.compile/Dynamo wraps the original exception
using Python's exception chaining (__cause__).
"""
if not _is_torch_runtime_error(exc):
return None
# First try to get the chained exception via __cause__
if exc.__cause__ is not None:
return exc.__cause__
return None
class TestBasicFunctionality:
"""Test cases for basic equivalence checking of standard exceptions."""
def test_identical_exception_types(self):
"""Test that two exceptions of the same type are considered equivalent."""
exc1 = ValueError("test error")
exc2 = ValueError("different message")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 922ns -> 852ns (8.22% faster)
def test_identical_standard_exceptions(self):
"""Test equivalence with various standard Python exceptions."""
pairs = [
(TypeError("msg1"), TypeError("msg2")),
(RuntimeError("msg1"), RuntimeError("msg2")),
(KeyError("msg1"), KeyError("msg2")),
(IndexError("msg1"), IndexError("msg2")),
(AttributeError("msg1"), AttributeError("msg2")),
(ZeroDivisionError("msg1"), ZeroDivisionError("msg2")),
]
for exc1, exc2 in pairs:
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 2.79μs -> 2.80μs (0.036% slower)
def test_different_exception_types(self):
"""Test that different exception types are not equivalent."""
exc1 = ValueError("error")
exc2 = TypeError("error")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 1.35μs -> 811ns (66.8% faster)
def test_different_exception_types_all_combinations(self):
"""Test equivalence returns False for various mismatched exception types."""
pairs = [
(ValueError("msg"), TypeError("msg")),
(RuntimeError("msg"), KeyError("msg")),
(IndexError("msg"), AttributeError("msg")),
(ZeroDivisionError("msg"), Exception("msg")),
]
for exc1, exc2 in pairs:
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 3.33μs -> 2.25μs (48.1% faster)
def test_custom_exception_same_type(self):
"""Test equivalence for custom exception classes."""
class CustomError(Exception):
pass
exc1 = CustomError("message1")
exc2 = CustomError("message2")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 601ns -> 611ns (1.64% slower)
def test_custom_exception_different_types(self):
"""Test non-equivalence for different custom exception types."""
class CustomError1(Exception):
pass
class CustomError2(Exception):
pass
exc1 = CustomError1("message")
exc2 = CustomError2("message")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 1.00μs -> 691ns (45.0% faster)
def test_base_exception_type(self):
"""Test with base Exception type."""
exc1 = Exception("message1")
exc2 = Exception("message2")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 902ns -> 801ns (12.6% faster)
def test_exception_with_no_message(self):
"""Test equivalence for exceptions without messages."""
exc1 = ValueError()
exc2 = ValueError()
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 851ns -> 752ns (13.2% faster)
def test_exception_with_empty_vs_nonempty_message(self):
"""Test equivalence is independent of message content."""
exc1 = ValueError("")
exc2 = ValueError("detailed error message")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 851ns -> 731ns (16.4% faster)
class TestEdgeCases:
"""Test cases for edge conditions and unusual scenarios."""
def test_exception_with_special_characters_in_message(self):
"""Test exceptions with special characters don't affect type comparison."""
exc1 = ValueError("Error with @#$%^&*()")
exc2 = ValueError("Error with \n\t\r")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 901ns -> 752ns (19.8% faster)
def test_exception_with_unicode_characters(self):
"""Test exceptions with unicode characters."""
exc1 = ValueError("Error: \u2764\u2665\u263a")
exc2 = ValueError("Error: \u00e9\u00e8\u00ea")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 832ns -> 791ns (5.18% faster)
def test_exception_with_very_long_message(self):
"""Test exceptions with very long messages."""
long_msg1 = "x" * 10000
long_msg2 = "y" * 10000
exc1 = ValueError(long_msg1)
exc2 = ValueError(long_msg2)
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 872ns -> 732ns (19.1% faster)
def test_exception_with_numeric_message(self):
"""Test exceptions with numeric content in messages."""
exc1 = ValueError(12345)
exc2 = ValueError(67890)
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 801ns -> 701ns (14.3% faster)
def test_exception_inheritance_not_equivalent(self):
"""Test that exception inheritance hierarchy doesn't make them equivalent."""
class BaseCustom(Exception):
pass
class DerivedCustom(BaseCustom):
pass
exc1 = BaseCustom("message")
exc2 = DerivedCustom("message")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 1.07μs -> 731ns (46.5% faster)
def test_exception_with_args_tuple(self):
"""Test exceptions constructed with multiple arguments."""
exc1 = ValueError("arg1", "arg2", "arg3")
exc2 = ValueError("different1", "different2")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 832ns -> 752ns (10.6% faster)
def test_exception_with_none_as_message(self):
"""Test exception with None as message."""
exc1 = ValueError(None)
exc2 = ValueError("actual message")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 882ns -> 732ns (20.5% faster)
def test_exception_same_instance_compared(self):
"""Test comparing the same exception instance with itself."""
exc = ValueError("message")
codeflash_output = _exceptions_are_equivalent(exc, exc) # 851ns -> 712ns (19.5% faster)
def test_exception_from_different_modules_same_name(self):
"""Test handling of different classes with same name but different modules."""
# Create mock exception classes with same name but different identities
class ValueError1(BaseException):
pass
class ValueError2(BaseException):
pass
# Rename them to have the same __name__ but different __module__
ValueError1.__name__ = "SameName"
ValueError2.__name__ = "SameName"
ValueError1.__module__ = "module1"
ValueError2.__module__ = "module2"
exc1 = ValueError1("msg")
exc2 = ValueError2("msg")
# They should be equivalent because type names match
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 601ns -> 631ns (4.75% slower)
def test_exception_with_multiline_message(self):
"""Test exceptions with multiline error messages."""
msg1 = "Line1\nLine2\nLine3\nLine4"
msg2 = "Different\nMultiline\nMessage\nHere"
exc1 = RuntimeError(msg1)
exc2 = RuntimeError(msg2)
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 892ns -> 802ns (11.2% faster)
def test_exception_message_containing_type_name(self):
"""Test message that contains other exception type names."""
exc1 = ValueError("This is a ValueError")
exc2 = ValueError("This is a TypeError or KeyError")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 872ns -> 831ns (4.93% faster)
class TestTorchRuntimeErrorHandling:
"""Test cases for TorchRuntimeError wrapping behavior."""
def test_torch_runtime_error_detection(self):
"""Test that TorchRuntimeError can be detected by type name."""
# Create a mock TorchRuntimeError-like exception
class TorchRuntimeError(BaseException):
pass
# Set the module to match torch._dynamo
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
exc = TorchRuntimeError("wrapped error")
def test_torch_runtime_error_false_for_different_module(self):
"""Test that TorchRuntimeError from different module is not detected."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "some_other_module"
TorchRuntimeError.__name__ = "TorchRuntimeError"
exc = TorchRuntimeError("wrapped error")
def test_get_wrapped_exception_with_cause(self):
"""Test extracting wrapped exception from __cause__."""
# Create a TorchRuntimeError-like class
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
# Create the wrapped exception
original_exc = ValueError("original error")
torch_exc = TorchRuntimeError("torch wrapped error")
torch_exc.__cause__ = original_exc
wrapped = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
def test_get_wrapped_exception_returns_none_for_non_torch(self):
"""Test that non-TorchRuntimeError returns None."""
exc = ValueError("regular error")
wrapped = _get_wrapped_exception_from_torch_runtime_error(exc)
def test_get_wrapped_exception_no_cause_returns_none(self):
"""Test TorchRuntimeError without __cause__ returns None."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
torch_exc = TorchRuntimeError("no cause")
wrapped = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
def test_torch_runtime_error_wrapping_via_cause_equivalent(self):
"""Test equivalence when TorchRuntimeError wraps exception via __cause__."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
# Original exception
exc1 = ValueError("original")
# TorchRuntimeError wrapping the original
exc2 = TorchRuntimeError("wrapped")
exc2.__cause__ = ValueError("wrapped original")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 2.23μs -> 2.17μs (2.76% faster)
def test_torch_runtime_error_wrapping_different_type_not_equivalent(self):
"""Test non-equivalence when TorchRuntimeError wraps different type."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
# Original exception
exc1 = ValueError("original")
# TorchRuntimeError wrapping a different type
exc2 = TorchRuntimeError("wrapped")
exc2.__cause__ = TypeError("wrapped different")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 2.92μs -> 2.48μs (18.2% faster)
def test_torch_runtime_error_message_pattern_matching(self):
"""Test TorchRuntimeError message pattern matching for type name."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
# Original exception
exc1 = ValueError("original")
# TorchRuntimeError with message containing the wrapped type
exc2 = TorchRuntimeError("got ValueError('wrapped error')")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 2.38μs -> 2.13μs (11.7% faster)
def test_torch_runtime_error_message_pattern_colon_format(self):
"""Test TorchRuntimeError message pattern with colon format."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
# Original exception
exc1 = RuntimeError("original")
# TorchRuntimeError with message containing the wrapped type (colon format)
exc2 = TorchRuntimeError("got RuntimeError: wrapped error")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 2.44μs -> 2.29μs (6.96% faster)
def test_torch_runtime_error_reverse_wrapping(self):
"""Test equivalence when original is TorchRuntimeError wrapping new type."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
# Original as TorchRuntimeError
exc1 = TorchRuntimeError("wrapped")
exc1.__cause__ = ValueError("wrapped original")
# New exception
exc2 = ValueError("new")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 2.33μs -> 1.91μs (21.5% faster)
def test_torch_runtime_error_both_sides(self):
"""Test when both exceptions are TorchRuntimeError with different wrapped types."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
# Both as TorchRuntimeError wrapping different types
exc1 = TorchRuntimeError("wrapped1")
exc1.__cause__ = ValueError("val error")
exc2 = TorchRuntimeError("wrapped2")
exc2.__cause__ = TypeError("type error")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 530ns -> 651ns (18.6% slower)
class TestComplexScenarios:
"""Test cases combining multiple features and complex scenarios."""
def test_exception_chain_with_context(self):
"""Test exceptions that are part of an exception chain."""
try:
try:
raise ValueError("original")
except ValueError as e:
raise TypeError("new") from e
except TypeError as e:
exc1 = e
exc2 = TypeError("different message")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 942ns -> 921ns (2.28% faster)
def test_multiple_exception_types_in_sequence(self):
"""Test multiple different exception types don't become equivalent."""
exceptions = [ValueError("msg"), TypeError("msg"), RuntimeError("msg"), KeyError("msg"), IndexError("msg")]
for i in range(len(exceptions)):
for j in range(len(exceptions)):
if i == j:
codeflash_output = _exceptions_are_equivalent(exceptions[i], exceptions[j])
else:
codeflash_output = _exceptions_are_equivalent(exceptions[i], exceptions[j])
def test_exception_with_complex_args(self):
"""Test exceptions with complex argument structures."""
exc1 = ValueError({"key": "value"}, [1, 2, 3], (4, 5, 6))
exc2 = ValueError({"different": "dict"}, ["a", "b"], ("x", "y"))
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 871ns -> 872ns (0.115% slower)
def test_exception_subclass_comparison(self):
"""Test comparison of exception subclasses."""
class CustomValueError(ValueError):
pass
exc1 = ValueError("base")
exc2 = CustomValueError("subclass")
# Different type names, so not equivalent
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 1.24μs -> 821ns (51.3% faster)
def test_exception_type_name_only_comparison(self):
"""Test that comparison is based on type name, not identity."""
# Create two different exception classes with the same name
exec_code = """
class DynamicException(Exception):
pass
exc1 = DynamicException("msg1")
"""
exec_globals1 = {}
exec(exec_code, exec_globals1)
exc1 = exec_globals1["exc1"]
exec_globals2 = {}
exec(exec_code, exec_globals2)
exc2 = exec_globals2["exc1"]
# Even though they're different classes, they have the same name
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 792ns -> 811ns (2.34% slower)
class TestLargeScale:
"""Test cases for large scale data and performance scenarios."""
def test_large_number_of_exception_comparisons(self):
"""Test comparing a large number of exception pairs."""
# Create 500 ValueError instances
exceptions = [ValueError(f"error_{i}") for i in range(500)]
# Compare each with first exception
for exc in exceptions:
codeflash_output = _exceptions_are_equivalent(exceptions[0], exc) # 132μs -> 140μs (5.94% slower)
def test_large_number_of_different_exception_types(self):
"""Test with many different exception types."""
exception_types = [
ValueError,
TypeError,
RuntimeError,
KeyError,
IndexError,
AttributeError,
ZeroDivisionError,
IOError,
OSError,
ImportError,
NotImplementedError,
StopIteration,
ArithmeticError,
BufferError,
LookupError,
MemoryError,
]
# Test each type is only equivalent to itself
for i, exc_type1 in enumerate(exception_types):
exc1 = exc_type1("message")
for j, exc_type2 in enumerate(exception_types):
exc2 = exc_type2("message")
if i == j:
codeflash_output = _exceptions_are_equivalent(exc1, exc2)
else:
codeflash_output = _exceptions_are_equivalent(exc1, exc2)
def test_very_large_exception_messages(self):
"""Test with exceptions having extremely large messages."""
# Create a 1MB string
large_msg1 = "x" * (1024 * 1024)
large_msg2 = "y" * (1024 * 1024)
exc1 = ValueError(large_msg1)
exc2 = ValueError(large_msg2)
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 1.06μs -> 922ns (15.2% faster)
def test_exception_comparison_performance(self):
"""Test performance of exception comparison with various inputs."""
import time
# Create many exception pairs
pairs = []
for i in range(1000):
if i % 2 == 0:
pairs.append((ValueError(f"msg_{i}"), ValueError(f"msg_{i + 1}")))
else:
pairs.append((ValueError(f"msg_{i}"), TypeError(f"msg_{i + 1}")))
# Perform comparisons and measure time
start_time = time.time()
for exc1, exc2 in pairs:
_exceptions_are_equivalent(exc1, exc2) # 353μs -> 289μs (22.3% faster)
end_time = time.time()
def test_many_torch_runtime_errors(self):
"""Test handling of many TorchRuntimeError instances."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
# Create 500 TorchRuntimeError instances wrapping different base exceptions
torch_errors = []
base_errors = []
for i in range(500):
base_exc = ValueError(f"base_{i}")
torch_exc = TorchRuntimeError(f"wrapped_{i}")
torch_exc.__cause__ = base_exc
torch_errors.append(torch_exc)
base_errors.append(base_exc)
# All should be equivalent to their corresponding base exceptions
for i in range(500):
codeflash_output = _exceptions_are_equivalent(
base_errors[i], torch_errors[i]
) # 351μs -> 299μs (17.4% faster)
def test_exception_types_with_long_names(self):
"""Test exception type comparison with very long type names."""
# Create exception classes with long names
long_name1 = "VeryLongExceptionNameWith" + "x" * 500
long_name2 = "VeryLongExceptionNameWith" + "y" * 500
ExceptionClass1 = type(long_name1, (Exception,), {})
ExceptionClass2 = type(long_name2, (Exception,), {})
exc1 = ExceptionClass1("message")
exc2 = ExceptionClass2("message")
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 1.09μs -> 791ns (38.1% faster)
def test_exception_comparison_with_circular_references(self):
"""Test exception comparison with objects containing circular references."""
# Create a circular reference in exception args
circular_list = [1, 2, 3]
circular_list.append(circular_list)
exc1 = ValueError(circular_list)
exc2 = ValueError([4, 5, 6])
# Should still compare by type name successfully
codeflash_output = _exceptions_are_equivalent(exc1, exc2) # 861ns -> 892ns (3.48% slower)
def test_batched_torch_error_comparison(self):
"""Test batch processing of many torch runtime error comparisons."""
class TorchRuntimeError(BaseException):
pass
TorchRuntimeError.__module__ = "torch._dynamo.exc"
TorchRuntimeError.__name__ = "TorchRuntimeError"
# Create 100 base exceptions and 100 torch wrapped versions
base_exceptions = [ValueError(f"base_{i}") for i in range(100)]
torch_exceptions = []
for i in range(100):
torch_exc = TorchRuntimeError(f"wrapped_{i}")
torch_exc.__cause__ = ValueError(f"wrapped_base_{i}")
torch_exceptions.append(torch_exc)
# Compare each base with each torch exception of same index
results = []
for i in range(100):
codeflash_output = _exceptions_are_equivalent(base_exceptions[i], torch_exceptions[i])
result = codeflash_output # 72.4μs -> 62.3μs (16.2% faster)
results.append(result)
def test_stress_test_many_exception_instantiations(self):
"""Stress test creating and comparing many exception instances."""
# Create 1000 exception instances
exceptions = []
for i in range(1000):
exc_type = ValueError if i % 2 == 0 else TypeError
exceptions.append(exc_type(f"message_{i}"))
# Compare each exception with a few others
equivalences = 0
non_equivalences = 0
for i in range(0, 1000, 10):
for j in range(i, min(i + 5, 1000)):
if _exceptions_are_equivalent(exceptions[i], exceptions[j]):
equivalences += 1
else:
non_equivalences += 1
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.from codeflash.verification.comparator import _exceptions_are_equivalent
def test__exceptions_are_equivalent():
_exceptions_are_equivalent(BaseException(), BaseException())🔎 Click to see Concolic Coverage Tests
| Test File::Test Function | Original ⏱️ | Optimized ⏱️ | Speedup |
|---|---|---|---|
codeflash_concolic_g5uubgk6/tmpkbpq1ztw/test_concolic_coverage.py::test__exceptions_are_equivalent |
952ns | 832ns | 14.4%✅ |
To test or edit this optimization locally git merge codeflash/optimize-pr1119-2026-01-20T05.30.28
Click to see suggested changes
| # If types match, they're potentially equivalent (will be compared further) | |
| if type(orig).__name__ == type(new).__name__: | |
| return True | |
| # Check if one is a TorchRuntimeError wrapping the other's type | |
| if _is_torch_runtime_error(new): | |
| wrapped = _get_wrapped_exception_from_torch_runtime_error(new) | |
| if wrapped is not None and type(wrapped).__name__ == type(orig).__name__: | |
| return True | |
| # Also check the error message for the wrapped exception type | |
| # TorchRuntimeError message contains "got ExceptionType('...')" | |
| error_msg = str(new) | |
| orig_type_name = type(orig).__name__ | |
| if f"got {orig_type_name}(" in error_msg or f"got {orig_type_name}:" in error_msg: | |
| return True | |
| if _is_torch_runtime_error(orig): | |
| wrapped = _get_wrapped_exception_from_torch_runtime_error(orig) | |
| if wrapped is not None and type(wrapped).__name__ == type(new).__name__: | |
| return True | |
| error_msg = str(orig) | |
| new_type_name = type(new).__name__ | |
| # Cache type lookups to avoid redundant attribute access | |
| orig_type = type(orig) | |
| new_type = type(new) | |
| orig_type_name = orig_type.__name__ | |
| new_type_name = new_type.__name__ | |
| # If types match, they're potentially equivalent (will be compared further) | |
| if orig_type_name == new_type_name: | |
| return True | |
| # Check if new is a TorchRuntimeError wrapping the other's type | |
| if new_type_name == "TorchRuntimeError" and "torch._dynamo" in new_type.__module__: | |
| wrapped = _get_wrapped_exception_from_torch_runtime_error(new) | |
| if wrapped is not None and type(wrapped).__name__ == orig_type_name: | |
| return True | |
| # Also check the error message for the wrapped exception type | |
| # TorchRuntimeError message contains "got ExceptionType('...')" | |
| error_msg = str(new) | |
| if f"got {orig_type_name}(" in error_msg or f"got {orig_type_name}:" in error_msg: | |
| return True | |
| # Check if orig is a TorchRuntimeError wrapping the other's type | |
| if orig_type_name == "TorchRuntimeError" and "torch._dynamo" in orig_type.__module__: | |
| wrapped = _get_wrapped_exception_from_torch_runtime_error(orig) | |
| if wrapped is not None and type(wrapped).__name__ == new_type_name: | |
| return True | |
| error_msg = str(orig) |
|
in favor of a more general solution in a different pr |
No description provided.