Skip to content

Conversation

@aseembits93
Copy link
Contributor

No description provided.

@aseembits93
Copy link
Contributor Author

@claude review this

@claude
Copy link

claude bot commented Jan 20, 2026

Claude finished @aseembits93's task in 1m 20s —— View job


Code Review Complete

  • Read repository guidelines (CLAUDE.md)
  • Read changed files
  • Analyze changes against origin/main
  • Review code quality and identify issues
  • Post comprehensive review feedback

Summary

This PR adds special handling for TorchRuntimeError exceptions in the comparator to correctly compare exceptions when one is wrapped by torch.compile's TorchRuntimeError and the other is the original unwrapped exception. The implementation is well-structured with comprehensive test coverage.

Detailed Review

✅ Strengths

  1. Well-designed solution: The approach of unwrapping TorchRuntimeError to compare the underlying exception is sound and handles the real-world use case of torch.compile wrapping exceptions.

  2. Comprehensive test coverage: The new test file test_torch_runtime_error_wrapping() covers 7 different scenarios including:

    • TorchRuntimeError with __cause__ chain
    • TorchRuntimeError without __cause__ but with matching error type in message
    • Different exception types (negative cases)
    • Both directions of comparison (orig vs new, new vs orig)
    • Nested structures (tuples containing exceptions)
  3. Non-invasive implementation: The changes don't affect existing comparator logic for non-exception types.

  4. Good documentation: Helper functions have clear docstrings explaining their purpose.

🔍 Issues and Concerns

1. Redundant Logic in Exception Comparison (codeflash/verification/comparator.py:143-166)

The exception comparison logic at lines 143-166 largely duplicates the logic already performed in _exceptions_are_equivalent() at lines 43-74. This creates maintenance burden and potential inconsistency.

Current flow:

  1. Lines 96-101: Call _exceptions_are_equivalent() to check if exceptions match
  2. Lines 143-166: Re-implement similar unwrapping logic using _is_torch_runtime_error() and _get_wrapped_exception_from_torch_runtime_error()

Recommendation: Consolidate this logic. The unwrapping in lines 143-166 could be simplified or the logic in _exceptions_are_equivalent() could be enhanced to handle the unwrapping and attribute comparison together.

2. Potential Issue with Early Return (codeflash/verification/comparator.py:152-155, 162-165)

When the wrapped exception type matches but can't be extracted, the code returns True immediately:

elif f"got {type(orig).__name__}(" in str(new) or f"got {type(orig).__name__}:" in str(new):
    # The wrapped exception type matches, but we can't extract it
    # Consider them equivalent since the same error occurred
    return True

This returns True based solely on exception type matching in the error message, without comparing exception attributes. This could lead to false positives where two exceptions of the same type but with different attributes (e.g., different error messages or data) are considered equivalent.

Example scenario:

orig = IndexError("index 0 out of bounds")
new = TorchRuntimeError("got IndexError('index 5 out of bounds')")  # Different index!
# These would be considered equivalent even though they represent different errors

Recommendation: Either:

  • Document this trade-off explicitly (why early return is acceptable)
  • Or enhance the logic to extract and compare exception details from the error message string

3. String Matching Fragility (codeflash/verification/comparator.py:62-63, 71-72, 152-153, 162-163)

The error message parsing relies on specific string patterns:

if f"got {orig_type_name}(" in error_msg or f"got {orig_type_name}:" in error_msg:

This is fragile because:

  • If PyTorch changes the TorchRuntimeError message format, this will break silently
  • The patterns "got ExceptionType(" and "got ExceptionType:" may not cover all cases
  • There's no validation that the matched string actually represents the wrapped exception

Recommendation: Add a comment explaining that this is a fallback heuristic for when __cause__ is unavailable, and consider logging a warning when falling back to string matching.

4. Missing Edge Cases in Tests

The test coverage is good but could include:

  • Testing when exception attributes differ (to verify attribute comparison works correctly)
  • Testing when both exceptions are TorchRuntimeError wrapping different exception types
  • Testing with exception types that have special characters in their names

5. Type Annotation Issue (codeflash/verification/comparator.py:27)

def _get_wrapped_exception_from_torch_runtime_error(exc: BaseException) -> Optional[BaseException]:  # noqa: FA100

The # noqa: FA100 disables the flake8-future-annotations check. This suggests the codebase may want from __future__ import annotations for proper type hint handling. Consider whether this should be addressed globally rather than suppressing it.

6. Test File Change Unexplained (code_to_optimize/tests/pytest/test_tensorflow_jit_code.py:12)

The change from:

tf = pytest.importorskip("tensorflow")

to:

import tensorflow as tf

This removes the conditional import skip, which means tests will fail if TensorFlow is not installed instead of being skipped. This seems unrelated to the torch exception handling and could be problematic.

Recommendation: Either:

  • Revert this change if it's unintentional
  • Or explain why it's needed in the PR description

📊 Code Quality

  • Readability: Good, functions are well-named and documented
  • Maintainability: Fair, some code duplication between early type checking and later exception handling
  • Testability: Excellent, comprehensive test coverage
  • Performance: Negligible impact, only affects exception comparison code path

🎯 Recommendations

  1. High Priority: Clarify or fix the TensorFlow import change in test_tensorflow_jit_code.py:12
  2. Medium Priority: Address the early return issue that skips attribute comparison (lines 152-155, 162-165)
  3. Low Priority: Consolidate the duplicate exception unwrapping logic
  4. Low Priority: Add comments about string matching fragility

Verdict

The core functionality is solid and solves a real problem. The main concerns are around edge cases and code duplication. I recommend addressing the TensorFlow import change and considering the early return behavior before merging.


Comment on lines +33 to +37
if not _is_torch_runtime_error(exc):
return None

# First try to get the chained exception via __cause__
if exc.__cause__ is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 24% (0.24x) speedup for _get_wrapped_exception_from_torch_runtime_error in codeflash/verification/comparator.py

⏱️ Runtime : 603 microseconds 486 microseconds (best of 241 runs)

📝 Explanation and details

The optimized code achieves a 24% speedup by eliminating redundant function calls and attribute lookups through inlining.

Key Optimizations

1. Function Call Elimination
The original code calls _is_torch_runtime_error(exc) as a separate function, which incurs overhead for:

  • Function call/return mechanics
  • Parameter passing
  • Stack frame setup/teardown

The optimized version inlines this check directly, removing this overhead entirely.

2. Reduced Attribute Lookups
The original code calls type(exc) twice within _is_torch_runtime_error():

type(exc).__name__ == "TorchRuntimeError" and "torch._dynamo" in type(exc).__module__

The optimized version caches type(exc) in exc_type and reuses it:

exc_type = type(exc)
if exc_type.__name__ == "TorchRuntimeError" and "torch._dynamo" in exc_type.__module__:

This eliminates one redundant type() call per invocation.

3. Streamlined Control Flow
The original version had nested conditionals checking the same condition. The optimized version combines the type check with the action (returning exc.__cause__) in a single conditional block, reducing branching overhead.

Performance Impact

Based on function_references, this function is called from _exceptions_are_equivalent() and comparator(), which are in the exception comparison hot path during verification. The comparator() function handles exception objects during test execution, meaning this optimization directly benefits:

  • Exception handling during test verification - where many exceptions may be checked rapidly
  • TorchRuntimeError unwrapping scenarios - common when using torch.compile/Dynamo

The annotated tests show consistent speedups across all test cases (15-31% faster), with particularly strong gains in:

  • Large-scale batch operations (500 exceptions: 30.7% faster)
  • Rapid successive calls (500 iterations: 22.8% faster)
  • Mixed exception type batches (21.6% faster)

This optimization is most valuable when the function is called frequently in loops or exception-heavy code paths, which aligns with its usage in the verification/comparison pipeline.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 1982 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 1 Passed
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
# imports
from codeflash.verification.comparator import _get_wrapped_exception_from_torch_runtime_error


def _is_torch_runtime_error(exc: BaseException) -> bool:
    """Check if an exception is a TorchRuntimeError from torch.compile/Dynamo."""
    # Check by class name to avoid importing torch
    return type(exc).__name__ == "TorchRuntimeError" and "torch._dynamo" in type(exc).__module__


def test_non_torch_exception_returns_none_even_with_cause():
    # Create a normal ValueError (not TorchRuntimeError) and set its __cause__.
    base = ValueError("underlying")
    exc = ValueError("wrapper")
    # Manually attach a cause as would be done by "raise ... from ..."
    exc.__cause__ = base
    # Since this is not a TorchRuntimeError (name/module don't match), function should return None
    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(exc)  # 832ns -> 701ns (18.7% faster)


def test_torch_runtime_error_with_cause_returns_cause():
    # Dynamically create a class named TorchRuntimeError with a module indicating torch._dynamo
    TorchRuntimeError = type("TorchRuntimeError", (RuntimeError,), {"__module__": "torch._dynamo"})
    # Instantiate and attach a cause exception
    underlying = RuntimeError("original problem")
    tr_exc = TorchRuntimeError("wrapped by torch")
    tr_exc.__cause__ = underlying
    # The function must detect the TorchRuntimeError (by name and module substring) and return the cause
    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(tr_exc)
    result = codeflash_output  # 1.23μs -> 982ns (25.5% faster)


def test_torch_runtime_error_without_cause_returns_none():
    # TorchRuntimeError-like instance but without a __cause__ should yield None
    TorchRuntimeError = type("TorchRuntimeError", (RuntimeError,), {"__module__": "torch._dynamo"})
    tr_exc = TorchRuntimeError("no cause attached")
    # Explicitly ensure no cause
    tr_exc.__cause__ = None
    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(tr_exc)  # 1.13μs -> 911ns (24.3% faster)


def test_torch_runtime_error_with_context_but_no_cause_returns_none():
    # If only __context__ is set (implicit exception chaining) but __cause__ is None, return None
    TorchRuntimeError = type("TorchRuntimeError", (RuntimeError,), {"__module__": "torch._dynamo"})
    tr_exc = TorchRuntimeError("context but no explicit cause")
    # Set context (what happens when an exception occurs during exception handling)
    tr_exc.__context__ = ValueError("contextual")
    tr_exc.__cause__ = None
    # The function explicitly checks only __cause__
    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(tr_exc)  # 1.10μs -> 862ns (27.8% faster)


def test_similar_name_but_wrong_module_returns_none():
    # Same class name but module does not contain "torch._dynamo" -> should not be treated as torch runtime error
    TorchRuntimeError = type("TorchRuntimeError", (RuntimeError,), {"__module__": "not_torch.dynamo"})
    underlying = KeyError("oops")
    tr_exc = TorchRuntimeError("looks similar")
    tr_exc.__cause__ = underlying
    # Because the module string doesn't include the expected substring, result must be None
    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(tr_exc)  # 931ns -> 801ns (16.2% faster)


def test_module_contains_substring_matches_even_with_extra_path():
    # Module containing the substring should match, even if it's a longer path
    TorchRuntimeError = type("TorchRuntimeError", (RuntimeError,), {"__module__": "some.prefix.torch._dynamo.suffix"})
    underlying = IndexError("index problem")
    tr_exc = TorchRuntimeError("wrapped")
    tr_exc.__cause__ = underlying
    # Should detect by substring and return the underlying exception
    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(tr_exc)  # 1.21μs -> 922ns (31.5% faster)


def test_nested_cause_returns_only_first_level_cause():
    # If the wrapped cause itself has a __cause__, the function should only return the first-level __cause__
    TorchRuntimeError = type("TorchRuntimeError", (RuntimeError,), {"__module__": "torch._dynamo"})
    inner = ValueError("inner-most")
    middle = TypeError("middle")
    middle.__cause__ = inner  # middle chains to inner
    outer = TorchRuntimeError("outer wrapper")
    outer.__cause__ = middle  # outer chains to middle
    # The function should return 'middle', not 'inner'
    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(outer)
    result = codeflash_output  # 1.33μs -> 1.06μs (25.5% faster)


def test_large_scale_many_wrapped_errors_iterable():
    # Large scale test: create a sizeable collection (but < 1000) of TorchRuntimeError-like instances
    # Each should unwrap to its respective cause; this checks for consistent behavior across many instances.
    count = 500  # below the 1000-step loop constraint
    TorchRuntimeError = type("TorchRuntimeError", (RuntimeError,), {"__module__": "torch._dynamo"})
    wrappers = []
    causes = []
    for i in range(count):
        cause = RuntimeError(f"cause-{i}")  # distinct cause objects
        tr_exc = TorchRuntimeError(f"wrapper-{i}")
        tr_exc.__cause__ = cause
        wrappers.append(tr_exc)
        causes.append(cause)

    # Verify each wrapper returns the exact corresponding cause
    for idx, w in enumerate(wrappers):
        codeflash_output = _get_wrapped_exception_from_torch_runtime_error(w)  # 165μs -> 126μs (30.7% faster)


def test_is_torch_runtime_error_helper_identifies_only_matching_types():
    # Validate helper behavior for a mix of objects to ensure the gatekeeping logic is correct.
    TorchRuntimeError = type("TorchRuntimeError", (RuntimeError,), {"__module__": "torch._dynamo"})
    other_named = type("TorchRuntimeError", (RuntimeError,), {"__module__": "torch.something_else"})
    other_module = type("OtherError", (RuntimeError,), {"__module__": "torch._dynamo"})

    t1 = TorchRuntimeError("match")
    t2 = other_named("wrong module")
    t3 = other_module("wrong name")


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
from codeflash.verification.comparator import _get_wrapped_exception_from_torch_runtime_error


# Function to test
def _is_torch_runtime_error(exc: BaseException) -> bool:
    """Check if an exception is a TorchRuntimeError from torch.compile/Dynamo."""
    # Check by class name to avoid importing torch
    return type(exc).__name__ == "TorchRuntimeError" and "torch._dynamo" in type(exc).__module__


# Helper to create a mock TorchRuntimeError with the correct module and class name
class TorchRuntimeError(Exception):
    """Mock TorchRuntimeError that simulates torch._dynamo.TorchRuntimeError."""


def test_non_torch_exception_returns_none():
    """Test that non-TorchRuntimeError exceptions return None."""
    # Standard ValueError should not be considered a TorchRuntimeError
    exc = ValueError("Standard error")
    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(exc)
    result = codeflash_output  # 922ns -> 701ns (31.5% faster)


def test_torch_error_with_cause_returns_cause():
    """Test that TorchRuntimeError with __cause__ returns the wrapped exception."""
    # Create a wrapped exception
    original_exc = RuntimeError("Original error")

    # Create a TorchRuntimeError with the original as __cause__
    torch_exc = TorchRuntimeError("Torch runtime error")
    torch_exc.__cause__ = original_exc

    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
    result = codeflash_output  # 1.08μs -> 862ns (25.5% faster)


def test_torch_error_without_cause_returns_none():
    """Test that TorchRuntimeError without __cause__ returns None."""
    # Create a TorchRuntimeError without a cause
    torch_exc = TorchRuntimeError("Torch runtime error")

    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
    result = codeflash_output  # 911ns -> 752ns (21.1% faster)


def test_torch_error_with_none_cause_returns_none():
    """Test that TorchRuntimeError with explicitly None __cause__ returns None."""
    # Create a TorchRuntimeError with explicit None cause
    torch_exc = TorchRuntimeError("Torch runtime error")
    torch_exc.__cause__ = None

    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
    result = codeflash_output  # 902ns -> 781ns (15.5% faster)


def test_standard_exception_types_return_none():
    """Test that various standard exception types return None."""
    # Test with different standard exception types
    exceptions = [
        TypeError("type error"),
        AttributeError("attribute error"),
        KeyError("key error"),
        IndexError("index error"),
        ZeroDivisionError("division error"),
    ]

    for exc in exceptions:
        codeflash_output = _get_wrapped_exception_from_torch_runtime_error(exc)
        result = codeflash_output  # 2.08μs -> 1.63μs (27.6% faster)


def test_torch_error_wraps_value_error():
    """Test extraction of ValueError from TorchRuntimeError."""
    # Create a ValueError and wrap it
    original_exc = ValueError("Wrapped value error")
    torch_exc = TorchRuntimeError("Torch error")
    torch_exc.__cause__ = original_exc

    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
    result = codeflash_output  # 1.00μs -> 812ns (23.3% faster)


def test_torch_error_with_nested_cause_chain():
    """Test extraction when exception has nested cause chains."""
    # Create a chain: original -> wrapped1 -> wrapped2 (torch)
    original = RuntimeError("Original")
    wrapped1 = TypeError("Wrapped once")
    wrapped1.__cause__ = original

    torch_exc = TorchRuntimeError("Torch error")
    torch_exc.__cause__ = wrapped1

    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
    result = codeflash_output  # 892ns -> 711ns (25.5% faster)


def test_exception_with_empty_message():
    """Test handling of exceptions with empty messages."""
    original_exc = RuntimeError("")
    torch_exc = TorchRuntimeError("")
    torch_exc.__cause__ = original_exc

    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
    result = codeflash_output  # 882ns -> 812ns (8.62% faster)


def test_exception_with_special_characters_in_message():
    """Test exceptions with special characters in messages."""
    original_exc = RuntimeError("Error\nwith\nnewlines\tand\ttabs")
    torch_exc = TorchRuntimeError("Torch error")
    torch_exc.__cause__ = original_exc

    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
    result = codeflash_output  # 941ns -> 812ns (15.9% faster)


def test_exception_with_unicode_message():
    """Test exceptions with unicode characters."""
    original_exc = RuntimeError("Error with unicode: \u00e9\u00e8\u00ea \u4e2d\u6587")
    torch_exc = TorchRuntimeError("Torch error")
    torch_exc.__cause__ = original_exc

    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
    result = codeflash_output  # 912ns -> 801ns (13.9% faster)


def test_torch_error_with_context_vs_cause():
    """Test that only __cause__ is extracted, not __context__."""
    # Create exceptions with context but no cause
    original_exc = RuntimeError("Original")
    torch_exc = TorchRuntimeError("Torch error")

    # Simulate an implicit exception context
    torch_exc.__context__ = original_exc
    torch_exc.__cause__ = None

    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
    result = codeflash_output  # 922ns -> 862ns (6.96% faster)


def test_wrong_module_name_returns_none():
    """Test that exception with wrong module name returns None."""
    # Create an exception with correct class name but wrong module
    exc = TorchRuntimeError("Error")
    exc.__module__ = "wrong_module"

    # This should return None because module is not torch._dynamo
    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(exc)
    result = codeflash_output  # 911ns -> 781ns (16.6% faster)


def test_wrong_class_name_returns_none():
    """Test that exception with wrong class name returns None."""

    # Create a custom exception with torch._dynamo module but wrong name
    class WrongError(Exception):
        pass

    WrongError.__module__ = "torch._dynamo"
    exc = WrongError("Error")

    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(exc)
    result = codeflash_output  # 642ns -> 552ns (16.3% faster)


def test_exception_object_identity_preserved():
    """Test that the returned exception is the same object, not a copy."""
    original_exc = RuntimeError("Original")
    torch_exc = TorchRuntimeError("Torch error")
    torch_exc.__cause__ = original_exc

    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
    result = codeflash_output  # 892ns -> 851ns (4.82% faster)


def test_torch_error_with_custom_exception_subclass():
    """Test extraction of custom exception subclass from TorchRuntimeError."""

    class CustomError(Exception):
        """Custom exception for testing."""

    original_exc = CustomError("Custom error message")
    torch_exc = TorchRuntimeError("Torch error")
    torch_exc.__cause__ = original_exc

    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
    result = codeflash_output  # 952ns -> 822ns (15.8% faster)


def test_exception_with_traceback():
    """Test handling of exceptions with traceback information."""
    try:
        1 / 0
    except ZeroDivisionError as e:
        original_exc = e

    torch_exc = TorchRuntimeError("Torch error")
    torch_exc.__cause__ = original_exc

    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
    result = codeflash_output  # 911ns -> 802ns (13.6% faster)


def test_none_input():
    """Test that None input raises appropriate error or returns None."""
    # The function signature expects BaseException, but test robustness
    try:
        codeflash_output = _get_wrapped_exception_from_torch_runtime_error(None)
        result = codeflash_output
    except (AttributeError, TypeError):
        # AttributeError is expected when calling methods on None
        pass


def test_torch_error_module_substring_not_matching():
    """Test that partial module name matches are not accepted."""
    # Create exception where torch._dynamo is a substring but not exact
    exc = TorchRuntimeError("Error")
    exc.__module__ = "notorch._dynamo"

    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(exc)
    result = codeflash_output  # 932ns -> 821ns (13.5% faster)


def test_multiple_sequential_torch_errors():
    """Test handling of multiple TorchRuntimeError instances sequentially."""
    # Create 100 different wrapped exceptions
    results = []
    for i in range(100):
        original_exc = RuntimeError(f"Original error {i}")
        torch_exc = TorchRuntimeError(f"Torch error {i}")
        torch_exc.__cause__ = original_exc

        codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
        result = codeflash_output  # 29.7μs -> 24.2μs (22.6% faster)
        results.append(result)
    for i, result in enumerate(results):
        pass


def test_large_batch_mixed_exception_types():
    """Test function with large batch of mixed exception types."""
    # Create 500 exceptions: mix of TorchRuntimeError and other types
    exceptions = []
    for i in range(500):
        if i % 2 == 0:
            # Create TorchRuntimeError
            original = RuntimeError(f"Error {i}")
            torch_exc = TorchRuntimeError(f"Torch {i}")
            torch_exc.__cause__ = original
            exceptions.append(torch_exc)
        else:
            # Create standard exception
            exceptions.append(ValueError(f"Standard {i}"))

    results = []
    torch_count = 0
    for exc in exceptions:
        codeflash_output = _get_wrapped_exception_from_torch_runtime_error(exc)
        result = codeflash_output  # 141μs -> 116μs (21.6% faster)
        results.append(result)
        if result is not None:
            torch_count += 1


def test_deep_cause_chain_extraction():
    """Test performance with deep exception chains (only direct cause extracted)."""
    # Create a deep chain: a0 -> a1 -> a2 -> ... -> torch_exc
    torch_exc = TorchRuntimeError("Torch error")
    original = RuntimeError("Original")
    torch_exc.__cause__ = original

    # The function should only extract the direct cause
    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
    result = codeflash_output  # 971ns -> 911ns (6.59% faster)


def test_large_exception_messages():
    """Test handling of exceptions with very large messages."""
    # Create an exception with a large message (but < 1MB)
    large_message = "x" * 100000  # 100KB message
    original_exc = RuntimeError(large_message)
    torch_exc = TorchRuntimeError("Torch error")
    torch_exc.__cause__ = original_exc

    codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
    result = codeflash_output  # 992ns -> 892ns (11.2% faster)


def test_performance_many_checks_non_torch_exceptions():
    """Test performance checking many non-TorchRuntimeError exceptions."""
    # Generate 200 standard exceptions
    exceptions = [ValueError(f"Error {i}") for i in range(200)]

    results = []
    for exc in exceptions:
        codeflash_output = _get_wrapped_exception_from_torch_runtime_error(exc)
        result = codeflash_output  # 50.7μs -> 42.4μs (19.5% faster)
        results.append(result)


def test_torch_errors_with_same_wrapped_exception():
    """Test multiple TorchRuntimeErrors wrapping the same exception."""
    # Create one original exception
    shared_original = RuntimeError("Shared original")

    # Wrap it in multiple TorchRuntimeErrors
    torch_errors = []
    for i in range(50):
        torch_exc = TorchRuntimeError(f"Torch error {i}")
        torch_exc.__cause__ = shared_original
        torch_errors.append(torch_exc)

    results = []
    for torch_exc in torch_errors:
        codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
        result = codeflash_output  # 15.4μs -> 12.5μs (23.1% faster)
        results.append(result)


def test_rapid_successive_calls():
    """Test rapid successive calls to verify no state corruption."""
    original_exc = RuntimeError("Original")
    torch_exc = TorchRuntimeError("Torch error")
    torch_exc.__cause__ = original_exc

    # Call function 500 times rapidly
    results = []
    for _ in range(500):
        codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
        result = codeflash_output  # 142μs -> 116μs (22.8% faster)
        results.append(result)


def test_varied_wrapped_exception_types_at_scale():
    """Test at scale with many different wrapped exception types."""
    exception_types = [
        RuntimeError,
        ValueError,
        TypeError,
        AttributeError,
        KeyError,
        IndexError,
        ZeroDivisionError,
        NotImplementedError,
        AssertionError,
        OSError,
    ]

    results = []
    for i in range(100):
        exc_type = exception_types[i % len(exception_types)]
        original_exc = exc_type(f"Message {i}")
        torch_exc = TorchRuntimeError(f"Torch error {i}")
        torch_exc.__cause__ = original_exc

        codeflash_output = _get_wrapped_exception_from_torch_runtime_error(torch_exc)
        result = codeflash_output  # 30.0μs -> 24.3μs (23.6% faster)
        results.append(result)

    # Verify types match
    for i, result in enumerate(results):
        expected_type = exception_types[i % len(exception_types)]


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
from codeflash.verification.comparator import _get_wrapped_exception_from_torch_runtime_error


def test__get_wrapped_exception_from_torch_runtime_error():
    _get_wrapped_exception_from_torch_runtime_error(BaseException())
🔎 Click to see Concolic Coverage Tests
Test File::Test Function Original ⏱️ Optimized ⏱️ Speedup
codeflash_concolic_g5uubgk6/tmpbtersvva/test_concolic_coverage.py::test__get_wrapped_exception_from_torch_runtime_error 842ns 651ns 29.3%✅

To test or edit this optimization locally git merge codeflash/optimize-pr1119-2026-01-20T05.21.22

Suggested change
if not _is_torch_runtime_error(exc):
return None
# First try to get the chained exception via __cause__
if exc.__cause__ is not None:
# Inline the check to avoid redundant function call and attribute lookups
exc_type = type(exc)
if exc_type.__name__ == "TorchRuntimeError" and "torch._dynamo" in exc_type.__module__:

Static Badge

Comment on lines +49 to +70
# If types match, they're potentially equivalent (will be compared further)
if type(orig).__name__ == type(new).__name__:
return True

# Check if one is a TorchRuntimeError wrapping the other's type
if _is_torch_runtime_error(new):
wrapped = _get_wrapped_exception_from_torch_runtime_error(new)
if wrapped is not None and type(wrapped).__name__ == type(orig).__name__:
return True
# Also check the error message for the wrapped exception type
# TorchRuntimeError message contains "got ExceptionType('...')"
error_msg = str(new)
orig_type_name = type(orig).__name__
if f"got {orig_type_name}(" in error_msg or f"got {orig_type_name}:" in error_msg:
return True

if _is_torch_runtime_error(orig):
wrapped = _get_wrapped_exception_from_torch_runtime_error(orig)
if wrapped is not None and type(wrapped).__name__ == type(new).__name__:
return True
error_msg = str(orig)
new_type_name = type(new).__name__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 16% (0.16x) speedup for _exceptions_are_equivalent in codeflash/verification/comparator.py

⏱️ Runtime : 1.22 milliseconds 1.05 milliseconds (best of 146 runs)

📝 Explanation and details

The optimized code achieves a 15% speedup by reducing redundant attribute lookups through strategic caching of type() results and their attributes.

Key Optimizations

1. Cached Type Lookups in _is_torch_runtime_error

The original code called type(exc) twice per invocation:

return type(exc).__name__ == "TorchRuntimeError" and "torch._dynamo" in type(exc).__module__

The optimized version caches the result:

exc_type = type(exc)
return exc_type.__name__ == "TorchRuntimeError" and "torch._dynamo" in exc_type.__module__

This eliminates one type() call per check, which is significant since this function is called frequently (3,164 hits in the original, 611 in the optimization test).

2. Upfront Type Attribute Caching in _exceptions_are_equivalent

The most impactful optimization occurs in _exceptions_are_equivalent, where type information is cached at the function start:

orig_type = type(orig)
new_type = type(new)
orig_type_name = orig_type.__name__
new_type_name = new_type.__name__

This eliminates repeated type() calls throughout the function. The line profiler shows this reduces time spent on type checking from ~34% to ~4.7% of total execution time.

3. Inlined TorchRuntimeError Checks

Instead of calling _is_torch_runtime_error() (which itself performs type lookups), the optimized code directly checks the cached type attributes:

if new_type_name == "TorchRuntimeError" and "torch._dynamo" in new_type.__module__:

This avoids the overhead of function calls and redundant type operations.

Performance Impact by Test Case

The optimization particularly excels in scenarios involving:

  • Different exception types (46.8-66.8% faster): Cached type names enable quick inequality checks without repeated lookups
  • TorchRuntimeError unwrapping (14.3-26% faster): Direct type checks bypass function call overhead
  • Large-scale comparisons (17.4-22.3% faster on 500-1000 iterations): Reduced per-comparison overhead compounds significantly

The optimization shows minimal regression (~2-5% slower) only in edge cases with very few comparisons where the upfront caching overhead slightly exceeds savings.

Context and Importance

Based on function_references, _exceptions_are_equivalent is called from the hot path comparator() function, which is a recursive comparison utility handling various data types. Since comparator() may be called frequently during verification workflows (especially with large data structures or exception chains), even a 15% improvement translates to meaningful time savings in production validation scenarios.

The optimization maintains correctness by preserving all logical branches while simply reusing computed type information instead of repeatedly querying the Python object model.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 2938 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 1 Passed
📊 Tests Coverage 78.9%
🌀 Click to see Generated Regression Tests
from typing import Optional

# imports
from codeflash.verification.comparator import _exceptions_are_equivalent


def _is_torch_runtime_error(exc: BaseException) -> bool:
    """Check if an exception is a TorchRuntimeError from torch.compile/Dynamo."""
    # Check by class name to avoid importing torch
    return type(exc).__name__ == "TorchRuntimeError" and "torch._dynamo" in type(exc).__module__


def _get_wrapped_exception_from_torch_runtime_error(exc: BaseException) -> Optional[BaseException]:  # noqa: FA100
    """Extract the underlying exception wrapped by TorchRuntimeError.

    TorchRuntimeError from torch.compile/Dynamo wraps the original exception
    using Python's exception chaining (__cause__).
    """
    if not _is_torch_runtime_error(exc):
        return None

    # First try to get the chained exception via __cause__
    if exc.__cause__ is not None:
        return exc.__cause__

    return None


def test_same_builtin_type_is_equivalent():
    # Basic case: identical built-in exception types should be equivalent
    orig = ValueError("original")
    new = ValueError("new")
    # types have the same __name__ so should be True
    codeflash_output = _exceptions_are_equivalent(orig, new)  # 882ns -> 861ns (2.44% faster)


def test_same_type_name_different_classes_equivalent():
    # If types have the same name (even if different classes), equivalence is based on name
    # Create a custom exception class named 'ValueError' (different object than built-in)
    CustomValueError = type("ValueError", (Exception,), {})  # class name matches built-in
    orig = ValueError("built-in")  # built-in ValueError
    new = CustomValueError("custom")  # custom class with the same name
    # Should be True because only the __name__ is compared in the function's first check
    codeflash_output = _exceptions_are_equivalent(orig, new)  # 791ns -> 802ns (1.37% slower)


def test_torch_runtime_error_with_cause_matches_orig():
    # Simulate a TorchRuntimeError wrapper by creating a class with the same name
    # and a module path containing 'torch._dynamo'
    TorchRuntimeErrorFake = type("TorchRuntimeError", (Exception,), {})
    TorchRuntimeErrorFake.__module__ = "torch._dynamo.fake_module"  # include required substring

    # Underlying exception that was originally raised
    underlying = TypeError("underlying problem")
    # Create the fake TorchRuntimeError and attach the underlying exception via __cause__
    wrapped = TorchRuntimeErrorFake("wrapper message")
    wrapped.__cause__ = underlying

    # orig is the underlying exception, new is the wrapper: should be considered equivalent
    codeflash_output = _exceptions_are_equivalent(underlying, wrapped)  # 2.23μs -> 1.95μs (14.3% faster)


def test_torch_runtime_error_wraps_new_side_equivalence():
    # Reverse of previous: orig is a wrapper, new is the underlying exception
    TorchRuntimeErrorFake = type("TorchRuntimeError", (Exception,), {})
    TorchRuntimeErrorFake.__module__ = "torch._dynamo.other"

    underlying = KeyError("missing")
    wrapper = TorchRuntimeErrorFake("wrapped key error")
    wrapper.__cause__ = underlying

    # orig is wrapper, new is underlying -> should be True
    codeflash_output = _exceptions_are_equivalent(wrapper, underlying)  # 2.37μs -> 1.88μs (26.0% faster)


def test_torch_runtime_error_message_parenthesis_format_matches():
    # TorchRuntimeError without a __cause__ but with a message containing "got TypeError('...')"
    TorchRuntimeErrorFake = type("TorchRuntimeError", (Exception,), {})
    TorchRuntimeErrorFake.__module__ = "torch._dynamo.something"

    # Craft message that matches the "got <TypeName>(" pattern
    msg = "some runtime wrapper; got TypeError('oops') while compiling"
    wrapper = TorchRuntimeErrorFake(msg)

    orig = TypeError("oops")
    # Since the message contains "got TypeError(" the function should report equivalence
    codeflash_output = _exceptions_are_equivalent(orig, wrapper)  # 2.50μs -> 2.38μs (5.08% faster)


def test_torch_runtime_error_message_colon_format_matches():
    # TorchRuntimeError message with the "got TypeError:" pattern should match as well
    TorchRuntimeErrorFake = type("TorchRuntimeError", (Exception,), {})
    TorchRuntimeErrorFake.__module__ = "torch._dynamo"

    msg = "error occurred; got TypeError: details follow"
    wrapper = TorchRuntimeErrorFake(msg)

    orig = TypeError("details follow")
    # Should be True because of the "got TypeError:" pattern
    codeflash_output = _exceptions_are_equivalent(orig, wrapper)  # 2.48μs -> 2.31μs (7.81% faster)


def test_non_equivalent_different_types():
    # Completely different exception types without any torch wrapper should not be equivalent
    orig = IndexError("index fail")
    new = ValueError("value fail")
    codeflash_output = _exceptions_are_equivalent(orig, new)  # 1.38μs -> 942ns (46.8% faster)


def test_torch_runtime_error_with_incorrect_module_not_recognized():
    # Create a class named TorchRuntimeError but with module that does NOT contain 'torch._dynamo'
    TorchRuntimeErrorNotDynamo = type("TorchRuntimeError", (Exception,), {})
    TorchRuntimeErrorNotDynamo.__module__ = "some.other.module"

    underlying = RuntimeError("bad")
    wrapper = TorchRuntimeErrorNotDynamo("wrapper")
    wrapper.__cause__ = underlying

    # Because the module doesn't include 'torch._dynamo', the wrapper should NOT be recognized,
    # and since names differ (TorchRuntimeError vs RuntimeError), result should be False
    codeflash_output = _exceptions_are_equivalent(underlying, wrapper)  # 1.55μs -> 1.14μs (36.1% faster)


def test_torch_runtime_error_message_partial_no_false_positive():
    # Ensure that similar-but-not-matching messages do not produce false positives.
    TorchRuntimeErrorFake = type("TorchRuntimeError", (Exception,), {})
    TorchRuntimeErrorFake.__module__ = "torch._dynamo.sub"

    # Message contains 'got Value(' which is not the exact searched pattern 'got ValueError('
    msg = "wrapper says got Value('x') somewhere"
    wrapper = TorchRuntimeErrorFake(msg)

    orig = ValueError("x")
    # Should be False because "got Value(" does not match "got ValueError(" or "got ValueError:"
    codeflash_output = _exceptions_are_equivalent(orig, wrapper)  # 2.85μs -> 2.60μs (10.0% faster)


def test_large_scale_many_name_based_equivalences():
    # Large-scale test: ensure performance & correctness over many name-equality cases.
    # We'll create 500 pairs (well under the 1000 element limit) where the type names match.
    pairs = []
    for i in range(500):  # 500 iterations is under the "do not exceed 1000 steps" guidance
        # Create two distinct classes that share the same __name__
        name = f"BulkExc{i}"
        A = type(name, (Exception,), {})
        B = type(name, (Exception,), {})
        # instantiate exceptions
        a_inst = A(f"a {i}")
        b_inst = B(f"b {i}")
        pairs.append((a_inst, b_inst))

    # All pairs should be equivalent because only the __name__ is checked in the first step
    for a_inst, b_inst in pairs:
        codeflash_output = _exceptions_are_equivalent(a_inst, b_inst)  # 109μs -> 112μs (2.28% slower)


def test_torch_runtime_error_message_matches_custom_exception_name():
    # Message-based match should work for non-builtins too: orig is custom-named exception
    CustomErr = type("CustomErr", (Exception,), {})
    TorchRuntimeErrorFake = type("TorchRuntimeError", (Exception,), {})
    TorchRuntimeErrorFake.__module__ = "torch._dynamo.extra"

    # Message contains the custom exception name in the expected pattern
    msg = "compilation failed; got CustomErr('x') at runtime"
    wrapper = TorchRuntimeErrorFake(msg)

    orig = CustomErr("x")
    # Should be True because message contains "got CustomErr("
    codeflash_output = _exceptions_are_equivalent(orig, wrapper)  # 2.52μs -> 2.46μs (2.48% faster)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
from codeflash.verification.comparator import _exceptions_are_equivalent


def _is_torch_runtime_error(exc: BaseException) -> bool:
    """Check if an exception is a TorchRuntimeError from torch.compile/Dynamo."""
    # Check by class name to avoid importing torch
    return type(exc).__name__ == "TorchRuntimeError" and "torch._dynamo" in type(exc).__module__


def _get_wrapped_exception_from_torch_runtime_error(exc: BaseException):
    """Extract the underlying exception wrapped by TorchRuntimeError.

    TorchRuntimeError from torch.compile/Dynamo wraps the original exception
    using Python's exception chaining (__cause__).
    """
    if not _is_torch_runtime_error(exc):
        return None

    # First try to get the chained exception via __cause__
    if exc.__cause__ is not None:
        return exc.__cause__

    return None


class TestBasicFunctionality:
    """Test cases for basic equivalence checking of standard exceptions."""

    def test_identical_exception_types(self):
        """Test that two exceptions of the same type are considered equivalent."""
        exc1 = ValueError("test error")
        exc2 = ValueError("different message")
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 922ns -> 852ns (8.22% faster)

    def test_identical_standard_exceptions(self):
        """Test equivalence with various standard Python exceptions."""
        pairs = [
            (TypeError("msg1"), TypeError("msg2")),
            (RuntimeError("msg1"), RuntimeError("msg2")),
            (KeyError("msg1"), KeyError("msg2")),
            (IndexError("msg1"), IndexError("msg2")),
            (AttributeError("msg1"), AttributeError("msg2")),
            (ZeroDivisionError("msg1"), ZeroDivisionError("msg2")),
        ]
        for exc1, exc2 in pairs:
            codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 2.79μs -> 2.80μs (0.036% slower)

    def test_different_exception_types(self):
        """Test that different exception types are not equivalent."""
        exc1 = ValueError("error")
        exc2 = TypeError("error")
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 1.35μs -> 811ns (66.8% faster)

    def test_different_exception_types_all_combinations(self):
        """Test equivalence returns False for various mismatched exception types."""
        pairs = [
            (ValueError("msg"), TypeError("msg")),
            (RuntimeError("msg"), KeyError("msg")),
            (IndexError("msg"), AttributeError("msg")),
            (ZeroDivisionError("msg"), Exception("msg")),
        ]
        for exc1, exc2 in pairs:
            codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 3.33μs -> 2.25μs (48.1% faster)

    def test_custom_exception_same_type(self):
        """Test equivalence for custom exception classes."""

        class CustomError(Exception):
            pass

        exc1 = CustomError("message1")
        exc2 = CustomError("message2")
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 601ns -> 611ns (1.64% slower)

    def test_custom_exception_different_types(self):
        """Test non-equivalence for different custom exception types."""

        class CustomError1(Exception):
            pass

        class CustomError2(Exception):
            pass

        exc1 = CustomError1("message")
        exc2 = CustomError2("message")
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 1.00μs -> 691ns (45.0% faster)

    def test_base_exception_type(self):
        """Test with base Exception type."""
        exc1 = Exception("message1")
        exc2 = Exception("message2")
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 902ns -> 801ns (12.6% faster)

    def test_exception_with_no_message(self):
        """Test equivalence for exceptions without messages."""
        exc1 = ValueError()
        exc2 = ValueError()
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 851ns -> 752ns (13.2% faster)

    def test_exception_with_empty_vs_nonempty_message(self):
        """Test equivalence is independent of message content."""
        exc1 = ValueError("")
        exc2 = ValueError("detailed error message")
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 851ns -> 731ns (16.4% faster)


class TestEdgeCases:
    """Test cases for edge conditions and unusual scenarios."""

    def test_exception_with_special_characters_in_message(self):
        """Test exceptions with special characters don't affect type comparison."""
        exc1 = ValueError("Error with @#$%^&*()")
        exc2 = ValueError("Error with \n\t\r")
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 901ns -> 752ns (19.8% faster)

    def test_exception_with_unicode_characters(self):
        """Test exceptions with unicode characters."""
        exc1 = ValueError("Error: \u2764\u2665\u263a")
        exc2 = ValueError("Error: \u00e9\u00e8\u00ea")
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 832ns -> 791ns (5.18% faster)

    def test_exception_with_very_long_message(self):
        """Test exceptions with very long messages."""
        long_msg1 = "x" * 10000
        long_msg2 = "y" * 10000
        exc1 = ValueError(long_msg1)
        exc2 = ValueError(long_msg2)
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 872ns -> 732ns (19.1% faster)

    def test_exception_with_numeric_message(self):
        """Test exceptions with numeric content in messages."""
        exc1 = ValueError(12345)
        exc2 = ValueError(67890)
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 801ns -> 701ns (14.3% faster)

    def test_exception_inheritance_not_equivalent(self):
        """Test that exception inheritance hierarchy doesn't make them equivalent."""

        class BaseCustom(Exception):
            pass

        class DerivedCustom(BaseCustom):
            pass

        exc1 = BaseCustom("message")
        exc2 = DerivedCustom("message")
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 1.07μs -> 731ns (46.5% faster)

    def test_exception_with_args_tuple(self):
        """Test exceptions constructed with multiple arguments."""
        exc1 = ValueError("arg1", "arg2", "arg3")
        exc2 = ValueError("different1", "different2")
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 832ns -> 752ns (10.6% faster)

    def test_exception_with_none_as_message(self):
        """Test exception with None as message."""
        exc1 = ValueError(None)
        exc2 = ValueError("actual message")
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 882ns -> 732ns (20.5% faster)

    def test_exception_same_instance_compared(self):
        """Test comparing the same exception instance with itself."""
        exc = ValueError("message")
        codeflash_output = _exceptions_are_equivalent(exc, exc)  # 851ns -> 712ns (19.5% faster)

    def test_exception_from_different_modules_same_name(self):
        """Test handling of different classes with same name but different modules."""

        # Create mock exception classes with same name but different identities
        class ValueError1(BaseException):
            pass

        class ValueError2(BaseException):
            pass

        # Rename them to have the same __name__ but different __module__
        ValueError1.__name__ = "SameName"
        ValueError2.__name__ = "SameName"
        ValueError1.__module__ = "module1"
        ValueError2.__module__ = "module2"

        exc1 = ValueError1("msg")
        exc2 = ValueError2("msg")
        # They should be equivalent because type names match
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 601ns -> 631ns (4.75% slower)

    def test_exception_with_multiline_message(self):
        """Test exceptions with multiline error messages."""
        msg1 = "Line1\nLine2\nLine3\nLine4"
        msg2 = "Different\nMultiline\nMessage\nHere"
        exc1 = RuntimeError(msg1)
        exc2 = RuntimeError(msg2)
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 892ns -> 802ns (11.2% faster)

    def test_exception_message_containing_type_name(self):
        """Test message that contains other exception type names."""
        exc1 = ValueError("This is a ValueError")
        exc2 = ValueError("This is a TypeError or KeyError")
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 872ns -> 831ns (4.93% faster)


class TestTorchRuntimeErrorHandling:
    """Test cases for TorchRuntimeError wrapping behavior."""

    def test_torch_runtime_error_detection(self):
        """Test that TorchRuntimeError can be detected by type name."""

        # Create a mock TorchRuntimeError-like exception
        class TorchRuntimeError(BaseException):
            pass

        # Set the module to match torch._dynamo
        TorchRuntimeError.__module__ = "torch._dynamo.exc"
        TorchRuntimeError.__name__ = "TorchRuntimeError"

        exc = TorchRuntimeError("wrapped error")

    def test_torch_runtime_error_false_for_different_module(self):
        """Test that TorchRuntimeError from different module is not detected."""

        class TorchRuntimeError(BaseException):
            pass

        TorchRuntimeError.__module__ = "some_other_module"
        TorchRuntimeError.__name__ = "TorchRuntimeError"

        exc = TorchRuntimeError("wrapped error")

    def test_get_wrapped_exception_with_cause(self):
        """Test extracting wrapped exception from __cause__."""

        # Create a TorchRuntimeError-like class
        class TorchRuntimeError(BaseException):
            pass

        TorchRuntimeError.__module__ = "torch._dynamo.exc"
        TorchRuntimeError.__name__ = "TorchRuntimeError"

        # Create the wrapped exception
        original_exc = ValueError("original error")
        torch_exc = TorchRuntimeError("torch wrapped error")
        torch_exc.__cause__ = original_exc

        wrapped = _get_wrapped_exception_from_torch_runtime_error(torch_exc)

    def test_get_wrapped_exception_returns_none_for_non_torch(self):
        """Test that non-TorchRuntimeError returns None."""
        exc = ValueError("regular error")
        wrapped = _get_wrapped_exception_from_torch_runtime_error(exc)

    def test_get_wrapped_exception_no_cause_returns_none(self):
        """Test TorchRuntimeError without __cause__ returns None."""

        class TorchRuntimeError(BaseException):
            pass

        TorchRuntimeError.__module__ = "torch._dynamo.exc"
        TorchRuntimeError.__name__ = "TorchRuntimeError"

        torch_exc = TorchRuntimeError("no cause")
        wrapped = _get_wrapped_exception_from_torch_runtime_error(torch_exc)

    def test_torch_runtime_error_wrapping_via_cause_equivalent(self):
        """Test equivalence when TorchRuntimeError wraps exception via __cause__."""

        class TorchRuntimeError(BaseException):
            pass

        TorchRuntimeError.__module__ = "torch._dynamo.exc"
        TorchRuntimeError.__name__ = "TorchRuntimeError"

        # Original exception
        exc1 = ValueError("original")

        # TorchRuntimeError wrapping the original
        exc2 = TorchRuntimeError("wrapped")
        exc2.__cause__ = ValueError("wrapped original")

        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 2.23μs -> 2.17μs (2.76% faster)

    def test_torch_runtime_error_wrapping_different_type_not_equivalent(self):
        """Test non-equivalence when TorchRuntimeError wraps different type."""

        class TorchRuntimeError(BaseException):
            pass

        TorchRuntimeError.__module__ = "torch._dynamo.exc"
        TorchRuntimeError.__name__ = "TorchRuntimeError"

        # Original exception
        exc1 = ValueError("original")

        # TorchRuntimeError wrapping a different type
        exc2 = TorchRuntimeError("wrapped")
        exc2.__cause__ = TypeError("wrapped different")

        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 2.92μs -> 2.48μs (18.2% faster)

    def test_torch_runtime_error_message_pattern_matching(self):
        """Test TorchRuntimeError message pattern matching for type name."""

        class TorchRuntimeError(BaseException):
            pass

        TorchRuntimeError.__module__ = "torch._dynamo.exc"
        TorchRuntimeError.__name__ = "TorchRuntimeError"

        # Original exception
        exc1 = ValueError("original")

        # TorchRuntimeError with message containing the wrapped type
        exc2 = TorchRuntimeError("got ValueError('wrapped error')")

        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 2.38μs -> 2.13μs (11.7% faster)

    def test_torch_runtime_error_message_pattern_colon_format(self):
        """Test TorchRuntimeError message pattern with colon format."""

        class TorchRuntimeError(BaseException):
            pass

        TorchRuntimeError.__module__ = "torch._dynamo.exc"
        TorchRuntimeError.__name__ = "TorchRuntimeError"

        # Original exception
        exc1 = RuntimeError("original")

        # TorchRuntimeError with message containing the wrapped type (colon format)
        exc2 = TorchRuntimeError("got RuntimeError: wrapped error")

        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 2.44μs -> 2.29μs (6.96% faster)

    def test_torch_runtime_error_reverse_wrapping(self):
        """Test equivalence when original is TorchRuntimeError wrapping new type."""

        class TorchRuntimeError(BaseException):
            pass

        TorchRuntimeError.__module__ = "torch._dynamo.exc"
        TorchRuntimeError.__name__ = "TorchRuntimeError"

        # Original as TorchRuntimeError
        exc1 = TorchRuntimeError("wrapped")
        exc1.__cause__ = ValueError("wrapped original")

        # New exception
        exc2 = ValueError("new")

        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 2.33μs -> 1.91μs (21.5% faster)

    def test_torch_runtime_error_both_sides(self):
        """Test when both exceptions are TorchRuntimeError with different wrapped types."""

        class TorchRuntimeError(BaseException):
            pass

        TorchRuntimeError.__module__ = "torch._dynamo.exc"
        TorchRuntimeError.__name__ = "TorchRuntimeError"

        # Both as TorchRuntimeError wrapping different types
        exc1 = TorchRuntimeError("wrapped1")
        exc1.__cause__ = ValueError("val error")

        exc2 = TorchRuntimeError("wrapped2")
        exc2.__cause__ = TypeError("type error")

        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 530ns -> 651ns (18.6% slower)


class TestComplexScenarios:
    """Test cases combining multiple features and complex scenarios."""

    def test_exception_chain_with_context(self):
        """Test exceptions that are part of an exception chain."""
        try:
            try:
                raise ValueError("original")
            except ValueError as e:
                raise TypeError("new") from e
        except TypeError as e:
            exc1 = e

        exc2 = TypeError("different message")
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 942ns -> 921ns (2.28% faster)

    def test_multiple_exception_types_in_sequence(self):
        """Test multiple different exception types don't become equivalent."""
        exceptions = [ValueError("msg"), TypeError("msg"), RuntimeError("msg"), KeyError("msg"), IndexError("msg")]

        for i in range(len(exceptions)):
            for j in range(len(exceptions)):
                if i == j:
                    codeflash_output = _exceptions_are_equivalent(exceptions[i], exceptions[j])
                else:
                    codeflash_output = _exceptions_are_equivalent(exceptions[i], exceptions[j])

    def test_exception_with_complex_args(self):
        """Test exceptions with complex argument structures."""
        exc1 = ValueError({"key": "value"}, [1, 2, 3], (4, 5, 6))
        exc2 = ValueError({"different": "dict"}, ["a", "b"], ("x", "y"))
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 871ns -> 872ns (0.115% slower)

    def test_exception_subclass_comparison(self):
        """Test comparison of exception subclasses."""

        class CustomValueError(ValueError):
            pass

        exc1 = ValueError("base")
        exc2 = CustomValueError("subclass")
        # Different type names, so not equivalent
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 1.24μs -> 821ns (51.3% faster)

    def test_exception_type_name_only_comparison(self):
        """Test that comparison is based on type name, not identity."""
        # Create two different exception classes with the same name
        exec_code = """
class DynamicException(Exception):
    pass
exc1 = DynamicException("msg1")
"""
        exec_globals1 = {}
        exec(exec_code, exec_globals1)
        exc1 = exec_globals1["exc1"]

        exec_globals2 = {}
        exec(exec_code, exec_globals2)
        exc2 = exec_globals2["exc1"]

        # Even though they're different classes, they have the same name
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 792ns -> 811ns (2.34% slower)


class TestLargeScale:
    """Test cases for large scale data and performance scenarios."""

    def test_large_number_of_exception_comparisons(self):
        """Test comparing a large number of exception pairs."""
        # Create 500 ValueError instances
        exceptions = [ValueError(f"error_{i}") for i in range(500)]

        # Compare each with first exception
        for exc in exceptions:
            codeflash_output = _exceptions_are_equivalent(exceptions[0], exc)  # 132μs -> 140μs (5.94% slower)

    def test_large_number_of_different_exception_types(self):
        """Test with many different exception types."""
        exception_types = [
            ValueError,
            TypeError,
            RuntimeError,
            KeyError,
            IndexError,
            AttributeError,
            ZeroDivisionError,
            IOError,
            OSError,
            ImportError,
            NotImplementedError,
            StopIteration,
            ArithmeticError,
            BufferError,
            LookupError,
            MemoryError,
        ]

        # Test each type is only equivalent to itself
        for i, exc_type1 in enumerate(exception_types):
            exc1 = exc_type1("message")
            for j, exc_type2 in enumerate(exception_types):
                exc2 = exc_type2("message")
                if i == j:
                    codeflash_output = _exceptions_are_equivalent(exc1, exc2)
                else:
                    codeflash_output = _exceptions_are_equivalent(exc1, exc2)

    def test_very_large_exception_messages(self):
        """Test with exceptions having extremely large messages."""
        # Create a 1MB string
        large_msg1 = "x" * (1024 * 1024)
        large_msg2 = "y" * (1024 * 1024)

        exc1 = ValueError(large_msg1)
        exc2 = ValueError(large_msg2)

        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 1.06μs -> 922ns (15.2% faster)

    def test_exception_comparison_performance(self):
        """Test performance of exception comparison with various inputs."""
        import time

        # Create many exception pairs
        pairs = []
        for i in range(1000):
            if i % 2 == 0:
                pairs.append((ValueError(f"msg_{i}"), ValueError(f"msg_{i + 1}")))
            else:
                pairs.append((ValueError(f"msg_{i}"), TypeError(f"msg_{i + 1}")))

        # Perform comparisons and measure time
        start_time = time.time()
        for exc1, exc2 in pairs:
            _exceptions_are_equivalent(exc1, exc2)  # 353μs -> 289μs (22.3% faster)
        end_time = time.time()

    def test_many_torch_runtime_errors(self):
        """Test handling of many TorchRuntimeError instances."""

        class TorchRuntimeError(BaseException):
            pass

        TorchRuntimeError.__module__ = "torch._dynamo.exc"
        TorchRuntimeError.__name__ = "TorchRuntimeError"

        # Create 500 TorchRuntimeError instances wrapping different base exceptions
        torch_errors = []
        base_errors = []
        for i in range(500):
            base_exc = ValueError(f"base_{i}")
            torch_exc = TorchRuntimeError(f"wrapped_{i}")
            torch_exc.__cause__ = base_exc
            torch_errors.append(torch_exc)
            base_errors.append(base_exc)

        # All should be equivalent to their corresponding base exceptions
        for i in range(500):
            codeflash_output = _exceptions_are_equivalent(
                base_errors[i], torch_errors[i]
            )  # 351μs -> 299μs (17.4% faster)

    def test_exception_types_with_long_names(self):
        """Test exception type comparison with very long type names."""
        # Create exception classes with long names
        long_name1 = "VeryLongExceptionNameWith" + "x" * 500
        long_name2 = "VeryLongExceptionNameWith" + "y" * 500

        ExceptionClass1 = type(long_name1, (Exception,), {})
        ExceptionClass2 = type(long_name2, (Exception,), {})

        exc1 = ExceptionClass1("message")
        exc2 = ExceptionClass2("message")

        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 1.09μs -> 791ns (38.1% faster)

    def test_exception_comparison_with_circular_references(self):
        """Test exception comparison with objects containing circular references."""
        # Create a circular reference in exception args
        circular_list = [1, 2, 3]
        circular_list.append(circular_list)

        exc1 = ValueError(circular_list)
        exc2 = ValueError([4, 5, 6])

        # Should still compare by type name successfully
        codeflash_output = _exceptions_are_equivalent(exc1, exc2)  # 861ns -> 892ns (3.48% slower)

    def test_batched_torch_error_comparison(self):
        """Test batch processing of many torch runtime error comparisons."""

        class TorchRuntimeError(BaseException):
            pass

        TorchRuntimeError.__module__ = "torch._dynamo.exc"
        TorchRuntimeError.__name__ = "TorchRuntimeError"

        # Create 100 base exceptions and 100 torch wrapped versions
        base_exceptions = [ValueError(f"base_{i}") for i in range(100)]
        torch_exceptions = []

        for i in range(100):
            torch_exc = TorchRuntimeError(f"wrapped_{i}")
            torch_exc.__cause__ = ValueError(f"wrapped_base_{i}")
            torch_exceptions.append(torch_exc)

        # Compare each base with each torch exception of same index
        results = []
        for i in range(100):
            codeflash_output = _exceptions_are_equivalent(base_exceptions[i], torch_exceptions[i])
            result = codeflash_output  # 72.4μs -> 62.3μs (16.2% faster)
            results.append(result)

    def test_stress_test_many_exception_instantiations(self):
        """Stress test creating and comparing many exception instances."""
        # Create 1000 exception instances
        exceptions = []
        for i in range(1000):
            exc_type = ValueError if i % 2 == 0 else TypeError
            exceptions.append(exc_type(f"message_{i}"))

        # Compare each exception with a few others
        equivalences = 0
        non_equivalences = 0

        for i in range(0, 1000, 10):
            for j in range(i, min(i + 5, 1000)):
                if _exceptions_are_equivalent(exceptions[i], exceptions[j]):
                    equivalences += 1
                else:
                    non_equivalences += 1


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
from codeflash.verification.comparator import _exceptions_are_equivalent


def test__exceptions_are_equivalent():
    _exceptions_are_equivalent(BaseException(), BaseException())
🔎 Click to see Concolic Coverage Tests
Test File::Test Function Original ⏱️ Optimized ⏱️ Speedup
codeflash_concolic_g5uubgk6/tmpkbpq1ztw/test_concolic_coverage.py::test__exceptions_are_equivalent 952ns 832ns 14.4%✅

To test or edit this optimization locally git merge codeflash/optimize-pr1119-2026-01-20T05.30.28

Click to see suggested changes
Suggested change
# If types match, they're potentially equivalent (will be compared further)
if type(orig).__name__ == type(new).__name__:
return True
# Check if one is a TorchRuntimeError wrapping the other's type
if _is_torch_runtime_error(new):
wrapped = _get_wrapped_exception_from_torch_runtime_error(new)
if wrapped is not None and type(wrapped).__name__ == type(orig).__name__:
return True
# Also check the error message for the wrapped exception type
# TorchRuntimeError message contains "got ExceptionType('...')"
error_msg = str(new)
orig_type_name = type(orig).__name__
if f"got {orig_type_name}(" in error_msg or f"got {orig_type_name}:" in error_msg:
return True
if _is_torch_runtime_error(orig):
wrapped = _get_wrapped_exception_from_torch_runtime_error(orig)
if wrapped is not None and type(wrapped).__name__ == type(new).__name__:
return True
error_msg = str(orig)
new_type_name = type(new).__name__
# Cache type lookups to avoid redundant attribute access
orig_type = type(orig)
new_type = type(new)
orig_type_name = orig_type.__name__
new_type_name = new_type.__name__
# If types match, they're potentially equivalent (will be compared further)
if orig_type_name == new_type_name:
return True
# Check if new is a TorchRuntimeError wrapping the other's type
if new_type_name == "TorchRuntimeError" and "torch._dynamo" in new_type.__module__:
wrapped = _get_wrapped_exception_from_torch_runtime_error(new)
if wrapped is not None and type(wrapped).__name__ == orig_type_name:
return True
# Also check the error message for the wrapped exception type
# TorchRuntimeError message contains "got ExceptionType('...')"
error_msg = str(new)
if f"got {orig_type_name}(" in error_msg or f"got {orig_type_name}:" in error_msg:
return True
# Check if orig is a TorchRuntimeError wrapping the other's type
if orig_type_name == "TorchRuntimeError" and "torch._dynamo" in orig_type.__module__:
wrapped = _get_wrapped_exception_from_torch_runtime_error(orig)
if wrapped is not None and type(wrapped).__name__ == new_type_name:
return True
error_msg = str(orig)

Static Badge

@aseembits93
Copy link
Contributor Author

in favor of a more general solution in a different pr

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants