Skip to content

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Nov 1, 2025

⚡️ This pull request contains optimizations for PR #867

If you approve this dependent PR, these changes will be merged into the original PR branch inspect-signature-issue.

This PR will be automatically closed if the original PR is merged.


📄 22% (0.22x) speedup for InjectPerfOnly.find_and_update_line_node in codeflash/code_utils/instrument_existing_tests.py

⏱️ Runtime : 2.01 milliseconds 1.64 milliseconds (best of 113 runs)

📝 Explanation and details

The optimized code achieves a 22% speedup through two main optimizations that reduce overhead in AST traversal and attribute lookups:

1. Custom AST traversal replaces expensive ast.walk()
The original code uses ast.walk() which creates recursive stack frames for every AST node. The optimized version implements iter_ast_calls() - a manual iterative traversal that only visits ast.Call nodes using a single stack. This eliminates Python's recursion overhead and reduces the O(N) stack frame creation to a single stack operation.

2. Reduced attribute lookups in hot paths

  • In node_in_call_position(): Uses getattr() with defaults to cache node attributes (node_lineno, node_end_lineno, etc.) instead of repeated hasattr() + attribute access
  • In find_and_update_line_node(): Hoists frequently-accessed object attributes (fn_obj.qualified_name, self.mode, etc.) to local variables before the loop
  • Pre-creates reusable AST nodes (codeflash_loop_index, codeflash_cur, codeflash_con) instead of recreating them in each iteration

Performance characteristics:

  • Small AST trees (basic function calls): 5-28% faster due to reduced attribute lookups
  • Large AST trees (deeply nested calls): 18-26% faster due to more efficient traversal avoiding ast.walk()
  • Large call position lists: 26% faster due to optimized position checking with cached attributes

The optimizations are most effective for complex test instrumentation scenarios with large AST trees or many call positions to check, which is typical in code analysis and transformation workflows.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 123 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 97.0%
🌀 Generated Regression Tests and Runtime
import ast
from typing import Any

# imports
import pytest
from codeflash.code_utils.instrument_existing_tests import InjectPerfOnly

# --- Minimal stubs for dependencies used in the function ---

class FunctionToOptimize:
    def __init__(self, function_name, qualified_name, is_async=False, parents=None, top_level_parent_name=None):
        self.function_name = function_name
        self.qualified_name = qualified_name
        self.is_async = is_async
        self.parents = parents or []
        self.top_level_parent_name = top_level_parent_name

class ParentStub:
    def __init__(self, type_, name):
        self.type = type_
        self.name = name

class CodePosition:
    def __init__(self, line_no=None, col_no=None, end_col_offset=None):
        self.line_no = line_no
        self.col_no = col_no
        self.end_col_offset = end_col_offset

class TestingMode:
    BEHAVIOR = "BEHAVIOR"
    PERFORMANCE = "PERFORMANCE"
from codeflash.code_utils.instrument_existing_tests import InjectPerfOnly

# --- Helper functions for test assertions ---

def get_first_call_node(stmt):
    """Recursively find the first ast.Call node in a statement."""
    for node in ast.walk(stmt):
        if isinstance(node, ast.Call):
            return node
    return None

def get_func_name(call_node):
    """Get function name from ast.Call node."""
    if isinstance(call_node.func, ast.Name):
        return call_node.func.id
    elif isinstance(call_node.func, ast.Attribute):
        return call_node.func.attr
    return None

def is_codeflash_wrap_call(call_node):
    return isinstance(call_node.func, ast.Name) and call_node.func.id == "codeflash_wrap"

def is_assign_to_bound_arguments(assign_node):
    return (
        isinstance(assign_node, ast.Assign)
        and any(isinstance(t, ast.Name) and t.id == "_call__bound__arguments" for t in assign_node.targets)
    )

def is_apply_defaults_expr(expr_node):
    return (
        isinstance(expr_node, ast.Expr)
        and isinstance(expr_node.value, ast.Call)
        and isinstance(expr_node.value.func, ast.Attribute)
        and expr_node.value.func.attr == "apply_defaults"
    )

# --- Test Cases ---

# 1. Basic Test Cases

def make_simple_call_stmt(func_name="foo", args=None, keywords=None, lineno=1, col_offset=0):
    """Create a simple ast.Expr node with a function call."""
    args = args or []
    keywords = keywords or []
    call = ast.Call(func=ast.Name(id=func_name, ctx=ast.Load()), args=args, keywords=keywords)
    call.lineno = lineno
    call.col_offset = col_offset
    call.end_lineno = lineno
    call.end_col_offset = col_offset + 10
    expr = ast.Expr(value=call, lineno=lineno, col_offset=col_offset)
    return expr

def make_method_call_stmt(obj_name="obj", method_name="foo", args=None, keywords=None, lineno=1, col_offset=0):
    """Create a simple ast.Expr node with a method call."""
    args = args or []
    keywords = keywords or []
    attr = ast.Attribute(value=ast.Name(id=obj_name, ctx=ast.Load()), attr=method_name, ctx=ast.Load())
    call = ast.Call(func=attr, args=args, keywords=keywords)
    call.lineno = lineno
    call.col_offset = col_offset
    call.end_lineno = lineno
    call.end_col_offset = col_offset + 10
    expr = ast.Expr(value=call, lineno=lineno, col_offset=col_offset)
    return expr

@pytest.fixture
def default_function_obj():
    return FunctionToOptimize(function_name="foo", qualified_name="foo", is_async=False)

@pytest.fixture
def default_method_function_obj():
    return FunctionToOptimize(function_name="foo", qualified_name="Bar.foo", is_async=False, parents=[ParentStub("ClassDef", "Bar")], top_level_parent_name="Bar")

@pytest.fixture
def call_position():
    return [CodePosition(line_no=1, col_no=0, end_col_offset=10)]

def test_basic_function_call_transformation(default_function_obj, call_position):
    """Test a basic function call is wrapped and signature binding is inserted."""
    node = make_simple_call_stmt()
    inj = InjectPerfOnly(default_function_obj, "module.py", "pytest", call_position, mode=TestingMode.BEHAVIOR)
    codeflash_output = inj.find_and_update_line_node(node, "test_foo", "0"); result = codeflash_output
    call_node = get_first_call_node(result[2])
    # Should have codeflash_loop_index and codeflash_cur/con in args
    arg_names = [a.id for a in call_node.args if isinstance(a, ast.Name)]

def test_basic_method_call_transformation(default_method_function_obj, call_position):
    """Test a method call is wrapped and signature binding is inserted."""
    node = make_method_call_stmt(obj_name="self", method_name="foo")
    inj = InjectPerfOnly(default_method_function_obj, "module.py", "pytest", call_position, mode=TestingMode.BEHAVIOR)
    codeflash_output = inj.find_and_update_line_node(node, "test_bar", "1", test_class_name="Bar"); result = codeflash_output
    call_node = get_first_call_node(result[2])


def test_no_call_node_returns_none(default_function_obj, call_position):
    """Test that if there is no call node, None is returned."""
    node = ast.Pass(lineno=1, col_offset=0)
    inj = InjectPerfOnly(default_function_obj, "module.py", "pytest", call_position)
    codeflash_output = inj.find_and_update_line_node(node, "test_none", "0"); result = codeflash_output # 4.76μs -> 4.82μs (1.22% slower)

def test_call_node_not_in_position_returns_none(default_function_obj):
    """Test that if call node's position does not match, None is returned."""
    node = make_simple_call_stmt()
    # call_position is on a different line
    call_position = [CodePosition(line_no=2, col_no=0, end_col_offset=10)]
    inj = InjectPerfOnly(default_function_obj, "module.py", "pytest", call_position)
    codeflash_output = inj.find_and_update_line_node(node, "test_none", "0"); result = codeflash_output # 8.84μs -> 8.18μs (8.09% faster)

def test_async_function_returns_original(default_function_obj, call_position):
    """Test that if the function is async, the node is returned unchanged."""
    async_func = FunctionToOptimize(function_name="foo", qualified_name="foo", is_async=True)
    node = make_simple_call_stmt()
    inj = InjectPerfOnly(async_func, "module.py", "pytest", call_position)
    codeflash_output = inj.find_and_update_line_node(node, "test_async", "0"); result = codeflash_output # 9.82μs -> 7.63μs (28.6% faster)

def test_multiple_calls_only_first_transformed(default_function_obj, call_position):
    """Test that only the first matching call is transformed."""
    # Create a node with two calls: foo() and bar()
    call1 = ast.Call(func=ast.Name(id="foo", ctx=ast.Load()), args=[], keywords=[])
    call1.lineno = 1
    call1.col_offset = 0
    call1.end_lineno = 1
    call1.end_col_offset = 10
    call2 = ast.Call(func=ast.Name(id="bar", ctx=ast.Load()), args=[], keywords=[])
    call2.lineno = 1
    call2.col_offset = 11
    call2.end_lineno = 1
    call2.end_col_offset = 20
    expr = ast.Expr(
        value=ast.Tuple(elts=[call1, call2], ctx=ast.Load()), lineno=1, col_offset=0
    )
    inj = InjectPerfOnly(default_function_obj, "module.py", "pytest", call_position)
    codeflash_output = inj.find_and_update_line_node(expr, "test_multi", "0"); result = codeflash_output # 24.8μs -> 21.8μs (13.5% faster)
    # Only the first call node (foo) should be wrapped
    call_node = get_first_call_node(result[-1])
    # The second call (bar) should remain unchanged
    tuple_elts = result[-1].value.elts

def test_call_with_args_and_keywords(default_function_obj, call_position):
    """Test that calls with arguments and keywords are handled."""
    args = [ast.Constant(value=1), ast.Constant(value=2)]
    keywords = [ast.keyword(arg="x", value=ast.Constant(value=3))]
    node = make_simple_call_stmt(args=args, keywords=keywords)
    inj = InjectPerfOnly(default_function_obj, "module.py", "pytest", call_position)
    codeflash_output = inj.find_and_update_line_node(node, "test_args", "0"); result = codeflash_output # 22.9μs -> 19.1μs (19.9% faster)
    call_node = get_first_call_node(result[-1])

def test_method_call_with_different_method(default_method_function_obj, call_position):
    """Test that method calls with a different method name are not wrapped."""
    node = make_method_call_stmt(obj_name="self", method_name="notfoo")
    inj = InjectPerfOnly(default_method_function_obj, "module.py", "pytest", call_position)
    codeflash_output = inj.find_and_update_line_node(node, "test_bar", "1", test_class_name="Bar"); result = codeflash_output # 12.2μs -> 11.5μs (5.90% faster)

# 3. Large Scale Test Cases


def test_large_ast_tree_with_nested_calls(default_function_obj):
    """Test a large AST tree with deeply nested calls."""
    # Create a nested call: foo(bar(baz(1)))
    call_baz = ast.Call(func=ast.Name(id="baz", ctx=ast.Load()), args=[ast.Constant(value=1)], keywords=[])
    call_baz.lineno = 1
    call_baz.col_offset = 0
    call_baz.end_lineno = 1
    call_baz.end_col_offset = 5
    call_bar = ast.Call(func=ast.Name(id="bar", ctx=ast.Load()), args=[call_baz], keywords=[])
    call_bar.lineno = 1
    call_bar.col_offset = 6
    call_bar.end_lineno = 1
    call_bar.end_col_offset = 11
    call_foo = ast.Call(func=ast.Name(id="foo", ctx=ast.Load()), args=[call_bar], keywords=[])
    call_foo.lineno = 1
    call_foo.col_offset = 12
    call_foo.end_lineno = 1
    call_foo.end_col_offset = 17
    expr = ast.Expr(value=call_foo, lineno=1, col_offset=0)
    # Only the outermost call matches the position
    call_position = [CodePosition(line_no=1, col_no=12, end_col_offset=17)]
    inj = InjectPerfOnly(default_function_obj, "module.py", "pytest", call_position)
    codeflash_output = inj.find_and_update_line_node(expr, "test_nested", "0"); result = codeflash_output # 23.3μs -> 19.6μs (18.9% faster)
    call_node = get_first_call_node(result[-1])
    # The argument to foo should be bar(baz(1)), not wrapped
    arg = call_node.args[-1]
    if isinstance(arg, ast.Starred):
        # In BEHAVIOR mode, the argument is a Starred node
        pass
    else:
        pass

def test_large_call_position_list(default_function_obj):
    """Test with a large call_positions list, only matching positions wrap calls."""
    node = make_simple_call_stmt(lineno=50)
    # 999 positions that don't match, and 1 that does
    positions = [CodePosition(line_no=i, col_no=0, end_col_offset=10) for i in range(1, 1000)]
    positions.append(CodePosition(line_no=50, col_no=0, end_col_offset=10))
    inj = InjectPerfOnly(default_function_obj, "module.py", "pytest", positions)
    codeflash_output = inj.find_and_update_line_node(node, "test_largepos", "0"); result = codeflash_output # 28.4μs -> 22.5μs (26.2% faster)
    call_node = get_first_call_node(result[-1])
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import ast
from types import SimpleNamespace

# imports
import pytest
from codeflash.code_utils.instrument_existing_tests import InjectPerfOnly


# function to test
def find_and_update_line_node(
    test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None
) -> list[ast.stmt] | None:
    """
    Given an AST statement node, finds the first ast.Call node within it that matches a call position,
    and rewrites it (and possibly adds argument binding statements) as described in the original code.
    Returns a list of AST statement nodes (possibly including the original test_node and new nodes),
    or None if no matching call is found.
    """
    # For testability, we'll simulate a minimal context for the function:
    # We'll use a global variable for call_positions, function_object, module_path, mode, etc.
    # (In production, these are attributes of the class, but for testing we inject them globally.)
    global _find_and_update_ctx
    ctx = _find_and_update_ctx

    def get_call_arguments(call_node: ast.Call):
        return SimpleNamespace(args=call_node.args, keywords=call_node.keywords)

    def node_in_call_position(node: ast.AST, call_positions: list):
        if isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset"):
            for pos in call_positions:
                if (
                    pos.line_no is not None
                    and getattr(node, "end_lineno", None) is not None
                    and node.lineno <= pos.line_no <= node.end_lineno
                ):
                    if pos.line_no == node.lineno and node.col_offset <= pos.col_no:
                        return True
                    if (
                        pos.line_no == node.end_lineno
                        and pos.end_col_offset is not None
                        and pos.end_col_offset >= pos.col_no
                    ):
                        return True
                    if node.lineno < pos.line_no < node.end_lineno:
                        return True
        return False

    return_statement = [test_node]
    call_node = None
    for node in ast.walk(test_node):
        if isinstance(node, ast.Call) and node_in_call_position(node, ctx.call_positions):
            call_node = node
            all_args = get_call_arguments(call_node)
            if isinstance(node.func, ast.Name):
                function_name = node.func.id

                if getattr(ctx.function_object, "is_async", False):
                    return [test_node]

                # Create the signature binding statements
                bind_call = ast.Assign(
                    targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())],
                    value=ast.Call(
                        func=ast.Attribute(
                            value=ast.Call(
                                func=ast.Attribute(
                                    value=ast.Name(id="inspect", ctx=ast.Load()), attr="signature", ctx=ast.Load()
                                ),
                                args=[ast.Name(id=function_name, ctx=ast.Load())],
                                keywords=[],
                            ),
                            attr="bind",
                            ctx=ast.Load(),
                        ),
                        args=all_args.args,
                        keywords=all_args.keywords,
                    ),
                    lineno=test_node.lineno,
                    col_offset=test_node.col_offset,
                )

                apply_defaults = ast.Expr(
                    value=ast.Call(
                        func=ast.Attribute(
                            value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
                            attr="apply_defaults",
                            ctx=ast.Load(),
                        ),
                        args=[],
                        keywords=[],
                    ),
                    lineno=test_node.lineno + 1,
                    col_offset=test_node.col_offset,
                )

                node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
                node.args = [
                    ast.Name(id=function_name, ctx=ast.Load()),
                    ast.Constant(value=ctx.module_path),
                    ast.Constant(value=test_class_name or None),
                    ast.Constant(value=node_name),
                    ast.Constant(value=getattr(ctx.function_object, "qualified_name", "")),
                    ast.Constant(value=index),
                    ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
                    *(
                        [ast.Name(id="codeflash_cur", ctx=ast.Load()), ast.Name(id="codeflash_con", ctx=ast.Load())]
                        if ctx.mode == "BEHAVIOR"
                        else []
                    ),
                    *(
                        call_node.args
                        if ctx.mode == "PERFORMANCE"
                        else [
                            ast.Starred(
                                value=ast.Attribute(
                                    value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
                                    attr="args",
                                    ctx=ast.Load(),
                                ),
                                ctx=ast.Load(),
                            )
                        ]
                    ),
                ]
                node.keywords = (
                    [
                        ast.keyword(
                            value=ast.Attribute(
                                value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
                                attr="kwargs",
                                ctx=ast.Load(),
                            )
                        )
                    ]
                    if ctx.mode == "BEHAVIOR"
                    else call_node.keywords
                )

                return_statement = (
                    [bind_call, apply_defaults, test_node] if ctx.mode == "BEHAVIOR" else [test_node]
                )
                break
            if isinstance(node.func, ast.Attribute):
                function_to_test = node.func.attr
                if function_to_test == getattr(ctx.function_object, "function_name", ""):
                    if getattr(ctx.function_object, "is_async", False):
                        return [test_node]

                    function_name = ast.unparse(node.func)

                    bind_call = ast.Assign(
                        targets=[ast.Name(id="_call__bound__arguments", ctx=ast.Store())],
                        value=ast.Call(
                            func=ast.Attribute(
                                value=ast.Call(
                                    func=ast.Attribute(
                                        value=ast.Name(id="inspect", ctx=ast.Load()),
                                        attr="signature",
                                        ctx=ast.Load(),
                                    ),
                                    args=[ast.parse(function_name, mode="eval").body],
                                    keywords=[],
                                ),
                                attr="bind",
                                ctx=ast.Load(),
                            ),
                            args=all_args.args,
                            keywords=all_args.keywords,
                        ),
                        lineno=test_node.lineno,
                        col_offset=test_node.col_offset,
                    )

                    apply_defaults = ast.Expr(
                        value=ast.Call(
                            func=ast.Attribute(
                                value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
                                attr="apply_defaults",
                                ctx=ast.Load(),
                            ),
                            args=[],
                            keywords=[],
                        ),
                        lineno=test_node.lineno + 1,
                        col_offset=test_node.col_offset,
                    )

                    node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
                    node.args = [
                        ast.parse(function_name, mode="eval").body,
                        ast.Constant(value=ctx.module_path),
                        ast.Constant(value=test_class_name or None),
                        ast.Constant(value=node_name),
                        ast.Constant(value=getattr(ctx.function_object, "qualified_name", "")),
                        ast.Constant(value=index),
                        ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
                        *(
                            [
                                ast.Name(id="codeflash_cur", ctx=ast.Load()),
                                ast.Name(id="codeflash_con", ctx=ast.Load()),
                            ]
                            if ctx.mode == "BEHAVIOR"
                            else []
                        ),
                        *(
                            call_node.args
                            if ctx.mode == "PERFORMANCE"
                            else [
                                ast.Starred(
                                    value=ast.Attribute(
                                        value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
                                        attr="args",
                                        ctx=ast.Load(),
                                    ),
                                    ctx=ast.Load(),
                                )
                            ]
                        ),
                    ]
                    node.keywords = (
                        [
                            ast.keyword(
                                value=ast.Attribute(
                                    value=ast.Name(id="_call__bound__arguments", ctx=ast.Load()),
                                    attr="kwargs",
                                    ctx=ast.Load(),
                                )
                            )
                        ]
                        if ctx.mode == "BEHAVIOR"
                        else call_node.keywords
                    )

                    return_statement = (
                        [bind_call, apply_defaults, test_node] if ctx.mode == "BEHAVIOR" else [test_node]
                    )
                    break

    if call_node is None:
        return None
    return return_statement

# Helper classes for simulating context and positions
class DummyFunctionObj:
    def __init__(self, function_name="foo", qualified_name="foo", is_async=False):
        self.function_name = function_name
        self.qualified_name = qualified_name
        self.is_async = is_async
        self.parents = []
        self.top_level_parent_name = None

class DummyCodePosition:
    def __init__(self, line_no, col_no, end_col_offset=None):
        self.line_no = line_no
        self.col_no = col_no
        self.end_col_offset = end_col_offset

# 1. Basic Test Cases

To edit these changes git checkout codeflash/optimize-pr867-2025-11-01T00.02.02 and push.

Codeflash Static Badge

The optimized code achieves a **22% speedup** through two main optimizations that reduce overhead in AST traversal and attribute lookups:

**1. Custom AST traversal replaces expensive `ast.walk()`**
The original code uses `ast.walk()` which creates recursive stack frames for every AST node. The optimized version implements `iter_ast_calls()` - a manual iterative traversal that only visits `ast.Call` nodes using a single stack. This eliminates Python's recursion overhead and reduces the O(N) stack frame creation to a single stack operation.

**2. Reduced attribute lookups in hot paths**
- In `node_in_call_position()`: Uses `getattr()` with defaults to cache node attributes (`node_lineno`, `node_end_lineno`, etc.) instead of repeated `hasattr()` + attribute access
- In `find_and_update_line_node()`: Hoists frequently-accessed object attributes (`fn_obj.qualified_name`, `self.mode`, etc.) to local variables before the loop
- Pre-creates reusable AST nodes (`codeflash_loop_index`, `codeflash_cur`, `codeflash_con`) instead of recreating them in each iteration

**Performance characteristics:**
- **Small AST trees** (basic function calls): 5-28% faster due to reduced attribute lookups
- **Large AST trees** (deeply nested calls): 18-26% faster due to more efficient traversal avoiding `ast.walk()`
- **Large call position lists**: 26% faster due to optimized position checking with cached attributes

The optimizations are most effective for complex test instrumentation scenarios with large AST trees or many call positions to check, which is typical in code analysis and transformation workflows.
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 1, 2025
@misrasaurabh1 misrasaurabh1 merged commit 40e82e2 into inspect-signature-issue Nov 1, 2025
14 of 23 checks passed
@codeflash-ai codeflash-ai bot deleted the codeflash/optimize-pr867-2025-11-01T00.02.02 branch November 1, 2025 00:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants