Skip to content

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Sep 27, 2025

⚡️ This pull request contains optimizations for PR #769

If you approve this dependent PR, these changes will be merged into the original PR branch clean-async-branch.

This PR will be automatically closed if the original PR is merged.


📄 12% (0.12x) speedup for AsyncCallInstrumenter._instrument_statement in codeflash/code_utils/instrument_existing_tests.py

⏱️ Runtime : 17.5 milliseconds 15.7 milliseconds (best of 39 runs)

📝 Explanation and details

The optimized code achieves an 11% speedup by replacing the expensive ast.walk() with a custom stack-based traversal that supports early termination.

Key optimizations:

  1. Stack-based AST traversal with early exit: Instead of ast.walk() which must visit every node, the optimized version uses a manual stack that immediately returns True when finding a matching Await node, avoiding unnecessary traversal of remaining subtrees.

  2. Function name caching: Pre-stores self._function_name = function.function_name in __init__ to eliminate repeated attribute lookups in _is_target_call().

  3. Local variable optimization: Extracts func = call_node.func to reduce repeated attribute access.

Performance impact by test type:

  • Small/simple statements (basic tests): 27-106% faster due to reduced traversal overhead
  • Complex nested expressions: 14% improvement as early exit helps when matches are found
  • Large-scale scenarios: 6-22% improvement, with better gains when fewer matches occur (early termination is more effective)

The optimization is most effective when matches are found early in the AST traversal, as it can skip examining the remaining nodes entirely. Line profiling shows the stack-based approach reduces the expensive ast.walk() overhead from 29% to 21.5% of total time in _instrument_statement.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 2025 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 71.4%
🌀 Generated Regression Tests and Runtime
from __future__ import annotations

import ast
from typing import Any

# imports
import pytest  # used for our unit tests
from codeflash.code_utils.instrument_existing_tests import \
    AsyncCallInstrumenter
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodePosition, TestingMode


class DummyFunctionToOptimize:
    def __init__(self, function_name: str, parents=None, top_level_parent_name=None):
        self.function_name = function_name
        self.parents = parents or []
        self.top_level_parent_name = top_level_parent_name

class DummyCodePosition:
    def __init__(self, line_no, col_no, end_col_offset=None):
        self.line_no = line_no
        self.col_no = col_no
        self.end_col_offset = end_col_offset

class DummyTestingMode:
    BEHAVIOR = "BEHAVIOR"

# Helper to parse code and get ast.stmt
def get_first_stmt_from_code(code: str) -> ast.stmt:
    tree = ast.parse(code)
    # Return the first statement in the module
    return tree.body[0]

# Helper to set lineno/col_offset on AST nodes for tests
def set_ast_positions(node: ast.AST, lineno: int, col_offset: int, end_lineno: int = None, end_col_offset: int = None):
    node.lineno = lineno
    node.col_offset = col_offset
    node.end_lineno = end_lineno if end_lineno is not None else lineno
    node.end_col_offset = end_col_offset if end_col_offset is not None else col_offset

# Basic Test Cases

def test_basic_match_single_await_call():
    """
    Basic: Awaiting a call to the target async function at the specified position.
    Should return (stmt, True).
    """
    code = "await foo()"
    stmt = get_first_stmt_from_code(code)
    # Set positions for the call node
    call_node = stmt.value
    set_ast_positions(call_node, lineno=1, col_offset=6, end_lineno=1, end_col_offset=11)
    set_ast_positions(stmt, lineno=1, col_offset=0, end_lineno=1, end_col_offset=11)
    # Setup instrumenter
    function = DummyFunctionToOptimize("foo")
    pos = [DummyCodePosition(line_no=1, col_no=6, end_col_offset=11)]
    instrumenter = AsyncCallInstrumenter(function, "", "", pos, DummyTestingMode.BEHAVIOR)
    new_stmt, did_instrument = instrumenter._instrument_statement(stmt, "foo") # 9.48μs -> 5.00μs (89.6% faster)

def test_basic_no_match_different_function():
    """
    Basic: Awaiting a call to a different async function (not the target).
    Should return (stmt, False).
    """
    code = "await bar()"
    stmt = get_first_stmt_from_code(code)
    call_node = stmt.value
    set_ast_positions(call_node, lineno=1, col_offset=6, end_lineno=1, end_col_offset=11)
    set_ast_positions(stmt, lineno=1, col_offset=0, end_lineno=1, end_col_offset=11)
    function = DummyFunctionToOptimize("foo")
    pos = [DummyCodePosition(line_no=1, col_no=6, end_col_offset=11)]
    instrumenter = AsyncCallInstrumenter(function, "", "", pos, DummyTestingMode.BEHAVIOR)
    new_stmt, did_instrument = instrumenter._instrument_statement(stmt, "bar") # 9.19μs -> 6.84μs (34.3% faster)

def test_basic_no_match_not_await():
    """
    Basic: Not an await statement, just a call.
    Should return (stmt, False).
    """
    code = "foo()"
    stmt = get_first_stmt_from_code(code)
    call_node = stmt.value
    set_ast_positions(call_node, lineno=1, col_offset=0, end_lineno=1, end_col_offset=5)
    set_ast_positions(stmt, lineno=1, col_offset=0, end_lineno=1, end_col_offset=5)
    function = DummyFunctionToOptimize("foo")
    pos = [DummyCodePosition(line_no=1, col_no=0, end_col_offset=5)]
    instrumenter = AsyncCallInstrumenter(function, "", "", pos, DummyTestingMode.BEHAVIOR)
    new_stmt, did_instrument = instrumenter._instrument_statement(stmt, "foo") # 7.83μs -> 5.08μs (54.0% faster)

def test_basic_match_attribute_call():
    """
    Basic: Awaiting a call to an attribute (obj.foo()) where foo is the target.
    Should return (stmt, True).
    """
    code = "await obj.foo()"
    stmt = get_first_stmt_from_code(code)
    call_node = stmt.value
    set_ast_positions(call_node, lineno=1, col_offset=6, end_lineno=1, end_col_offset=15)
    set_ast_positions(stmt, lineno=1, col_offset=0, end_lineno=1, end_col_offset=15)
    function = DummyFunctionToOptimize("foo")
    pos = [DummyCodePosition(line_no=1, col_no=10, end_col_offset=15)]
    instrumenter = AsyncCallInstrumenter(function, "", "", pos, DummyTestingMode.BEHAVIOR)
    new_stmt, did_instrument = instrumenter._instrument_statement(stmt, "foo") # 8.67μs -> 4.51μs (92.2% faster)

# Edge Test Cases

def test_edge_no_call_positions():
    """
    Edge: Awaiting target function, but call_positions is empty.
    Should return (stmt, False).
    """
    code = "await foo()"
    stmt = get_first_stmt_from_code(code)
    call_node = stmt.value
    set_ast_positions(call_node, lineno=1, col_offset=6, end_lineno=1, end_col_offset=11)
    set_ast_positions(stmt, lineno=1, col_offset=0, end_lineno=1, end_col_offset=11)
    function = DummyFunctionToOptimize("foo")
    pos = []
    instrumenter = AsyncCallInstrumenter(function, "", "", pos, DummyTestingMode.BEHAVIOR)
    new_stmt, did_instrument = instrumenter._instrument_statement(stmt, "foo") # 9.68μs -> 6.89μs (40.4% faster)

def test_edge_call_outside_position():
    """
    Edge: Awaiting target function, but call is outside specified position.
    Should return (stmt, False).
    """
    code = "await foo()"
    stmt = get_first_stmt_from_code(code)
    call_node = stmt.value
    set_ast_positions(call_node, lineno=1, col_offset=6, end_lineno=1, end_col_offset=11)
    set_ast_positions(stmt, lineno=1, col_offset=0, end_lineno=1, end_col_offset=11)
    function = DummyFunctionToOptimize("foo")
    # Position is on a different line
    pos = [DummyCodePosition(line_no=2, col_no=6, end_col_offset=11)]
    instrumenter = AsyncCallInstrumenter(function, "", "", pos, DummyTestingMode.BEHAVIOR)
    new_stmt, did_instrument = instrumenter._instrument_statement(stmt, "foo") # 10.3μs -> 7.27μs (41.3% faster)

def test_edge_multiple_calls_only_one_matches():
    """
    Edge: Statement with multiple awaits, only one matches the target and position.
    Should return (stmt, True).
    """
    code = """
await foo()
await bar()
"""
    tree = ast.parse(code)
    stmt1 = tree.body[0]
    stmt2 = tree.body[1]
    call_node1 = stmt1.value
    call_node2 = stmt2.value
    set_ast_positions(call_node1, lineno=2, col_offset=6, end_lineno=2, end_col_offset=11)
    set_ast_positions(stmt1, lineno=2, col_offset=0, end_lineno=2, end_col_offset=11)
    set_ast_positions(call_node2, lineno=3, col_offset=6, end_lineno=3, end_col_offset=11)
    set_ast_positions(stmt2, lineno=3, col_offset=0, end_lineno=3, end_col_offset=11)
    function = DummyFunctionToOptimize("foo")
    pos = [DummyCodePosition(line_no=2, col_no=6, end_col_offset=11)]
    instrumenter = AsyncCallInstrumenter(function, "", "", pos, DummyTestingMode.BEHAVIOR)
    # Only stmt1 should match
    new_stmt, did_instrument = instrumenter._instrument_statement(stmt1, "foo") # 8.29μs -> 4.02μs (106% faster)
    # stmt2 should not match
    new_stmt2, did_instrument2 = instrumenter._instrument_statement(stmt2, "bar") # 6.26μs -> 4.90μs (27.8% faster)

def test_edge_missing_lineno_col_offset():
    """
    Edge: Awaiting target function, but call node lacks lineno/col_offset.
    Should return (stmt, False).
    """
    code = "await foo()"
    stmt = get_first_stmt_from_code(code)
    call_node = stmt.value
    # Intentionally do not set lineno/col_offset
    function = DummyFunctionToOptimize("foo")
    pos = [DummyCodePosition(line_no=1, col_no=6, end_col_offset=11)]
    instrumenter = AsyncCallInstrumenter(function, "", "", pos, DummyTestingMode.BEHAVIOR)
    new_stmt, did_instrument = instrumenter._instrument_statement(stmt, "foo") # 8.13μs -> 4.09μs (99.0% faster)

def test_edge_nested_await_in_expression():
    """
    Edge: Awaiting target function inside a complex expression.
    Should return (stmt, True) if call position matches.
    """
    code = "result = await foo() + await bar()"
    stmt = get_first_stmt_from_code(code)
    # Find both await nodes inside the expression
    # result = BinOp(left=Await(Call(foo)), op=Add(), right=Await(Call(bar)))
    left_await = stmt.value.left
    right_await = stmt.value.right
    set_ast_positions(left_await.value, lineno=1, col_offset=15, end_lineno=1, end_col_offset=20)
    set_ast_positions(left_await, lineno=1, col_offset=8, end_lineno=1, end_col_offset=20)
    set_ast_positions(right_await.value, lineno=1, col_offset=23, end_lineno=1, end_col_offset=28)
    set_ast_positions(right_await, lineno=1, col_offset=21, end_lineno=1, end_col_offset=28)
    set_ast_positions(stmt, lineno=1, col_offset=0, end_lineno=1, end_col_offset=28)
    function = DummyFunctionToOptimize("foo")
    pos = [DummyCodePosition(line_no=1, col_no=15, end_col_offset=20)]
    instrumenter = AsyncCallInstrumenter(function, "", "", pos, DummyTestingMode.BEHAVIOR)
    new_stmt, did_instrument = instrumenter._instrument_statement(stmt, "foo") # 11.7μs -> 10.2μs (14.3% faster)

def test_edge_attribute_not_matching():
    """
    Edge: Awaiting obj.bar() but looking for foo.
    Should return (stmt, False).
    """
    code = "await obj.bar()"
    stmt = get_first_stmt_from_code(code)
    call_node = stmt.value
    set_ast_positions(call_node, lineno=1, col_offset=6, end_lineno=1, end_col_offset=15)
    set_ast_positions(stmt, lineno=1, col_offset=0, end_lineno=1, end_col_offset=15)
    function = DummyFunctionToOptimize("foo")
    pos = [DummyCodePosition(line_no=1, col_no=10, end_col_offset=15)]
    instrumenter = AsyncCallInstrumenter(function, "", "", pos, DummyTestingMode.BEHAVIOR)
    new_stmt, did_instrument = instrumenter._instrument_statement(stmt, "foo") # 10.4μs -> 8.18μs (27.6% faster)

# Large Scale Test Cases

def test_large_scale_many_statements_one_match():
    """
    Large scale: Many statements, only one matches.
    """
    code_lines = []
    for i in range(1, 501):
        if i == 250:
            code_lines.append("await foo()")
        else:
            code_lines.append("await bar()")
    code = "\n".join(code_lines)
    tree = ast.parse(code)
    stmts = tree.body
    # Set positions for all calls
    for idx, stmt in enumerate(stmts, start=1):
        call_node = stmt.value
        set_ast_positions(call_node, lineno=idx, col_offset=6, end_lineno=idx, end_col_offset=11)
        set_ast_positions(stmt, lineno=idx, col_offset=0, end_lineno=idx, end_col_offset=11)
    function = DummyFunctionToOptimize("foo")
    pos = [DummyCodePosition(line_no=250, col_no=6, end_col_offset=11)]
    instrumenter = AsyncCallInstrumenter(function, "", "", pos, DummyTestingMode.BEHAVIOR)
    # Only statement 250 should match
    for idx, stmt in enumerate(stmts, start=1):
        new_stmt, did_instrument = instrumenter._instrument_statement(stmt, "foo") # 2.08ms -> 1.70ms (22.2% faster)
        if idx == 250:
            pass
        else:
            pass

def test_large_scale_all_match():
    """
    Large scale: All statements match target.
    """
    code_lines = ["await foo()" for _ in range(500)]
    code = "\n".join(code_lines)
    tree = ast.parse(code)
    stmts = tree.body
    for idx, stmt in enumerate(stmts, start=1):
        call_node = stmt.value
        set_ast_positions(call_node, lineno=idx, col_offset=6, end_lineno=idx, end_col_offset=11)
        set_ast_positions(stmt, lineno=idx, col_offset=0, end_lineno=idx, end_col_offset=11)
    function = DummyFunctionToOptimize("foo")
    pos = [DummyCodePosition(line_no=i, col_no=6, end_col_offset=11) for i in range(1, 501)]
    instrumenter = AsyncCallInstrumenter(function, "", "", pos, DummyTestingMode.BEHAVIOR)
    # All should match
    for idx, stmt in enumerate(stmts, start=1):
        new_stmt, did_instrument = instrumenter._instrument_statement(stmt, "foo") # 11.1ms -> 10.5ms (6.49% faster)

def test_large_scale_no_matches():
    """
    Large scale: None of the statements match target.
    """
    code_lines = ["await bar()" for _ in range(500)]
    code = "\n".join(code_lines)
    tree = ast.parse(code)
    stmts = tree.body
    for idx, stmt in enumerate(stmts, start=1):
        call_node = stmt.value
        set_ast_positions(call_node, lineno=idx, col_offset=6, end_lineno=idx, end_col_offset=11)
        set_ast_positions(stmt, lineno=idx, col_offset=0, end_lineno=idx, end_col_offset=11)
    function = DummyFunctionToOptimize("foo")
    pos = [DummyCodePosition(line_no=i, col_no=6, end_col_offset=11) for i in range(1, 501)]
    instrumenter = AsyncCallInstrumenter(function, "", "", pos, DummyTestingMode.BEHAVIOR)
    # None should match
    for idx, stmt in enumerate(stmts, start=1):
        new_stmt, did_instrument = instrumenter._instrument_statement(stmt, "foo") # 2.08ms -> 1.71ms (22.3% faster)

def test_large_scale_complex_nested_awaits():
    """
    Large scale: Many statements with nested awaits, only some match.
    """
    code_lines = []
    for i in range(1, 501):
        if i % 100 == 0:
            code_lines.append("result = await foo() + await bar()")
        else:
            code_lines.append("await bar()")
    code = "\n".join(code_lines)
    tree = ast.parse(code)
    stmts = tree.body
    function = DummyFunctionToOptimize("foo")
    # Positions for foo calls at lines 100, 200, 300, 400, 500
    pos = [DummyCodePosition(line_no=i, col_no=15, end_col_offset=20) for i in range(100, 501, 100)]
    # Set positions
    for idx, stmt in enumerate(stmts, start=1):
        if idx % 100 == 0:
            # result = await foo() + await bar()
            left_await = stmt.value.left
            set_ast_positions(left_await.value, lineno=idx, col_offset=15, end_lineno=idx, end_col_offset=20)
            set_ast_positions(left_await, lineno=idx, col_offset=8, end_lineno=idx, end_col_offset=20)
            set_ast_positions(stmt, lineno=idx, col_offset=0, end_lineno=idx, end_col_offset=28)
        else:
            call_node = stmt.value
            set_ast_positions(call_node, lineno=idx, col_offset=6, end_lineno=idx, end_col_offset=11)
            set_ast_positions(stmt, lineno=idx, col_offset=0, end_lineno=idx, end_col_offset=11)
    instrumenter = AsyncCallInstrumenter(function, "", "", pos, DummyTestingMode.BEHAVIOR)
    # Only every 100th statement should match
    for idx, stmt in enumerate(stmts, start=1):
        if idx % 100 == 0:
            new_stmt, did_instrument = instrumenter._instrument_statement(stmt, "foo")
        else:
            new_stmt, did_instrument = instrumenter._instrument_statement(stmt, "foo")
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from __future__ import annotations

import ast

# imports
import pytest  # used for our unit tests
from codeflash.code_utils.instrument_existing_tests import \
    AsyncCallInstrumenter
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodePosition, TestingMode


# Helper functions/classes for tests
class DummyParent:
    def __init__(self, type_):
        self.type = type_
        self.top_level_parent_name = "DummyClass"

def make_function_to_optimize(name="foo", parents=None):
    if parents is None:
        parents = []
    return FunctionToOptimize(
        function_name=name,
        parents=parents,
        top_level_parent_name="DummyClass"
    )

def make_code_position(line_no, col_no, end_col_offset=None):
    return CodePosition(line_no=line_no, col_no=col_no, end_col_offset=end_col_offset)

def parse_stmt(source):
    # Parse a single statement, return the ast.stmt node
    mod = ast.parse(source)
    return mod.body[0]

def set_call_location(node, lineno, col_offset, end_lineno=None, end_col_offset=None):
    # Set location attributes for ast.Call node
    node.lineno = lineno
    node.col_offset = col_offset
    node.end_lineno = end_lineno if end_lineno is not None else lineno
    node.end_col_offset = end_col_offset if end_col_offset is not None else col_offset + 1

def set_await_location(await_node, lineno, col_offset, end_lineno=None, end_col_offset=None):
    await_node.lineno = lineno
    await_node.col_offset = col_offset
    await_node.end_lineno = end_lineno if end_lineno is not None else lineno
    await_node.end_col_offset = end_col_offset if end_col_offset is not None else col_offset + 1

# ------------------- UNIT TESTS -------------------

# ----------- 1. Basic Test Cases -----------

To edit these changes git checkout codeflash/optimize-pr769-2025-09-27T00.38.40 and push.

Codeflash

The optimized code achieves an 11% speedup by replacing the expensive `ast.walk()` with a custom stack-based traversal that supports **early termination**. 

**Key optimizations:**

1. **Stack-based AST traversal with early exit**: Instead of `ast.walk()` which must visit every node, the optimized version uses a manual stack that immediately returns `True` when finding a matching `Await` node, avoiding unnecessary traversal of remaining subtrees.

2. **Function name caching**: Pre-stores `self._function_name = function.function_name` in `__init__` to eliminate repeated attribute lookups in `_is_target_call()`.

3. **Local variable optimization**: Extracts `func = call_node.func` to reduce repeated attribute access.

**Performance impact by test type:**
- **Small/simple statements** (basic tests): 27-106% faster due to reduced traversal overhead
- **Complex nested expressions**: 14% improvement as early exit helps when matches are found
- **Large-scale scenarios**: 6-22% improvement, with better gains when fewer matches occur (early termination is more effective)

The optimization is most effective when matches are found early in the AST traversal, as it can skip examining the remaining nodes entirely. Line profiling shows the stack-based approach reduces the expensive `ast.walk()` overhead from 29% to 21.5% of total time in `_instrument_statement`.
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Sep 27, 2025
@codeflash-ai codeflash-ai bot closed this Sep 29, 2025
@codeflash-ai
Copy link
Contributor Author

codeflash-ai bot commented Sep 29, 2025

This PR has been automatically closed because the original PR #769 by KRRT7 was closed.

@codeflash-ai codeflash-ai bot deleted the codeflash/optimize-pr769-2025-09-27T00.38.40 branch September 29, 2025 21:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants