Skip to content

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Sep 30, 2025

⚡️ This pull request contains optimizations for PR #363

If you approve this dependent PR, these changes will be merged into the original PR branch part-1-windows-fixes.

This PR will be automatically closed if the original PR is merged.


📄 65% (0.65x) speedup for InitDecorator.visit_ClassDef in codeflash/verification/instrument_codeflash_capture.py

⏱️ Runtime : 1.19 milliseconds 718 microseconds (best of 116 runs)

📝 Explanation and details

The optimized code achieves a 65% speedup through strategic precomputation of AST nodes that are repeatedly created during class processing.

Key optimizations:

  1. Precomputed AST components in __init__: Instead of reconstructing identical AST nodes (like ast.Name, ast.arg, ast.Constant) on every visit_ClassDef call, the optimized version creates them once during initialization and reuses them. This eliminates the expensive AST node construction overhead seen in the profiler - lines creating decorator keywords and super() call components dropped from ~2ms total to ~0.6ms.

  2. Optimized decorator presence check: Replaced the any() generator expression with a for/else loop that stops immediately when finding an existing codeflash_capture decorator. This avoids generator allocation overhead and short-circuits the search earlier.

  3. Reduced per-class AST construction: The decorator is now built once per class using precomputed components, rather than reconstructing all keywords and function references from scratch each time.

Performance impact by test type:

  • Basic cases (single class with simple __init__): ~140-220% faster, benefiting from reduced AST node construction
  • Edge cases (classes needing synthetic __init__): ~100-150% faster, particularly benefiting from prebuilt super() call components
  • Large scale (many methods/classes): ~17-40% faster, where the constant-time optimizations compound across many iterations

The optimization is most effective for workloads processing many classes, as the upfront precomputation cost is amortized across multiple visit_ClassDef calls, directly addressing the bottleneck of repetitive AST node creation identified in the profiler.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 1420 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 2 Passed
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
from __future__ import annotations

import ast
import sys
from pathlib import Path

# imports
import pytest  # used for our unit tests
from codeflash.verification.instrument_codeflash_capture import InitDecorator

# Helper functions for the tests

def get_classdef_from_code(code: str) -> ast.ClassDef:
    """Parse code and return the first ClassDef node."""
    tree = ast.parse(code)
    for node in tree.body:
        if isinstance(node, ast.ClassDef):
            return node
    raise ValueError("No class found in code.")

def get_init_func(classdef: ast.ClassDef) -> ast.FunctionDef | None:
    """Return the __init__ function from a ClassDef node, or None if not found."""
    for node in classdef.body:
        if isinstance(node, ast.FunctionDef) and node.name == "__init__":
            return node
    return None

def has_codeflash_capture_decorator(funcdef: ast.FunctionDef) -> bool:
    """Check if a FunctionDef has a codeflash_capture decorator."""
    for deco in funcdef.decorator_list:
        if (
            isinstance(deco, ast.Call)
            and isinstance(deco.func, ast.Name)
            and deco.func.id == "codeflash_capture"
        ):
            return True
    return False

def get_codeflash_capture_decorator(funcdef: ast.FunctionDef):
    """Return the codeflash_capture decorator call node, or None."""
    for deco in funcdef.decorator_list:
        if (
            isinstance(deco, ast.Call)
            and isinstance(deco.func, ast.Name)
            and deco.func.id == "codeflash_capture"
        ):
            return deco
    return None

def get_decorator_keywords(deco: ast.Call) -> dict:
    """Return a dict of keyword arguments from a decorator call."""
    return {kw.arg: kw.value.value if isinstance(kw.value, ast.Constant) else kw.value.s for kw in deco.keywords}

# --- UNIT TESTS ---

# Shared test parameters
TARGET_CLASSES = {"MyClass", "EdgeClass", "BigClass"}
FTO_NAME = "fto_test"
TMP_DIR_PATH = "/tmp/codeflash"
TESTS_ROOT = Path("/project/tests")
IS_FTO = True

# --------- 1. BASIC TEST CASES ---------

def test_adds_decorator_to_existing_init():
    """Decorator is added to __init__ if not present, and not to other methods."""
    code = """
class MyClass:
    def __init__(self, x):
        self.x = x
    def foo(self): pass
"""
    classdef = get_classdef_from_code(code)
    decorator = InitDecorator(TARGET_CLASSES, FTO_NAME, TMP_DIR_PATH, TESTS_ROOT, IS_FTO)
    codeflash_output = decorator.visit_ClassDef(classdef); new_classdef = codeflash_output # 9.08μs -> 3.68μs (147% faster)

    # __init__ should have codeflash_capture decorator
    init_func = get_init_func(new_classdef)
    deco = get_codeflash_capture_decorator(init_func)
    # Check decorator keyword arguments
    kws = get_decorator_keywords(deco)

def test_does_not_modify_non_target_class():
    """Classes not in target_classes are not modified."""
    code = """
class NotTarget:
    def __init__(self, y): self.y = y
"""
    classdef = get_classdef_from_code(code)
    decorator = InitDecorator(TARGET_CLASSES, FTO_NAME, TMP_DIR_PATH, TESTS_ROOT, IS_FTO)
    orig_dump = ast.dump(classdef)
    codeflash_output = decorator.visit_ClassDef(classdef); new_classdef = codeflash_output # 601ns -> 612ns (1.80% slower)

def test_adds_decorator_to_init_with_existing_decorators():
    """Decorator is added before existing decorators."""
    code = """
class MyClass:
    @staticmethod
    def __init__(self): pass
"""
    classdef = get_classdef_from_code(code)
    decorator = InitDecorator(TARGET_CLASSES, FTO_NAME, TMP_DIR_PATH, TESTS_ROOT, IS_FTO)
    codeflash_output = decorator.visit_ClassDef(classdef); new_classdef = codeflash_output # 8.83μs -> 3.63μs (143% faster)
    init_func = get_init_func(new_classdef)

def test_does_not_duplicate_decorator():
    """If codeflash_capture decorator is already present, do not add another."""
    code = """
class MyClass:
    @codeflash_capture(function_name="MyClass.__init__", tmp_dir_path="/tmp", tests_root="/root", is_fto=True)
    def __init__(self, a): pass
"""
    classdef = get_classdef_from_code(code)
    decorator = InitDecorator(TARGET_CLASSES, FTO_NAME, TMP_DIR_PATH, TESTS_ROOT, IS_FTO)
    codeflash_output = decorator.visit_ClassDef(classdef); new_classdef = codeflash_output # 9.45μs -> 3.76μs (151% faster)
    init_func = get_init_func(new_classdef)
    # Only one codeflash_capture decorator present
    count = sum(
        isinstance(d, ast.Call) and getattr(d.func, "id", None) == "codeflash_capture"
        for d in init_func.decorator_list
    )

def test_adds_init_if_missing():
    """If __init__ is missing, it is created with the correct decorator and super call."""
    code = """
class MyClass:
    def foo(self): pass
"""
    classdef = get_classdef_from_code(code)
    decorator = InitDecorator(TARGET_CLASSES, FTO_NAME, TMP_DIR_PATH, TESTS_ROOT, IS_FTO)
    codeflash_output = decorator.visit_ClassDef(classdef); new_classdef = codeflash_output # 14.2μs -> 6.76μs (110% faster)
    # __init__ should now exist
    init_func = get_init_func(new_classdef)
    # Should have correct arguments: self, *args, **kwargs
    argnames = [a.arg for a in init_func.args.args]
    # Should call super().__init__(*args, **kwargs)
    body = init_func.body

# --------- 2. EDGE TEST CASES ---------

def test_class_with_no_body():
    """Class with no body gets __init__ inserted."""
    code = "class EdgeClass:\n    pass"
    classdef = get_classdef_from_code(code)
    decorator = InitDecorator(TARGET_CLASSES, FTO_NAME, TMP_DIR_PATH, TESTS_ROOT, IS_FTO)
    codeflash_output = decorator.visit_ClassDef(classdef); new_classdef = codeflash_output # 13.9μs -> 7.07μs (96.6% faster)

def test_class_with_multiple_inits():
    """Class with multiple __init__ methods (should only decorate the first valid one)."""
    code = """
class MyClass:
    def __init__(self, x): pass
    def __init__(self, y): pass
"""
    classdef = get_classdef_from_code(code)
    decorator = InitDecorator(TARGET_CLASSES, FTO_NAME, TMP_DIR_PATH, TESTS_ROOT, IS_FTO)
    codeflash_output = decorator.visit_ClassDef(classdef); new_classdef = codeflash_output # 8.88μs -> 4.12μs (116% faster)
    # Only the first __init__ should be decorated
    inits = [n for n in new_classdef.body if isinstance(n, ast.FunctionDef) and n.name == "__init__"]
    if len(inits) > 1:
        pass

def test_init_with_no_self():
    """__init__ without 'self' as first argument should not be decorated."""
    code = """
class MyClass:
    def __init__(x): pass
"""
    classdef = get_classdef_from_code(code)
    decorator = InitDecorator(TARGET_CLASSES, FTO_NAME, TMP_DIR_PATH, TESTS_ROOT, IS_FTO)
    codeflash_output = decorator.visit_ClassDef(classdef); new_classdef = codeflash_output # 14.2μs -> 7.10μs (100% faster)
    init_func = get_init_func(new_classdef)

def test_class_with_only_classmethods():
    """If only classmethods are present, __init__ should be created."""
    code = """
class EdgeClass:
    @classmethod
    def foo(cls): pass
"""
    classdef = get_classdef_from_code(code)
    decorator = InitDecorator(TARGET_CLASSES, FTO_NAME, TMP_DIR_PATH, TESTS_ROOT, IS_FTO)
    codeflash_output = decorator.visit_ClassDef(classdef); new_classdef = codeflash_output # 13.8μs -> 6.66μs (107% faster)
    # __init__ should exist and be decorated
    init_func = get_init_func(new_classdef)

def test_class_with_inherited_init():
    """If __init__ is inherited (not defined), __init__ should be created."""
    code = """
class EdgeClass(BaseClass):
    def foo(self): pass
"""
    classdef = get_classdef_from_code(code)
    decorator = InitDecorator(TARGET_CLASSES, FTO_NAME, TMP_DIR_PATH, TESTS_ROOT, IS_FTO)
    codeflash_output = decorator.visit_ClassDef(classdef); new_classdef = codeflash_output # 13.6μs -> 6.30μs (117% faster)
    # __init__ should be created and decorated
    init_func = get_init_func(new_classdef)

def test_class_with_decorated_init_and_other_decorators():
    """If __init__ has multiple decorators, codeflash_capture is inserted first."""
    code = """
class MyClass:
    @other_decorator
    def __init__(self): pass
"""
    classdef = get_classdef_from_code(code)
    decorator = InitDecorator(TARGET_CLASSES, FTO_NAME, TMP_DIR_PATH, TESTS_ROOT, IS_FTO)
    codeflash_output = decorator.visit_ClassDef(classdef); new_classdef = codeflash_output # 8.27μs -> 3.50μs (136% faster)
    init_func = get_init_func(new_classdef)

def test_class_with_init_and_various_arg_styles():
    """__init__ with self, *args, **kwargs is decorated."""
    code = """
class MyClass:
    def __init__(self, *args, **kwargs): pass
"""
    classdef = get_classdef_from_code(code)
    decorator = InitDecorator(TARGET_CLASSES, FTO_NAME, TMP_DIR_PATH, TESTS_ROOT, IS_FTO)
    codeflash_output = decorator.visit_ClassDef(classdef); new_classdef = codeflash_output # 7.92μs -> 3.32μs (139% faster)
    init_func = get_init_func(new_classdef)

def test_class_with_init_and_posonlyargs():
    """__init__ with positional-only args is decorated."""
    # Python 3.8+ only
    if sys.version_info < (3, 8):
        pytest.skip("posonlyargs only in Python 3.8+")
    code = """
class MyClass:
    def __init__(self, a, /, b): pass
"""
    tree = ast.parse(code)
    classdef = next(n for n in tree.body if isinstance(n, ast.ClassDef))
    decorator = InitDecorator(TARGET_CLASSES, FTO_NAME, TMP_DIR_PATH, TESTS_ROOT, IS_FTO)
    codeflash_output = decorator.visit_ClassDef(classdef); new_classdef = codeflash_output # 14.3μs -> 7.18μs (99.3% faster)
    init_func = get_init_func(new_classdef)

def test_class_with_init_and_kwonlyargs():
    """__init__ with keyword-only args is decorated."""
    code = """
class MyClass:
    def __init__(self, *, a, b): pass
"""
    classdef = get_classdef_from_code(code)
    decorator = InitDecorator(TARGET_CLASSES, FTO_NAME, TMP_DIR_PATH, TESTS_ROOT, IS_FTO)
    codeflash_output = decorator.visit_ClassDef(classdef); new_classdef = codeflash_output # 7.86μs -> 3.31μs (138% faster)
    init_func = get_init_func(new_classdef)

def test_class_with_init_and_annotations():
    """__init__ with type annotations is decorated."""
    code = """
class MyClass:
    def __init__(self: "MyClass", x: int) -> None: pass
"""
    classdef = get_classdef_from_code(code)
    decorator = InitDecorator(TARGET_CLASSES, FTO_NAME, TMP_DIR_PATH, TESTS_ROOT, IS_FTO)
    codeflash_output = decorator.visit_ClassDef(classdef); new_classdef = codeflash_output # 7.87μs -> 3.37μs (134% faster)
    init_func = get_init_func(new_classdef)

# --------- 3. LARGE SCALE TEST CASES ---------

def test_many_classes_only_target_modified():
    """Only target classes are modified among many classes."""
    # Build code with 100 classes, only 3 are targets
    code = "\n".join(
        f"class Class{i}:\n    def __init__(self): pass"
        for i in range(100)
    )
    # Add targets
    code += """
class MyClass:
    def __init__(self): pass
class EdgeClass:
    pass
class BigClass:
    pass
"""
    tree = ast.parse(code)
    decorator = InitDecorator(TARGET_CLASSES, FTO_NAME, TMP_DIR_PATH, TESTS_ROOT, IS_FTO)
    # Visit all classes
    new_bodies = []
    for node in tree.body:
        if isinstance(node, ast.ClassDef):
            new_bodies.append(decorator.visit_ClassDef(node))
        else:
            new_bodies.append(node)
    # Check non-targets are not modified
    for i in range(100):
        classdef = new_bodies[i]
        init_func = get_init_func(classdef)
    # Targets are modified
    for name in ["MyClass", "EdgeClass", "BigClass"]:
        classdef = next(c for c in new_bodies if c.name == name)
        init_func = get_init_func(classdef)

def test_large_class_with_many_methods():
    """A class with many methods and no __init__ gets __init__ inserted and decorated."""
    methods = "\n".join(f"    def method{i}(self): pass" for i in range(500))
    code = f"""
class BigClass:
{methods}
"""
    classdef = get_classdef_from_code(code)
    decorator = InitDecorator(TARGET_CLASSES, FTO_NAME, TMP_DIR_PATH, TESTS_ROOT, IS_FTO)
    codeflash_output = decorator.visit_ClassDef(classdef); new_classdef = codeflash_output # 58.9μs -> 50.3μs (17.1% faster)
    # __init__ should be present and decorated
    init_func = get_init_func(new_classdef)
    # Should still have all other methods
    method_names = {n.name for n in new_classdef.body if isinstance(n, ast.FunctionDef)}

def test_large_class_with_existing_init_and_many_methods():
    """A class with many methods and an existing __init__ gets decorator added only to __init__."""
    methods = "\n".join(f"    def method{i}(self): pass" for i in range(500))
    code = f"""
class BigClass:
    def __init__(self): pass
{methods}
"""
    classdef = get_classdef_from_code(code)
    decorator = InitDecorator(TARGET_CLASSES, FTO_NAME, TMP_DIR_PATH, TESTS_ROOT, IS_FTO)
    codeflash_output = decorator.visit_ClassDef(classdef); new_classdef = codeflash_output # 50.7μs -> 42.9μs (18.3% faster)
    init_func = get_init_func(new_classdef)
    # All other methods should not be decorated
    for n in new_classdef.body:
        if isinstance(n, ast.FunctionDef) and n.name != "__init__":
            pass

def test_performance_many_classes_and_methods():
    """Performance: process 100 classes each with 10 methods."""
    code = "\n".join(
        f"class Class{i}:\n" +
        "\n".join(f"    def method{j}(self): pass" for j in range(10))
        for i in range(100)
    )
    # Add a target class at the end
    code += """
class MyClass:
    def foo(self): pass
"""
    tree = ast.parse(code)
    decorator = InitDecorator(TARGET_CLASSES, FTO_NAME, TMP_DIR_PATH, TESTS_ROOT, IS_FTO)
    # Visit all classes
    new_bodies = []
    for node in tree.body:
        if isinstance(node, ast.ClassDef):
            new_bodies.append(decorator.visit_ClassDef(node))
        else:
            new_bodies.append(node)
    # Only MyClass should have __init__ inserted and decorated
    myclass = next(c for c in new_bodies if c.name == "MyClass")
    init_func = get_init_func(myclass)
    # All other classes should have no codeflash_capture decorator
    for c in new_bodies:
        if isinstance(c, ast.ClassDef) and c.name != "MyClass":
            for n in c.body:
                if isinstance(n, ast.FunctionDef):
                    pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from __future__ import annotations

import ast
import sys
from pathlib import Path

# imports
import pytest  # used for our unit tests
from codeflash.verification.instrument_codeflash_capture import InitDecorator

# Helper functions for tests

def get_classdef_from_code(code: str, class_name: str = None) -> ast.ClassDef:
    """Parse code and return the ast.ClassDef node for the given class name (or first class if None)."""
    tree = ast.parse(code)
    for node in tree.body:
        if isinstance(node, ast.ClassDef):
            if class_name is None or node.name == class_name:
                return node
    raise ValueError("No class found in code")

def get_init_funcdef(classdef: ast.ClassDef) -> ast.FunctionDef | None:
    """Return the __init__ method ast.FunctionDef from a classdef, or None."""
    for node in classdef.body:
        if isinstance(node, ast.FunctionDef) and node.name == "__init__":
            return node
    return None

def has_codeflash_capture_decorator(funcdef: ast.FunctionDef) -> bool:
    """Check if the function has a codeflash_capture decorator."""
    for d in funcdef.decorator_list:
        if (
            isinstance(d, ast.Call)
            and isinstance(d.func, ast.Name)
            and d.func.id == "codeflash_capture"
        ):
            return True
    return False

def get_codeflash_capture_decorator(funcdef: ast.FunctionDef):
    """Return the codeflash_capture decorator ast.Call node, or None."""
    for d in funcdef.decorator_list:
        if (
            isinstance(d, ast.Call)
            and isinstance(d.func, ast.Name)
            and d.func.id == "codeflash_capture"
        ):
            return d
    return None

def get_decorator_keyword_arg(decorator: ast.Call, argname: str):
    """Get the value of a keyword argument from a decorator ast.Call."""
    for kw in decorator.keywords:
        if kw.arg == argname:
            return kw.value
    return None

# ----------------------
# Basic Test Cases
# ----------------------

def test_adds_decorator_to_existing_init():
    """Test that the decorator is added to a class with an existing __init__."""
    code = """
class MyClass:
    def __init__(self, x):
        self.x = x
    def foo(self): pass
"""
    classdef = get_classdef_from_code(code, "MyClass")
    dec = InitDecorator(
        target_classes={"MyClass"},
        fto_name="fto",
        tmp_dir_path="/tmp/dir",
        tests_root=Path("/tests/root"),
        is_fto=True,
    )
    codeflash_output = dec.visit_ClassDef(classdef); new_classdef = codeflash_output # 12.2μs -> 3.92μs (213% faster)
    init_func = get_init_funcdef(new_classdef)
    # Check decorator arguments
    decorator = get_codeflash_capture_decorator(init_func)

def test_does_not_modify_other_classes():
    """Test that classes not in target_classes are not modified."""
    code = """
class NotTarget:
    def __init__(self): pass
"""
    classdef = get_classdef_from_code(code, "NotTarget")
    dec = InitDecorator(
        target_classes={"SomeOtherClass"},
        fto_name="fto",
        tmp_dir_path="/tmp/dir",
        tests_root=Path("/tests/root"),
    )
    codeflash_output = dec.visit_ClassDef(classdef); new_classdef = codeflash_output # 641ns -> 671ns (4.47% slower)
    init_func = get_init_funcdef(new_classdef)

def test_adds_init_if_missing():
    """Test that __init__ is created if missing, with correct decorator and super call."""
    code = """
class NewClass:
    def foo(self): pass
"""
    classdef = get_classdef_from_code(code, "NewClass")
    dec = InitDecorator(
        target_classes={"NewClass"},
        fto_name="fto",
        tmp_dir_path="TMP",
        tests_root=Path("ROOT"),
        is_fto=False,
    )
    codeflash_output = dec.visit_ClassDef(classdef); new_classdef = codeflash_output # 17.3μs -> 7.03μs (146% faster)
    init_func = get_init_funcdef(new_classdef)
    # The function should have arguments: self, *args, **kwargs
    argnames = [a.arg for a in init_func.args.args]
    # The body should have a super().__init__(*args, **kwargs) call
    call = init_func.body[0]
    # Check decorator args
    decorator = get_codeflash_capture_decorator(init_func)

def test_does_not_duplicate_decorator():
    """Test that the decorator is not duplicated if already present."""
    code = """
class MyClass:
    @codeflash_capture(function_name="MyClass.__init__", tmp_dir_path="/tmp/dir", tests_root="/tests/root", is_fto=True)
    def __init__(self, x):
        self.x = x
"""
    classdef = get_classdef_from_code(code, "MyClass")
    dec = InitDecorator(
        target_classes={"MyClass"},
        fto_name="fto",
        tmp_dir_path="/tmp/dir",
        tests_root=Path("/tests/root"),
        is_fto=True,
    )
    codeflash_output = dec.visit_ClassDef(classdef); new_classdef = codeflash_output # 13.0μs -> 4.03μs (223% faster)
    init_func = get_init_funcdef(new_classdef)
    # Should still have only one codeflash_capture decorator
    codeflash_count = sum(
        1
        for d in init_func.decorator_list
        if isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == "codeflash_capture"
    )

def test_multiple_methods_only_init_decorated():
    """Test that only __init__ gets the decorator, not other methods."""
    code = """
class MyClass:
    def __init__(self, x): self.x = x
    def foo(self): pass
    def bar(self): pass
"""
    classdef = get_classdef_from_code(code, "MyClass")
    dec = InitDecorator(
        target_classes={"MyClass"},
        fto_name="fto",
        tmp_dir_path="TMP",
        tests_root=Path("ROOT"),
        is_fto=False,
    )
    codeflash_output = dec.visit_ClassDef(classdef); new_classdef = codeflash_output # 10.9μs -> 3.83μs (184% faster)
    for node in new_classdef.body:
        if isinstance(node, ast.FunctionDef):
            if node.name == "__init__":
                pass
            else:
                pass

# ----------------------
# Edge Test Cases
# ----------------------

def test_init_with_different_self_name():
    """Test that __init__ with first arg not 'self' is ignored (should create new __init__)."""
    code = """
class WeirdInit:
    def __init__(notself, x): pass
"""
    classdef = get_classdef_from_code(code, "WeirdInit")
    dec = InitDecorator(
        target_classes={"WeirdInit"},
        fto_name="fto",
        tmp_dir_path="TMP",
        tests_root=Path("ROOT"),
    )
    codeflash_output = dec.visit_ClassDef(classdef); new_classdef = codeflash_output # 16.9μs -> 7.21μs (135% faster)
    # Should create a new __init__ with correct 'self'
    # There will be two __init__s, one with notself, one with self
    inits = [n for n in new_classdef.body if isinstance(n, ast.FunctionDef) and n.name == "__init__"]
    found = False
    for init_func in inits:
        if init_func.args.args and init_func.args.args[0].arg == "self":
            found = True

def test_class_with_no_body():
    """Test class with empty body (should create __init__)."""
    code = "class Empty: pass"
    classdef = get_classdef_from_code(code, "Empty")
    dec = InitDecorator(
        target_classes={"Empty"},
        fto_name="fto",
        tmp_dir_path="TMP",
        tests_root=Path("ROOT"),
    )
    codeflash_output = dec.visit_ClassDef(classdef); new_classdef = codeflash_output # 16.1μs -> 7.02μs (130% faster)
    init_func = get_init_funcdef(new_classdef)

def test_class_with_multiple_decorators_on_init():
    """Test that codeflash_capture is inserted at the start of the decorator list."""
    code = """
class DecoratedInit:
    @other_decorator
    def __init__(self): pass
"""
    classdef = get_classdef_from_code(code, "DecoratedInit")
    dec = InitDecorator(
        target_classes={"DecoratedInit"},
        fto_name="fto",
        tmp_dir_path="TMP",
        tests_root=Path("ROOT"),
    )
    codeflash_output = dec.visit_ClassDef(classdef); new_classdef = codeflash_output # 10.5μs -> 3.61μs (191% faster)
    init_func = get_init_funcdef(new_classdef)
    # Should be first in the list
    first = init_func.decorator_list[0]

def test_class_with_init_no_args():
    """Test __init__ with no arguments (should not be decorated, new __init__ should be created)."""
    code = """
class NoArgInit:
    def __init__(): pass
"""
    classdef = get_classdef_from_code(code, "NoArgInit")
    dec = InitDecorator(
        target_classes={"NoArgInit"},
        fto_name="fto",
        tmp_dir_path="TMP",
        tests_root=Path("ROOT"),
    )
    codeflash_output = dec.visit_ClassDef(classdef); new_classdef = codeflash_output # 16.1μs -> 6.79μs (137% faster)
    # Should have two __init__s: one original, one new with self
    inits = [n for n in new_classdef.body if isinstance(n, ast.FunctionDef) and n.name == "__init__"]
    found = False
    for init_func in inits:
        if init_func.args.args and init_func.args.args[0].arg == "self":
            found = True

def test_target_class_with_nested_class():
    """Test that only the outer class is decorated, not the nested class."""
    code = """
class Outer:
    def __init__(self): pass
    class Inner:
        def __init__(self): pass
"""
    classdef = get_classdef_from_code(code, "Outer")
    dec = InitDecorator(
        target_classes={"Outer"},
        fto_name="fto",
        tmp_dir_path="TMP",
        tests_root=Path("ROOT"),
    )
    codeflash_output = dec.visit_ClassDef(classdef); new_classdef = codeflash_output # 10.3μs -> 3.66μs (182% faster)
    # Outer __init__ should be decorated
    outer_init = get_init_funcdef(new_classdef)
    # Inner __init__ should not be decorated
    for node in new_classdef.body:
        if isinstance(node, ast.ClassDef) and node.name == "Inner":
            inner_init = get_init_funcdef(node)

def test_target_class_with_no_methods():
    """Test a class with no methods at all (should add __init__)."""
    code = """
class NoMethods:
    pass
"""
    classdef = get_classdef_from_code(code, "NoMethods")
    dec = InitDecorator(
        target_classes={"NoMethods"},
        fto_name="fto",
        tmp_dir_path="TMP",
        tests_root=Path("ROOT"),
    )
    codeflash_output = dec.visit_ClassDef(classdef); new_classdef = codeflash_output # 15.8μs -> 7.03μs (125% faster)
    init_func = get_init_funcdef(new_classdef)

# ----------------------
# Large Scale Test Cases
# ----------------------

def test_large_number_of_classes_only_target_modified():
    """Test with many classes, only target class is modified."""
    code = "\n".join(
        f"class Cls{i}:\n    def foo(self): pass"
        for i in range(100)
    )
    # Add a target class in the middle
    code += "\nclass Target:\n    def __init__(self): pass\n"
    tree = ast.parse(code)
    dec = InitDecorator(
        target_classes={"Target"},
        fto_name="fto",
        tmp_dir_path="TMP",
        tests_root=Path("ROOT"),
    )
    # Transform all classdefs
    new_bodies = []
    for node in tree.body:
        if isinstance(node, ast.ClassDef):
            new_bodies.append(dec.visit_ClassDef(node))
        else:
            new_bodies.append(node)
    # Only Target should have codeflash_capture on __init__
    for node in new_bodies:
        if isinstance(node, ast.ClassDef):
            if node.name == "Target":
                init_func = get_init_funcdef(node)
            else:
                # Other classes have no __init__, so no decorator
                init_func = get_init_funcdef(node)
                if init_func:
                    pass

def test_large_class_with_many_methods():
    """Test a class with many methods, only __init__ is decorated."""
    code = "class BigClass:\n"
    code += "    def __init__(self): pass\n"
    for i in range(200):
        code += f"    def method{i}(self): pass\n"
    classdef = get_classdef_from_code(code, "BigClass")
    dec = InitDecorator(
        target_classes={"BigClass"},
        fto_name="fto",
        tmp_dir_path="TMP",
        tests_root=Path("ROOT"),
    )
    codeflash_output = dec.visit_ClassDef(classdef); new_classdef = codeflash_output # 29.6μs -> 21.5μs (37.8% faster)
    for node in new_classdef.body:
        if isinstance(node, ast.FunctionDef):
            if node.name == "__init__":
                pass
            else:
                pass

def test_large_scale_adds_init_to_many_classes():
    """Test adding __init__ to many classes that lack it."""
    code = "\n".join(
        f"class Cls{i}:\n    def foo(self): pass"
        for i in range(50)
    )
    tree = ast.parse(code)
    dec = InitDecorator(
        target_classes={f"Cls{i}" for i in range(50)},
        fto_name="fto",
        tmp_dir_path="TMP",
        tests_root=Path("ROOT"),
    )
    # Transform all classdefs
    new_bodies = []
    for node in tree.body:
        if isinstance(node, ast.ClassDef):
            new_bodies.append(dec.visit_ClassDef(node))
        else:
            new_bodies.append(node)
    # All target classes should now have __init__ with decorator
    for node in new_bodies:
        if isinstance(node, ast.ClassDef):
            init_func = get_init_funcdef(node)

def test_large_class_with_nested_classes():
    """Test a large class with many nested classes, only the outer is decorated."""
    code = "class Outer:\n"
    code += "    def __init__(self): pass\n"
    for i in range(20):
        code += f"    class Inner{i}:\n        def __init__(self): pass\n"
    classdef = get_classdef_from_code(code, "Outer")
    dec = InitDecorator(
        target_classes={"Outer"},
        fto_name="fto",
        tmp_dir_path="TMP",
        tests_root=Path("ROOT"),
    )
    codeflash_output = dec.visit_ClassDef(classdef); new_classdef = codeflash_output # 12.6μs -> 4.77μs (163% faster)
    # Outer __init__ should be decorated
    outer_init = get_init_funcdef(new_classdef)
    # None of the inner classes should have decorated __init__
    for node in new_classdef.body:
        if isinstance(node, ast.ClassDef):
            inner_init = get_init_funcdef(node)

def test_performance_with_maximum_classes():
    """Test performance and correctness with 1000 classes, only last is target."""
    code = "\n".join(
        f"class Cls{i}:\n    def foo(self): pass"
        for i in range(999)
    )
    code += "\nclass Target:\n    def __init__(self): pass\n"
    tree = ast.parse(code)
    dec = InitDecorator(
        target_classes={"Target"},
        fto_name="fto",
        tmp_dir_path="TMP",
        tests_root=Path("ROOT"),
    )
    # Transform all classdefs
    new_bodies = []
    for node in tree.body:
        if isinstance(node, ast.ClassDef):
            new_bodies.append(dec.visit_ClassDef(node))
        else:
            new_bodies.append(node)
    # Only Target should have codeflash_capture on __init__
    for node in new_bodies:
        if isinstance(node, ast.ClassDef):
            if node.name == "Target":
                init_func = get_init_funcdef(node)
            else:
                init_func = get_init_funcdef(node)
                if init_func:
                    pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from ast import ClassDef
from codeflash.verification.instrument_codeflash_capture import InitDecorator
from pathlib import Path
import pytest

def test_InitDecorator_visit_ClassDef():
    with pytest.raises(AttributeError, match="'ClassDef'\\ object\\ has\\ no\\ attribute\\ 'name'"):
        InitDecorator.visit_ClassDef(InitDecorator({''}, '', '', Path(), is_fto=0), ClassDef())
🔎 Concolic Coverage Tests and Runtime
Test File::Test Function Original ⏱️ Optimized ⏱️ Speedup
codeflash_concolic_nh__l8ip/tmpjzz9bzb_/test_concolic_coverage.py::test_InitDecorator_visit_ClassDef 3.13μs 2.58μs 21.3%✅

To edit these changes git checkout codeflash/optimize-pr363-2025-09-30T01.44.19 and push.

Codeflash

The optimized code achieves a **65% speedup** through strategic precomputation of AST nodes that are repeatedly created during class processing.

**Key optimizations:**

1. **Precomputed AST components in `__init__`**: Instead of reconstructing identical AST nodes (like `ast.Name`, `ast.arg`, `ast.Constant`) on every `visit_ClassDef` call, the optimized version creates them once during initialization and reuses them. This eliminates the expensive AST node construction overhead seen in the profiler - lines creating decorator keywords and super() call components dropped from ~2ms total to ~0.6ms.

2. **Optimized decorator presence check**: Replaced the `any()` generator expression with a `for/else` loop that stops immediately when finding an existing `codeflash_capture` decorator. This avoids generator allocation overhead and short-circuits the search earlier.

3. **Reduced per-class AST construction**: The decorator is now built once per class using precomputed components, rather than reconstructing all keywords and function references from scratch each time.

**Performance impact by test type:**
- **Basic cases** (single class with simple `__init__`): ~140-220% faster, benefiting from reduced AST node construction
- **Edge cases** (classes needing synthetic `__init__`): ~100-150% faster, particularly benefiting from prebuilt super() call components  
- **Large scale** (many methods/classes): ~17-40% faster, where the constant-time optimizations compound across many iterations

The optimization is most effective for workloads processing many classes, as the upfront precomputation cost is amortized across multiple `visit_ClassDef` calls, directly addressing the bottleneck of repetitive AST node creation identified in the profiler.
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Sep 30, 2025
@KRRT7 KRRT7 merged commit 91870c0 into part-1-windows-fixes Sep 30, 2025
20 of 23 checks passed
@codeflash-ai codeflash-ai bot deleted the codeflash/optimize-pr363-2025-09-30T01.44.19 branch September 30, 2025 02:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant