Skip to content

⚡️ Speed up method InitDecorator.visit_ClassDef by 10% in PR #1860 (fix/attrs-init-instrumentation)#1867

Open
codeflash-ai[bot] wants to merge 1 commit intofix/attrs-init-instrumentationfrom
codeflash/optimize-pr1860-2026-03-18T10.39.52
Open

⚡️ Speed up method InitDecorator.visit_ClassDef by 10% in PR #1860 (fix/attrs-init-instrumentation)#1867
codeflash-ai[bot] wants to merge 1 commit intofix/attrs-init-instrumentationfrom
codeflash/optimize-pr1860-2026-03-18T10.39.52

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Mar 18, 2026

⚡️ This pull request contains optimizations for PR #1860

If you approve this dependent PR, these changes will be merged into the original PR branch fix/attrs-init-instrumentation.

This PR will be automatically closed if the original PR is merged.


📄 10% (0.10x) speedup for InitDecorator.visit_ClassDef in codeflash/languages/python/instrument_codeflash_capture.py

⏱️ Runtime : 405 microseconds 367 microseconds (best of 250 runs)

📝 Explanation and details

The optimization eliminates redundant iterations through node.body by adding a break statement immediately after finding and decorating the __init__ method (when has_init=True). The profiler shows the outer body loop dropped from 392 hits to 376 hits (~4% fewer), and the inner decorator-list loop dropped from 18 hits to 18 hits but now exits cleanly via break instead of continuing to scan remaining body items. Additionally, the if not has_init: branch now consolidates dataclass/attrs/NamedTuple checks in a single decorator loop instead of three separate passes, reducing _expr_name calls from 471 total hits to 263 (~44% fewer) and cutting that function's time from 391 µs to 218 µs. Runtime improved from 405 µs to 367 µs (10% faster) with no correctness regressions across all test cases.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 137 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
import ast
from pathlib import Path

# imports
from codeflash.languages.python.instrument_codeflash_capture import InitDecorator


def test_visit_classdef_non_target_class_unchanged():
    """Test that classes not in target_classes are returned unchanged."""
    # Create a simple class AST node
    class_node = ast.parse("class MyClass:\n    pass").body[0]

    # Create InitDecorator with empty target_classes
    decorator = InitDecorator(target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"))

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 601ns -> 631ns (4.75% slower)

    # The node should be returned unchanged (same reference)
    assert result is class_node
    assert decorator.inserted_decorator is False


def test_visit_classdef_target_class_with_existing_init():
    """Test that decorator is added to existing __init__ in target class."""
    # Create a class with __init__
    code = """
class MyClass:
    def __init__(self):
        pass
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 3.48μs -> 3.48μs (0.000% faster)

    # Find the __init__ method
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")

    # Should have decorator added
    assert len(init_method.decorator_list) > 0
    assert decorator.inserted_decorator is True


def test_visit_classdef_target_class_creates_init():
    """Test that __init__ is created for target class without one."""
    # Create a class without __init__
    code = """
class MyClass:
    pass
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 4.36μs -> 4.30μs (1.42% faster)

    # Find the __init__ method (should be created)
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")

    # Verify it's the first item
    assert result.body[0] is init_method
    # Should have decorator
    assert len(init_method.decorator_list) > 0
    assert decorator.inserted_decorator is True


def test_visit_classdef_no_duplicate_decorator_on_existing_init():
    """Test that codeflash_capture decorator is not added if already present."""
    # Create a class with __init__ that already has codeflash_capture decorator
    code = """
class MyClass:
    @codeflash_capture(function_name="MyClass.__init__", tmp_dir_path="/tmp")
    def __init__(self):
        pass
"""
    class_node = ast.parse(code).body[0]
    initial_decorator_count = len(class_node.body[0].decorator_list)

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 4.01μs -> 3.89μs (3.09% faster)

    # Find the __init__ method
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")

    # Decorator count should not increase
    assert len(init_method.decorator_list) == initial_decorator_count
    # inserted_decorator should still be False since we didn't add one
    assert decorator.inserted_decorator is False


def test_visit_classdef_decorator_contains_correct_function_name():
    """Test that decorator contains the correct function_name keyword."""
    # Create a class without __init__
    code = """
class TargetClass:
    pass
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with TargetClass as target
    decorator = InitDecorator(
        target_classes={"TargetClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 4.39μs -> 4.25μs (3.30% faster)

    # Find the __init__ method
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")

    # Get the decorator (should be a Call node)
    decorator_call = init_method.decorator_list[0]
    assert isinstance(decorator_call, ast.Call)

    # Find function_name keyword argument
    function_name_kw = next((kw for kw in decorator_call.keywords if kw.arg == "function_name"), None)
    assert function_name_kw is not None
    assert isinstance(function_name_kw.value, ast.Constant)
    assert function_name_kw.value.value == "TargetClass.__init__"


def test_visit_classdef_created_init_has_correct_arguments():
    """Test that created __init__ has correct arguments (self, *args, **kwargs)."""
    # Create a class without __init__
    code = """
class MyClass:
    pass
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 4.31μs -> 4.34μs (0.692% slower)

    # Find the __init__ method
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")

    # Check arguments
    assert len(init_method.args.args) == 1
    assert init_method.args.args[0].arg == "self"
    assert init_method.args.vararg.arg == "args"
    assert init_method.args.kwarg.arg == "kwargs"


def test_visit_classdef_created_init_calls_super():
    """Test that created __init__ contains a super().__init__(*args, **kwargs) call."""
    # Create a class without __init__
    code = """
class MyClass:
    pass
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 4.14μs -> 4.25μs (2.59% slower)

    # Find the __init__ method
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")

    # Check body contains Expr with Call
    assert len(init_method.body) > 0
    expr_stmt = init_method.body[0]
    assert isinstance(expr_stmt, ast.Expr)
    assert isinstance(expr_stmt.value, ast.Call)


def test_visit_classdef_multiple_target_classes():
    """Test that only target classes are modified."""
    # Create two classes
    code = """
class ClassA:
    pass

class ClassB:
    pass
"""
    module = ast.parse(code)
    class_a = module.body[0]
    class_b = module.body[1]

    # Create InitDecorator with only ClassA as target
    decorator = InitDecorator(
        target_classes={"ClassA"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit both classes
    result_a = decorator.visit_ClassDef(class_a)  # 4.19μs -> 4.19μs (0.000% faster)
    result_b = decorator.visit_ClassDef(class_b)

    # ClassA should be modified
    init_a = next(
        item for item in result_a.body if isinstance(item, ast.FunctionDef) and item.name == "__init__"
    )  # 370ns -> 360ns (2.78% faster)
    assert len(init_a.decorator_list) > 0

    # ClassB should be unchanged (no __init__ added)
    init_b = next(
        (item for item in result_b.body if isinstance(item, ast.FunctionDef) and item.name == "__init__"), None
    )
    assert init_b is None


def test_visit_classdef_with_dataclass_decorator_skipped():
    """Test that dataclass classes are skipped (no __init__ created)."""
    # Create a dataclass
    code = """
@dataclass
class MyClass:
    field: int
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 3.94μs -> 3.99μs (1.23% slower)

    # Should not create __init__
    init_method = next(
        (item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__"), None
    )
    assert init_method is None
    assert decorator.inserted_decorator is False


def test_visit_classdef_with_attrs_decorator_stored_for_patching():
    """Test that attrs classes are stored for monkey-patching instead of modified."""
    # Create an attrs class
    code = """
@attrs.define
class MyClass:
    field: int = 0
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 6.57μs -> 5.79μs (13.5% faster)

    # Should store in _attrs_classes_to_patch
    assert "MyClass" in decorator._attrs_classes_to_patch
    # Should not create __init__ in the class
    init_methods = [item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__"]
    assert len(init_methods) == 0
    assert decorator.inserted_decorator is True


def test_visit_classdef_with_attrs_init_false_not_patched():
    """Test that attrs classes with init=False are not patched."""
    # Create an attrs class with init=False
    code = """
@attrs.define(init=False)
class MyClass:
    field: int = 0
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 7.35μs -> 6.29μs (16.9% faster)

    # Should not be stored for patching
    assert "MyClass" not in decorator._attrs_classes_to_patch
    assert decorator.inserted_decorator is False


def test_visit_classdef_with_namedtuple_base_skipped():
    """Test that NamedTuple subclasses are skipped."""
    # Create a NamedTuple subclass
    code = """
class MyClass(NamedTuple):
    field: int
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 3.95μs -> 3.98μs (0.754% slower)

    # Should not create __init__
    init_method = next(
        (item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__"), None
    )
    assert init_method is None
    assert decorator.inserted_decorator is False


def test_visit_classdef_with_init_with_self_parameter():
    """Test that __init__ with self parameter is modified."""
    code = """
class MyClass:
    def __init__(self):
        pass
"""
    class_node = ast.parse(code).body[0]

    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    result = decorator.visit_ClassDef(class_node)  # 3.55μs -> 3.52μs (0.853% faster)

    init_methods = [item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__"]
    assert len(init_methods) == 1
    assert len(init_methods[0].decorator_list) > 0
    assert decorator.inserted_decorator is True


def test_visit_classdef_empty_target_classes():
    """Test with empty target_classes set."""
    code = """
class MyClass:
    pass
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with empty target_classes
    decorator = InitDecorator(target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"))

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 641ns -> 651ns (1.54% slower)

    # Should be unchanged
    assert result is class_node
    assert decorator.inserted_decorator is False


def test_visit_classdef_with_empty_body():
    """Test class with empty body."""
    # Manually create a class with empty body
    class_node = ast.ClassDef(name="MyClass", bases=[], keywords=[], body=[], decorator_list=[])

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 4.24μs -> 4.13μs (2.66% faster)

    # Should create __init__
    assert len(result.body) > 0
    init_method = result.body[0]
    assert isinstance(init_method, ast.FunctionDef)
    assert init_method.name == "__init__"
    assert decorator.inserted_decorator is True


def test_visit_classdef_with_complex_class_name():
    """Test with class names containing underscores and numbers."""
    code = """
class _PrivateClass123:
    pass
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with complex class name as target
    decorator = InitDecorator(
        target_classes={"_PrivateClass123"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 4.31μs -> 4.27μs (0.937% faster)

    # Should modify the class
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")
    assert len(init_method.decorator_list) > 0
    # Verify decorator function_name has correct class name
    decorator_call = init_method.decorator_list[0]
    function_name_kw = next((kw for kw in decorator_call.keywords if kw.arg == "function_name"), None)
    assert function_name_kw.value.value == "_PrivateClass123.__init__"


def test_visit_classdef_with_multiple_existing_methods():
    """Test class with multiple existing methods."""
    code = """
class MyClass:
    def method1(self):
        pass
    
    def __init__(self):
        pass
    
    def method2(self):
        pass
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 3.79μs -> 3.64μs (4.12% faster)

    # Find __init__
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")

    # Should have decorator
    assert len(init_method.decorator_list) > 0
    # Other methods should remain unchanged
    method_count = len([item for item in result.body if isinstance(item, ast.FunctionDef)])
    assert method_count == 3


def test_visit_classdef_with_class_variables():
    """Test class with class variables."""
    code = """
class MyClass:
    class_var = 10
    
    def __init__(self):
        pass
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 3.58μs -> 3.63μs (1.38% slower)

    # Find __init__
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")

    # Should have decorator
    assert len(init_method.decorator_list) > 0
    # Class variable should still exist
    class_var_assign = next(item for item in result.body if isinstance(item, ast.Assign))
    assert class_var_assign is not None


def test_visit_classdef_with_empty_init():
    """Test class with existing empty __init__."""
    code = """
class MyClass:
    def __init__(self):
        ...
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 3.68μs -> 3.63μs (1.41% faster)

    # Find __init__
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")

    # Should have decorator
    assert len(init_method.decorator_list) > 0
    # Original body should be preserved
    assert len(init_method.body) > 0


def test_visit_classdef_with_init_no_args():
    """Test __init__ with no arguments (edge case, invalid)."""
    # Create a class with __init__ that has no arguments
    code = """
class MyClass:
    def __init__():
        pass
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 4.58μs -> 4.40μs (4.09% faster)

    # Should not modify existing __init__ (no self parameter), and should create new one
    init_methods = [item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__"]
    # Should have created a new one
    assert len(init_methods) >= 1
    assert decorator.inserted_decorator is True


def test_visit_classdef_with_typing_namedtuple():
    """Test typing.NamedTuple base class."""
    code = """
from typing import NamedTuple

class MyClass(NamedTuple):
    field: int
"""
    # Parse just the class definition
    class_code = """
class MyClass(NamedTuple):
    field: int
"""
    class_node = ast.parse(class_code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 4.09μs -> 4.16μs (1.71% slower)

    # Should not create __init__ for NamedTuple
    init_method = next(
        (item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__"), None
    )
    assert init_method is None
    assert decorator.inserted_decorator is False


def test_visit_classdef_with_multiple_decorators_on_init():
    """Test __init__ with multiple existing decorators."""
    code = """
class MyClass:
    @classmethod
    @functools.lru_cache
    def __init__(self):
        pass
"""
    class_node = ast.parse(code).body[0]

    initial_decorator_count = len(class_node.body[0].decorator_list)

    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    result = decorator.visit_ClassDef(class_node)  # 3.89μs -> 3.96μs (1.77% slower)

    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")
    assert len(init_method.decorator_list) > initial_decorator_count
    assert decorator.inserted_decorator is True


def test_visit_classdef_with_deeply_nested_attrs():
    """Test attrs classes with namespace like codeflash.attrs."""
    code = """
@codeflash.attrs.define
class MyClass:
    field: int = 0
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 7.34μs -> 6.34μs (15.8% faster)

    # Should recognize attrs and store for patching
    assert "MyClass" in decorator._attrs_classes_to_patch
    assert decorator.inserted_decorator is True


def test_visit_classdef_preserves_node_metadata():
    """Test that AST node metadata is preserved."""
    code = """
class MyClass:
    pass
"""
    class_node = ast.parse(code).body[0]
    original_name = class_node.name
    original_bases = class_node.bases

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 4.38μs -> 4.42μs (0.928% slower)

    # Name and bases should be preserved
    assert result.name == original_name
    assert result.bases == original_bases


def test_visit_classdef_with_bases():
    """Test class with base classes."""
    code = """
class MyClass(BaseClass, Mixin):
    pass
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 5.67μs -> 5.43μs (4.44% faster)

    # Should create __init__
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")
    assert len(init_method.decorator_list) > 0
    # Bases should be preserved
    assert len(result.bases) == 2
    assert decorator.inserted_decorator is True


def test_visit_classdef_with_instance_variables_in_init():
    """Test class with instance variables assigned in __init__."""
    code = """
class MyClass:
    def __init__(self):
        self.x = 10
        self.y = 20
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 3.57μs -> 3.59μs (0.585% slower)

    # Find __init__
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")

    # Should have decorator
    assert len(init_method.decorator_list) > 0
    # Original body should be preserved
    assert len(init_method.body) == 2


def test_visit_classdef_with_annotations():
    """Test class with type annotations."""
    code = """
class MyClass:
    x: int
    y: str
    
    def __init__(self):
        pass
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 3.52μs -> 3.61μs (2.52% slower)

    # Find __init__
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")

    # Should have decorator
    assert len(init_method.decorator_list) > 0
    # Annotations should be preserved
    annassign_count = len([item for item in result.body if isinstance(item, ast.AnnAssign)])
    assert annassign_count == 2


def test_visit_classdef_decorator_order():
    """Test that decorator is inserted at the beginning of decorator_list."""
    code = """
class MyClass:
    @some_decorator
    def __init__(self):
        pass
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 3.57μs -> 3.55μs (0.536% faster)

    # Find __init__
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")

    # codeflash_capture should be first decorator
    first_decorator = init_method.decorator_list[0]
    assert isinstance(first_decorator, ast.Call)
    assert isinstance(first_decorator.func, ast.Name)
    assert first_decorator.func.id == "codeflash_capture"


def test_visit_classdef_with_init_parameters():
    """Test class with __init__ that has parameters."""
    code = """
class MyClass:
    def __init__(self, x, y=10):
        self.x = x
        self.y = y
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 3.47μs -> 3.49μs (0.545% slower)

    # Find __init__
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")

    # Should have decorator
    assert len(init_method.decorator_list) > 0
    # Parameters should be preserved
    assert len(init_method.args.args) == 3  # self, x, y
    assert init_method.args.args[0].arg == "self"
    assert init_method.args.args[1].arg == "x"
    assert init_method.args.args[2].arg == "y"


def test_visit_classdef_many_target_classes():
    """Test with many target classes in the set."""
    # Create InitDecorator with 1000 target classes
    target_classes = {f"Class{i}" for i in range(1000)}

    # Create a class that is in the target set
    code = """
class Class500:
    pass
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator
    decorator = InitDecorator(
        target_classes=target_classes, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 4.54μs -> 4.51μs (0.665% faster)

    # Should create __init__
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")
    assert len(init_method.decorator_list) > 0
    assert decorator.inserted_decorator is True


def test_visit_classdef_many_methods_in_class():
    """Test class with many methods."""
    # Create a class with 100 methods
    methods = "\n    ".join([f"def method{i}(self):\n        pass" for i in range(100)])
    code = f"""
class MyClass:
    {methods}
    
    def __init__(self):
        pass
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 12.2μs -> 12.2μs (0.164% faster)

    # Find __init__
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")

    # Should have decorator
    assert len(init_method.decorator_list) > 0
    # All methods should still be present
    method_count = len([item for item in result.body if isinstance(item, ast.FunctionDef)])
    assert method_count == 101  # 100 methods + __init__


def test_visit_classdef_many_decorators_in_class():
    """Test class with many decorators."""
    # Create a class with 100 decorators
    decorators = "\n".join([f"@decorator{i}" for i in range(100)])
    code = f"""
{decorators}
class MyClass:
    pass
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 41.4μs -> 31.5μs (31.3% faster)

    # Should create __init__
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")
    assert len(init_method.decorator_list) > 0
    # Class decorators should be preserved
    assert len(result.decorator_list) == 100


def test_visit_classdef_many_class_variables():
    """Test class with many class variables."""
    # Create a class with 100 class variables
    variables = "\n    ".join([f"var{i} = {i}" for i in range(100)])
    code = f"""
class MyClass:
    {variables}
    
    def __init__(self):
        pass
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 8.60μs -> 8.68μs (0.922% slower)

    # Find __init__
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")

    # Should have decorator
    assert len(init_method.decorator_list) > 0
    # All variables should still be present
    assign_count = len([item for item in result.body if isinstance(item, ast.Assign)])
    assert assign_count == 100


def test_visit_classdef_large_init_body():
    """Test __init__ with large body."""
    # Create __init__ with 100 statements
    statements = "\n        ".join([f"self.x{i} = {i}" for i in range(100)])
    code = f"""
class MyClass:
    def __init__(self):
        {statements}
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 3.65μs -> 3.65μs (0.027% slower)

    # Find __init__
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")

    # Should have decorator
    assert len(init_method.decorator_list) > 0
    # All statements should still be present
    assert len(init_method.body) == 100


def test_visit_classdef_large_init_with_many_parameters():
    """Test __init__ with many parameters."""
    # Create __init__ with 100 parameters
    params = ", ".join([f"arg{i}=None" for i in range(100)])
    code = f"""
class MyClass:
    def __init__(self, {params}):
        pass
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 3.51μs -> 3.58μs (1.96% slower)

    # Find __init__
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")

    # Should have decorator
    assert len(init_method.decorator_list) > 0
    # All parameters should be preserved
    assert len(init_method.args.args) == 101  # self + 100 args


def test_visit_classdef_many_base_classes():
    """Test class with many base classes."""
    # Create a class with 50 base classes
    bases = ", ".join([f"Base{i}" for i in range(50)])
    code = f"""
class MyClass({bases}):
    pass
"""
    class_node = ast.parse(code).body[0]

    # Create InitDecorator with MyClass as target
    decorator = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the class node
    result = decorator.visit_ClassDef(class_node)  # 14.9μs -> 14.7μs (1.02% faster)

    # Should create __init__
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")
    assert len(init_method.decorator_list) > 0
    # All base classes should be preserved
    assert len(result.bases) == 50


def test_visit_classdef_complex_nested_class():
    """Test modification of nested class."""
    # Create a module with a nested class structure
    code = """
class OuterClass:
    class InnerClass:
        pass
"""
    module = ast.parse(code)
    outer_class = module.body[0]
    inner_class = outer_class.body[0]

    # Create InitDecorator with InnerClass as target
    decorator = InitDecorator(
        target_classes={"InnerClass"}, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    # Visit the inner class node
    result = decorator.visit_ClassDef(inner_class)  # 4.24μs -> 4.28μs (0.935% slower)

    # Should create __init__
    init_method = next(item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__")
    assert len(init_method.decorator_list) > 0
    assert decorator.inserted_decorator is True


def test_visit_classdef_multiple_sequential_visits():
    """Test multiple sequential visits with different target classes."""
    decorator = InitDecorator(
        target_classes={"TargetA", "TargetB", "TargetC"},
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
    )

    class_names = ["TargetA", "TargetB", "TargetC", "TargetA", "TargetB"]
    for class_name in class_names:
        code = f"""
class {class_name}:
    pass
"""
        class_node = ast.parse(code).body[0]
        result = decorator.visit_ClassDef(class_node)  # 15.0μs -> 15.0μs (0.067% slower)

        init_method = next(
            (item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__"), None
        )
        assert init_method is not None
        assert len(init_method.decorator_list) > 0


def test_visit_classdef_stored_attrs_patches():
    """Test that many attrs classes are correctly stored for patching."""
    decorator = InitDecorator(
        target_classes={f"Class{i}" for i in range(100)},
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
    )

    # Visit 50 attrs classes
    for i in range(50):
        code = f"""
@attrs.define
class Class{i}:
    field: int = 0
"""
        class_node = ast.parse(code).body[0]
        result = decorator.visit_ClassDef(class_node)  # 164μs -> 139μs (18.1% faster)

        # Should be stored for patching
        assert f"Class{i}" in decorator._attrs_classes_to_patch

    # Verify all 50 were stored
    assert len(decorator._attrs_classes_to_patch) == 50


def test_visit_classdef_performance_with_large_target_set():
    """Test performance with large target class set and correct class matching."""
    target_classes = {f"Class{i}" for i in range(1000)}

    decorator = InitDecorator(
        target_classes=target_classes, fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests")
    )

    class_names = ["Class50", "Class500", "Class750", "Class999"]
    for class_name in class_names:
        code = f"""
class {class_name}:
    pass
"""
        class_node = ast.parse(code).body[0]
        result = decorator.visit_ClassDef(class_node)  # 12.7μs -> 12.9μs (1.87% slower)

        init_method = next(
            item for item in result.body if isinstance(item, ast.FunctionDef) and item.name == "__init__"
        )
        assert len(init_method.decorator_list) > 0
        function_name_kw = next(
            (kw for kw in init_method.decorator_list[0].keywords if kw.arg == "function_name"), None
        )
        assert function_name_kw.value.value == f"{class_name}.__init__"

To edit these changes git checkout codeflash/optimize-pr1860-2026-03-18T10.39.52 and push.

Codeflash Static Badge

The optimization eliminates redundant iterations through `node.body` by adding a `break` statement immediately after finding and decorating the `__init__` method (when `has_init=True`). The profiler shows the outer body loop dropped from 392 hits to 376 hits (~4% fewer), and the inner decorator-list loop dropped from 18 hits to 18 hits but now exits cleanly via `break` instead of continuing to scan remaining body items. Additionally, the `if not has_init:` branch now consolidates dataclass/attrs/NamedTuple checks in a single decorator loop instead of three separate passes, reducing `_expr_name` calls from 471 total hits to 263 (~44% fewer) and cutting that function's time from 391 µs to 218 µs. Runtime improved from 405 µs to 367 µs (10% faster) with no correctness regressions across all test cases.
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Mar 18, 2026
@claude
Copy link
Contributor

claude bot commented Mar 18, 2026

Claude finished @codeflash-ai[bot]'s task in 45s —— View job


PR Review Summary

Todo

  • Triage PR scope
  • Run lint/typecheck (prek)
  • Resolve stale review threads
  • Review code changes
  • Duplicate detection
  • Test coverage analysis
  • Check optimization PRs
  • Post summary

Prek Checks

uv run prek run --from-ref origin/fix/attrs-init-instrumentationpassed (ruff check, ruff format both clean). No auto-fixes needed.

Code Review

The optimization is sound. Two changes are made:

1. break after locating __init__ (line 199)

for item in node.body:
    if isinstance(item, ast.FunctionDef) and item.name == "__init__" ...:
        ...
        break   # ← new

Correct. Python classes can only have one valid __init__ (with a self first arg), so stopping the scan once found is safe.

2. Consolidated decorator loop (dataclass + attrs combined)

Original had two separate passes over node.decorator_list; the optimization merges them into one. The early-return logic is preserved correctly.

Minor issue — duplicate stale comment at line 222:

# Create super().__init__(*args, **kwargs) call (use prebuilt AST fragments)  ← misplaced

# Skip NamedTuples ...
for base in node.bases:
    ...

# Create super().__init__(*args, **kwargs) call (use prebuilt AST fragments)  ← correct location
super_call = self._super_call_expr

The first copy of this comment (line 222) is a leftover from the refactor — it appears before the NamedTuple check but has nothing to do with it. The second copy (line 230) is at the correct location. This is cosmetic only and doesn't affect correctness. Fix this →

Ordering change: NamedTuple base-class check is now performed after the decorator loop rather than between the two former decorator loops. Semantically equivalent for all realistic inputs.

No bugs or correctness regressions found.

Duplicate Detection

No duplicates detected.

Test Coverage

instrument_codeflash_capture.py: 99% coverage (129 stmts, 1 miss) from the existing test suite. The 8 test failures in the suite are pre-existing and unrelated to this change (tracer tests). Codeflash also reports 100% coverage via 137 generated regression tests.

Optimization PRs


Last updated: 2026-03-18

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants