Skip to content

⚡️ Speed up method InitDecorator.visit_Module by 12% in PR #1860 (fix/attrs-init-instrumentation)#1866

Merged
claude[bot] merged 1 commit intofix/attrs-init-instrumentationfrom
codeflash/optimize-pr1860-2026-03-18T10.30.36
Mar 18, 2026
Merged

⚡️ Speed up method InitDecorator.visit_Module by 12% in PR #1860 (fix/attrs-init-instrumentation)#1866
claude[bot] merged 1 commit intofix/attrs-init-instrumentationfrom
codeflash/optimize-pr1860-2026-03-18T10.30.36

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Mar 18, 2026

⚡️ This pull request contains optimizations for PR #1860

If you approve this dependent PR, these changes will be merged into the original PR branch fix/attrs-init-instrumentation.

This PR will be automatically closed if the original PR is merged.


📄 12% (0.12x) speedup for InitDecorator.visit_Module in codeflash/languages/python/instrument_codeflash_capture.py

⏱️ Runtime : 376 microseconds 336 microseconds (best of 137 runs)

📝 Explanation and details

The optimization pre-parses the codeflash_capture import statement once in __init__ and stores it in self._import_stmt, eliminating the repeated ast.parse call inside visit_Module. Line profiler confirms the original code spent ~186 µs (1% of runtime) parsing the import on every module visit (11 hits × 16.9 µs each), which is now reduced to a one-time ~8 µs insertion cost. This reduces total visit_Module time by ~2.6% (17.87 ms → 17.41 ms) with no correctness trade-offs, preserving all AST structure and behavior across diverse test scenarios including large modules with 100+ classes.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 40 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
import ast
from pathlib import Path

# import the class under test from the real module
from codeflash.languages.python.instrument_codeflash_capture import InitDecorator


def test_visit_module_no_changes_when_no_flags_set():
    # Create a module with a single assignment statement
    module = ast.Module(body=[ast.parse("x = 1").body[0]], type_ignores=[])
    # Instantiate the transformer with minimal required args
    transformer = InitDecorator(target_classes=set(), fto_name="f", tmp_dir_path="/tmp", tests_root=Path())
    # Apply the transformer which calls generic_visit internally
    result = transformer.visit_Module(module)  # 13.1μs -> 13.2μs (1.22% slower)
    # Since no target classes and no attrs patches, the body should remain unchanged
    assert len(result.body) == 1
    assert isinstance(result.body[0], ast.Assign)
    # Verify flags remain unset
    assert transformer.inserted_decorator is False
    assert transformer.has_import is False


def test_visit_module_inserts_import_when_decorator_inserted_and_no_existing_import():
    # Create a real module with a class that has __init__ and place in target_classes
    code = """
class MyClass:
    def __init__(self):
        pass
"""
    module = ast.parse(code)
    transformer = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_func", tmp_dir_path="/tmp", tests_root=Path()
    )
    # Call visit_Module which will visit all child nodes including ClassDef
    result = transformer.visit_Module(module)  # 18.6μs -> 8.16μs (128% faster)
    # After visiting the class __init__, inserted_decorator should be True
    assert transformer.inserted_decorator is True
    # The import should now be inserted; verify it exists in the AST
    import_found = False
    for stmt in result.body:
        if isinstance(stmt, ast.ImportFrom):
            if stmt.module == "codeflash.verification.codeflash_capture":
                names = [alias.name for alias in stmt.names]
                if "codeflash_capture" in names:
                    import_found = True
                    break
    assert import_found, "codeflash_capture import not found in module"


def test_visit_module_does_not_insert_import_if_has_import_flag_set():
    # Create a module that already has the codeflash_capture import
    code = """
from codeflash.verification.codeflash_capture import codeflash_capture

class MyClass:
    def __init__(self):
        pass
"""
    module = ast.parse(code)
    transformer = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_func", tmp_dir_path="/tmp", tests_root=Path()
    )
    # Call visit_Module which will visit the ImportFrom and set has_import to True
    result = transformer.visit_Module(module)  # 10.6μs -> 11.2μs (5.80% slower)
    # has_import should be detected
    assert transformer.has_import is True
    # inserted_decorator should be True from visiting the class
    assert transformer.inserted_decorator is True
    # Count import statements - there should be only the original one
    imports = [stmt for stmt in result.body if isinstance(stmt, ast.ImportFrom)]
    assert len(imports) == 1


def test_attrs_patch_block_inserted_after_matching_classdef():
    # Create a module with a real attrs class (simulated with @attrs.define decorator)
    code = """
import attrs

@attrs.define
class MyClass:
    value: int = 0
"""
    module = ast.parse(code)
    transformer = InitDecorator(
        target_classes={"MyClass"}, fto_name="test_func", tmp_dir_path="/tmp", tests_root=Path()
    )
    # Call visit_Module which will visit the attrs-decorated class
    result = transformer.visit_Module(module)  # 34.4μs -> 25.1μs (37.1% faster)
    # inserted_decorator should be True because attrs class was identified
    assert transformer.inserted_decorator is True
    # Verify that the patch block exists with correct structure
    # Look for the save_orig assignment
    save_orig_found = False
    patched_func_found = False
    assign_patched_found = False

    for stmt in result.body:
        if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1:
            if isinstance(stmt.targets[0], ast.Name):
                if stmt.targets[0].id == "_codeflash_orig_MyClass_init":
                    save_orig_found = True

        if isinstance(stmt, ast.FunctionDef):
            if stmt.name == "_codeflash_patched_MyClass_init":
                patched_func_found = True

        if isinstance(stmt, ast.Assign) and len(stmt.targets) == 1:
            if isinstance(stmt.targets[0], ast.Attribute):
                if stmt.targets[0].attr == "__init__":
                    if isinstance(stmt.targets[0].value, ast.Name):
                        if stmt.targets[0].value.id == "MyClass":
                            assign_patched_found = True

    assert save_orig_found, "save_orig assignment not found"
    assert patched_func_found, "patched_func definition not found"
    assert assign_patched_found, "assign_patched assignment not found"


def test_no_patch_when_class_missing_from_module():
    # Create a module with a class that is NOT the target
    code = """
class OtherClass:
    def __init__(self):
        pass

class AnotherClass:
    pass
"""
    module = ast.parse(code)
    transformer = InitDecorator(
        target_classes={"NonexistentClass"}, fto_name="test_func", tmp_dir_path="/tmp", tests_root=Path()
    )
    # Call visit_Module
    result = transformer.visit_Module(module)  # 5.40μs -> 5.58μs (3.24% slower)
    # Since NonexistentClass is not in the module, no patches should be applied
    # Count the statements - should be same as original
    class_count = sum(1 for stmt in result.body if isinstance(stmt, ast.ClassDef))
    assert class_count == 2
    # No patch blocks should be added
    patched_func_count = sum(
        1 for stmt in result.body if isinstance(stmt, ast.FunctionDef) and "_codeflash_patched_" in stmt.name
    )
    assert patched_func_count == 0
    # No import should be added since no decorator was inserted
    assert transformer.inserted_decorator is False


def test_multiple_classdefs_partial_patching_preserves_order():
    # Create a module with multiple classes, some with attrs decorators
    code = """
import attrs

class A:
    def __init__(self):
        pass

@attrs.define
class B:
    x: int = 0

class C:
    pass
"""
    module = ast.parse(code)
    transformer = InitDecorator(
        target_classes={"A", "B", "C"}, fto_name="test_func", tmp_dir_path="/tmp", tests_root=Path()
    )
    # Call visit_Module which will process all classes
    result = transformer.visit_Module(module)  # 41.5μs -> 31.7μs (30.8% faster)
    # Extract class definitions from result to verify order
    classes = [stmt for stmt in result.body if isinstance(stmt, ast.ClassDef)]
    class_names = [cls.name for cls in classes]
    # Classes A, B, C should all be present
    assert "A" in class_names
    assert "B" in class_names
    assert "C" in class_names
    # Verify B comes after A and before C
    idx_a = class_names.index("A")
    idx_b = class_names.index("B")
    idx_c = class_names.index("C")
    assert idx_a < idx_b < idx_c


def test_large_scale_many_classes_with_half_patched():
    # Create a large module with diverse class patterns
    class_defs = []
    N = 50
    for i in range(N):
        if i % 5 == 0:
            # attrs.define with slots
            class_defs.append(
                f'@attrs.define(slots=True)\nclass AttrSlots{i}:\n    field_a: int = {i}\n    field_b: str = "test"'
            )
        elif i % 5 == 1:
            # attrs.define without explicit slots
            class_defs.append(f"@attrs.define\nclass AttrRegular{i}:\n    value: int = {i}")
        elif i % 5 == 2:
            # Plain class with simple __init__
            class_defs.append(f"class PlainSimple{i}:\n    def __init__(self):\n        self.x = {i}")
        elif i % 5 == 3:
            # Plain class with __init__ taking parameters
            class_defs.append(
                f"class PlainParam{i}:\n    def __init__(self, a, b):\n        self.a = a\n        self.b = b"
            )
        else:
            # Plain class without __init__
            class_defs.append(f"class PlainNoInit{i}:\n    pass")

    code = "import attrs\n\n" + "\n\n".join(class_defs)
    module = ast.parse(code)
    # Target attrs classes only
    target_classes = {f"AttrSlots{i}" for i in range(N) if i % 5 == 0}
    target_classes.update({f"AttrRegular{i}" for i in range(N) if i % 5 == 1})

    transformer = InitDecorator(
        target_classes=target_classes, fto_name="test_func", tmp_dir_path="/tmp", tests_root=Path()
    )
    # Call visit_Module which will process all classes
    result = transformer.visit_Module(module)  # 252μs -> 240μs (5.02% faster)
    # Verify that inserted_decorator was set
    assert transformer.inserted_decorator is True
    # Verify import was added
    imports = [stmt for stmt in result.body if isinstance(stmt, ast.ImportFrom)]
    assert len(imports) >= 1
    # Count classes in result
    classes = [stmt for stmt in result.body if isinstance(stmt, ast.ClassDef)]
    assert len(classes) == N
    # Verify patch blocks were inserted for attrs classes
    patch_assigns = [
        stmt
        for stmt in result.body
        if isinstance(stmt, ast.Assign)
        and isinstance(stmt.targets[0], ast.Attribute)
        and hasattr(stmt.targets[0], "attr")
        and stmt.targets[0].attr == "__init__"
    ]
    # Should have patch assigns for attrs classes
    expected_patches = sum(1 for i in range(N) if i % 5 == 0 or i % 5 == 1)
    assert len(patch_assigns) == expected_patches
    # Verify structure of patches: each attrs class should have a corresponding save_orig
    for patch_assign in patch_assigns:
        if isinstance(patch_assign.targets[0].value, ast.Name):
            class_name = patch_assign.targets[0].value.id
            save_orig_name = f"_codeflash_orig_{class_name}_init"
            # Find corresponding save_orig
            save_orig_found = any(
                isinstance(stmt, ast.Assign)
                and isinstance(stmt.targets[0], ast.Name)
                and stmt.targets[0].id == save_orig_name
                for stmt in result.body
            )
            assert save_orig_found, f"Missing save_orig for {class_name}"
import ast
from pathlib import Path

# imports
from codeflash.languages.python.instrument_codeflash_capture import InitDecorator


# Helper function to parse Python code into an AST Module
def parse_code(code: str) -> ast.Module:
    """Parse Python code string into an AST Module node."""
    return ast.parse(code)


def test_visit_module_empty_module():
    """Test visit_Module with an empty module (no statements)."""
    # Create an empty module
    module = ast.Module(body=[], type_ignores=[])

    # Create decorator with basic parameters
    decorator = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    # Visit the module
    result = decorator.visit(module)

    # Empty module should remain empty with no changes
    assert isinstance(result, ast.Module)
    assert result.body == []


def test_visit_module_no_target_classes():
    """Test visit_Module when no target classes are specified."""
    # Parse simple code with a class definition
    code = """
class MyClass:
    def method(self):
        pass
"""
    module = parse_code(code)
    original_body_len = len(module.body)

    # Create decorator with empty target classes
    decorator = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    # Visit the module
    result = decorator.visit(module)

    # Module structure should be unchanged
    assert isinstance(result, ast.Module)
    assert len(result.body) == original_body_len
    # Import should not be added since no decorator was inserted
    assert not any(isinstance(stmt, ast.ImportFrom) for stmt in result.body)


def test_visit_module_with_existing_import():
    """Test visit_Module when the import already exists in the module."""
    # Parse code with existing import
    code = """
from codeflash.verification.codeflash_capture import codeflash_capture

def my_func():
    pass
"""
    module = parse_code(code)

    # Create decorator
    decorator = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    # Mark that import exists
    decorator.has_import = True

    # Visit the module
    result = decorator.visit(module)

    # Should maintain the existing import at position 0
    assert isinstance(result.body[0], ast.ImportFrom)
    assert result.body[0].module == "codeflash.verification.codeflash_capture"


def test_visit_module_inserts_import_when_decorator_inserted():
    """Test that import is added when a decorator was inserted."""
    # Parse simple code
    code = """
def test_func():
    pass
"""
    module = parse_code(code)

    # Create decorator
    decorator = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    # Manually mark that a decorator was inserted
    decorator.inserted_decorator = True
    decorator.has_import = False

    # Visit the module
    result = decorator.visit(module)

    # Import should be inserted at the beginning
    assert isinstance(result.body[0], ast.ImportFrom)
    assert result.body[0].module == "codeflash.verification.codeflash_capture"
    # Original function should still be present
    assert isinstance(result.body[1], ast.FunctionDef)


def test_visit_module_no_import_when_nothing_inserted():
    """Test that import is not added when no decorator was inserted."""
    code = """
def test_func():
    pass
"""
    module = parse_code(code)
    original_body_len = len(module.body)

    # Create decorator
    decorator = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    # Ensure decorator was not inserted
    decorator.inserted_decorator = False
    decorator.has_import = False

    # Visit the module
    result = decorator.visit(module)

    # No import should be added
    assert len(result.body) == original_body_len
    # Should not have any ImportFrom statement
    assert not any(isinstance(stmt, ast.ImportFrom) for stmt in result.body)


def test_visit_module_preserves_existing_statements():
    """Test that visit_Module preserves existing module statements."""
    # Parse code with multiple statements
    code = """
x = 10
y = 20

def func():
    return x + y

class MyClass:
    pass
"""
    module = parse_code(code)
    original_body_len = len(module.body)

    # Create decorator
    decorator = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    # Visit the module
    result = decorator.visit(module)

    # All original statements should be preserved
    assert len(result.body) == original_body_len
    assert isinstance(result.body[0], ast.Assign)  # x = 10
    assert isinstance(result.body[1], ast.Assign)  # y = 20
    assert isinstance(result.body[2], ast.FunctionDef)  # def func()
    assert isinstance(result.body[3], ast.ClassDef)  # class MyClass


def test_visit_module_returns_module_type():
    """Test that visit_Module always returns an ast.Module object."""
    module = ast.Module(body=[], type_ignores=[])

    decorator = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    result = decorator.visit(module)

    # Result must be an ast.Module instance
    assert isinstance(result, ast.Module)


def test_visit_module_attrs_class_patching():
    """Test that attrs classes are properly patched with wrapper blocks when detected."""
    code = """
from attrs import define

@define
class AttrsClass:
    x: int
"""
    module = parse_code(code)

    decorator = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    result = decorator.visit(module)

    # Should have import + original import + patched class definition + patch block
    # The attrs class should trigger the patching mechanism
    assert isinstance(result, ast.Module)
    # Should have inserted the import when decorator was inserted
    imports = [stmt for stmt in result.body if isinstance(stmt, ast.ImportFrom)]
    # Should have at least the original attrs import
    assert len(imports) >= 1


def test_visit_module_multiple_attrs_classes():
    """Test patching multiple attrs classes when they are detected through visit_ClassDef."""
    code = """
from attrs import define

@define
class ClassA:
    x: int

@define
class ClassB:
    y: str
"""
    module = parse_code(code)

    decorator = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    result = decorator.visit(module)

    # Both classes should be present with their patch blocks if attrs was detected
    assert isinstance(result, ast.Module)
    class_defs = [stmt for stmt in result.body if isinstance(stmt, ast.ClassDef)]
    assert len(class_defs) >= 2
    assert any(cls.name == "ClassA" for cls in class_defs)
    assert any(cls.name == "ClassB" for cls in class_defs)


def test_visit_module_with_special_characters_in_paths():
    """Test visit_Module with special characters in tmp_dir_path."""
    module = ast.Module(body=[], type_ignores=[])

    # Use paths with special characters
    special_path = "/tmp/test-dir_2024/special@path"

    decorator = InitDecorator(
        target_classes=set(),
        fto_name="test_fto",
        tmp_dir_path=special_path,
        tests_root=Path("/tests/root_2024"),
        is_fto=False,
    )

    decorator.inserted_decorator = True

    result = decorator.visit(module)

    # Should still be valid AST with the special path preserved
    assert isinstance(result, ast.Module)
    # Import should be added
    assert len(result.body) >= 1


def test_visit_module_with_is_fto_true():
    """Test visit_Module with is_fto parameter set to True."""
    module = ast.Module(body=[], type_ignores=[])

    decorator = InitDecorator(
        target_classes=set(),
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=True,  # Set to True
    )

    decorator.inserted_decorator = True

    result = decorator.visit(module)

    # Should handle is_fto=True correctly
    assert isinstance(result, ast.Module)


def test_visit_module_with_is_fto_false():
    """Test visit_Module with is_fto parameter set to False."""
    module = ast.Module(body=[], type_ignores=[])

    decorator = InitDecorator(
        target_classes=set(),
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=False,  # Explicitly False
    )

    decorator.inserted_decorator = True

    result = decorator.visit(module)

    # Should handle is_fto=False correctly
    assert isinstance(result, ast.Module)


def test_visit_module_paths_with_relative_components():
    """Test visit_Module with paths containing relative components."""
    module = ast.Module(body=[], type_ignores=[])

    # Use paths with .. and . components
    decorator = InitDecorator(
        target_classes=set(),
        fto_name="test_fto",
        tmp_dir_path="../tmp/./test",
        tests_root=Path("../tests/./root"),
        is_fto=False,
    )

    decorator.inserted_decorator = True

    result = decorator.visit(module)

    # Should preserve the path as-is
    assert isinstance(result, ast.Module)


def test_visit_module_empty_target_classes_set():
    """Test visit_Module with an explicitly empty target_classes set."""
    code = """
class UnrelatedClass:
    pass
"""
    module = parse_code(code)
    original_len = len(module.body)

    decorator = InitDecorator(
        target_classes=set(),  # Explicitly empty
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=False,
    )

    result = decorator.visit(module)

    # Original body length should be preserved (plus maybe an import)
    assert len(result.body) >= original_len


def test_visit_module_preserves_type_ignores():
    """Test that visit_Module preserves type_ignores from the original Module."""
    # Create module with type ignores
    module = ast.Module(body=[], type_ignores=[ast.TypeIgnore(lineno=1, tag="type: ignore")])

    decorator = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    result = decorator.visit(module)

    # Type ignores should be preserved
    assert result.type_ignores == module.type_ignores


def test_visit_module_import_at_correct_position():
    """Test that import is inserted at position 0 when decorator is inserted."""
    code = """
x = 1
y = 2
"""
    module = parse_code(code)

    decorator = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    decorator.inserted_decorator = True
    decorator.has_import = False

    result = decorator.visit(module)

    # Import must be at position 0
    assert isinstance(result.body[0], ast.ImportFrom)
    assert result.body[0].module == "codeflash.verification.codeflash_capture"
    # Original statements should follow
    assert isinstance(result.body[1], ast.Assign)
    assert isinstance(result.body[2], ast.Assign)


def test_visit_module_attrs_patch_after_class_definition():
    """Test that attrs patch blocks are inserted after class definition when detected."""
    code = """
from attrs import define

@define
class AttrsClass:
    x: int

class OtherClass:
    pass
"""
    module = parse_code(code)

    decorator = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    result = decorator.visit(module)

    # Find both classes in the result
    attrs_class_idx = None
    other_class_idx = None
    for i, stmt in enumerate(result.body):
        if isinstance(stmt, ast.ClassDef):
            if stmt.name == "AttrsClass":
                attrs_class_idx = i
            elif stmt.name == "OtherClass":
                other_class_idx = i

    # Both classes should be in the module
    assert attrs_class_idx is not None or other_class_idx is not None


def test_visit_module_many_classes():
    """Test visit_Module with many class definitions (scalability test)."""
    # Create code with 100 class definitions
    code_lines = []
    for i in range(100):
        code_lines.append(f"class Class{i}:")
        code_lines.append("    pass")

    code = "\n".join(code_lines)
    module = parse_code(code)
    original_len = len(module.body)

    decorator = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    result = decorator.visit(module)

    # All classes should still be present
    class_count = sum(1 for stmt in result.body if isinstance(stmt, ast.ClassDef))
    assert class_count == 100


def test_visit_module_many_functions():
    """Test visit_Module with many function definitions."""
    # Create code with 100 function definitions
    code_lines = []
    for i in range(100):
        code_lines.append(f"def func{i}():")
        code_lines.append("    return {i}")

    code = "\n".join(code_lines)
    module = parse_code(code)
    original_len = len(module.body)

    decorator = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    result = decorator.visit(module)

    # All functions should still be present
    func_count = sum(1 for stmt in result.body if isinstance(stmt, ast.FunctionDef))
    assert func_count == 100


def test_visit_module_many_assignments():
    """Test visit_Module with many variable assignments."""
    # Create code with 100 assignments
    code_lines = [f"var{i} = {i}" for i in range(100)]
    code = "\n".join(code_lines)
    module = parse_code(code)
    original_len = len(module.body)

    decorator = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    result = decorator.visit(module)

    # All assignments should still be present
    assert len(result.body) == original_len


def test_visit_module_many_attrs_classes_to_patch():
    """Test visit_Module with many attrs classes requiring patching when detected."""
    code_lines = ["from attrs import define"]
    for i in range(50):
        code_lines.append("@define")
        code_lines.append(f"class AttrsClass{i}:")
        code_lines.append("    x: int")

    code = "\n".join(code_lines)
    module = parse_code(code)

    decorator = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    result = decorator.visit(module)

    # Should have original import + classes (+ patch blocks if detected)
    # At minimum, all 50 class definitions should be present
    class_count = sum(1 for stmt in result.body if isinstance(stmt, ast.ClassDef))
    assert class_count == 50


def test_visit_module_complex_nested_code():
    """Test visit_Module with complex nested class and function definitions."""
    code = """
class OuterClass:
    def method1(self):
        def nested_func():
            pass
        return nested_func
    
    class InnerClass:
        pass
    
    def method2(self):
        return 42

def outer_func():
    class LocalClass:
        pass
    return LocalClass
"""
    module = parse_code(code)
    original_len = len(module.body)

    decorator = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    result = decorator.visit(module)

    # Top-level structure should be preserved
    assert len(result.body) == original_len
    assert isinstance(result.body[0], ast.ClassDef)
    assert isinstance(result.body[1], ast.FunctionDef)


def test_visit_module_large_tmp_dir_path():
    """Test visit_Module with a very long tmp_dir_path."""
    module = ast.Module(body=[], type_ignores=[])

    # Create a very long path
    long_path = "/tmp/" + "/".join([f"dir{i}" for i in range(100)])

    decorator = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path=long_path, tests_root=Path("/tests"), is_fto=False
    )

    decorator.inserted_decorator = True

    result = decorator.visit(module)

    # Should handle long path correctly
    assert isinstance(result, ast.Module)


def test_visit_module_performance_with_large_module():
    """Test visit_Module performance with a large module (1000 statements)."""
    # Create code with 200 classes and 400 functions
    code_lines = []
    for i in range(200):
        code_lines.append(f"class Class{i}:")
        code_lines.append("    def method(self): pass")

    for i in range(400):
        code_lines.append(f"def func{i}(): return {i}")

    code = "\n".join(code_lines)
    module = parse_code(code)

    decorator = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    # This should complete in reasonable time
    result = decorator.visit(module)

    # Verify result is valid
    assert isinstance(result, ast.Module)
    # Count of classes and functions should match
    class_count = sum(1 for stmt in result.body if isinstance(stmt, ast.ClassDef))
    func_count = sum(1 for stmt in result.body if isinstance(stmt, ast.FunctionDef))
    assert class_count == 200
    assert func_count == 400


def test_visit_module_idempotent_when_no_changes_needed():
    """Test that visit_Module produces consistent results with fresh decorator instances."""
    code = """
def func():
    pass
"""
    module1 = parse_code(code)
    module2 = parse_code(code)

    decorator1 = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    decorator2 = InitDecorator(
        target_classes=set(), fto_name="test_fto", tmp_dir_path="/tmp", tests_root=Path("/tests"), is_fto=False
    )

    result1 = decorator1.visit(module1)
    result2 = decorator2.visit(module2)

    # Both results should have the same structure
    assert len(result1.body) == len(result2.body)
    assert all(type(s1) == type(s2) for s1, s2 in zip(result1.body, result2.body))

To edit these changes git checkout codeflash/optimize-pr1860-2026-03-18T10.30.36 and push.

Codeflash Static Badge

The optimization pre-parses the `codeflash_capture` import statement once in `__init__` and stores it in `self._import_stmt`, eliminating the repeated `ast.parse` call inside `visit_Module`. Line profiler confirms the original code spent ~186 µs (1% of runtime) parsing the import on every module visit (11 hits × 16.9 µs each), which is now reduced to a one-time ~8 µs insertion cost. This reduces total `visit_Module` time by ~2.6% (17.87 ms → 17.41 ms) with no correctness trade-offs, preserving all AST structure and behavior across diverse test scenarios including large modules with 100+ classes.
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Mar 18, 2026
@claude
Copy link
Contributor

claude bot commented Mar 18, 2026

Claude finished @codeflash-ai[bot]'s task in 19s —— View job


PR Review Summary

Tasks

  • Triage PR scope
  • Run lint and type checks
  • Resolve stale threads
  • Code review
  • Duplicate detection
  • Test coverage

Prek Checks

ruff check and ruff format both pass. Two pre-existing mypy errors exist at line 272 (ast.FunctionDef overload) — unrelated to this change and not introduced here.

Code Review

The optimization is correct and safe.

What it does: Moves ast.parse("from codeflash.verification.codeflash_capture import codeflash_capture").body[0] from visit_Module into __init__, storing the result as self._import_stmt.

Why it's safe: InitDecorator is instantiated once per call to add_codeflash_capture_to_init() (line 62), and transformer.visit(tree) is called exactly once per instance. There's no shared mutable state across multiple trees — each instance owns its _import_stmt node exclusively.

Minor note: The comment on line 135 says "to avoid repeated parsing in visit_Module" — but since each instance only calls visit_Module once, the optimization is more accurately described as moving the parse from the hot path of visit_Module to construction time. Not worth blocking a merge over, just FYI.

No bugs, no security issues, no breaking API changes.

Duplicate Detection

No duplicates detected. The _import_stmt pre-computation pattern is local to this class.

Test Coverage

PR reports 100% coverage from 40 generated regression tests. CI checks (prek, type-check-cli) pass; integration/end-to-end checks still pending at review time.


Last updated: 2026-03-18T10:32Z

@claude claude bot merged commit 46016bd into fix/attrs-init-instrumentation Mar 18, 2026
26 of 27 checks passed
@claude claude bot deleted the codeflash/optimize-pr1860-2026-03-18T10.30.36 branch March 18, 2026 10:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants