Skip to content

Conversation

@KRRT7
Copy link
Contributor

@KRRT7 KRRT7 commented Aug 22, 2025

PR Type

Enhancement, Bug fix, Tests


Description

  • Add support for async function detection

  • Remove has_any_async_functions usage

  • Fix coverage utils empty database return

  • Update tests for async discovery and code validation


Diagram Walkthrough

flowchart LR
  A["Code parsing"] -- "detect `AsyncFunctionDef`" --> B["Function discovery"]
  B -- "flag `is_async`" --> C["FunctionToOptimize"]
  C -- "optimize without blocking on async" --> D["FunctionOptimizer"]
  D -- "generate coverage" --> E["CoverageUtils"]
Loading

File Walkthrough

Relevant files
Enhancement
4 files
code_utils.py
Remove unused async detection function                                     
+0/-8     
static_analysis.py
Add AsyncFunctionDef support to lookup functions                 
+9/-3     
functions_to_optimize.py
Enable async function discovery and flagging                         
+30/-2   
function_optimizer.py
Remove async restriction and support async AST                     
+1/-7     
Bug fix
1 files
coverage_utils.py
Fix empty DB check order and return `CoverageData`             
+3/-3     
Tests
3 files
test_async_function_discovery.py
Add comprehensive async discovery tests                                   
+286/-0 
test_code_context_extractor.py
Resolve symlink paths in code context tests                           
+9/-4     
test_code_utils.py
Replace async detection tests with code validation             
+34/-18 

@github-actions
Copy link

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
🔒 No security concerns identified
⚡ Recommended focus areas for review

Coverage Empty DB

Ensure the fallback to create_empty uses the correct CoverageData.create_empty method and that CoverageData is properly imported. Confirm the exists/stat check avoids errors when the database file is missing.

if not database_path.exists() or not database_path.stat().st_size:
    logger.debug(f"Coverage database {database_path} is empty or does not exist")
    sentry_sdk.capture_message(f"Coverage database {database_path} is empty or does not exist")
    return CoverageData.create_empty(source_code_path, function_name, code_context)
Async metadata consistency

The AsyncFunctionDef visitor appends functions without setting starting_line and ending_line, and uses a different parent list than the sync FunctionDef visitor. Ensure metadata consistency for async functions.

def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
    # Check if the async function has a return statement and add it to the list
    if function_has_return_statement(node) and not function_is_a_property(node):
        self.functions.append(
            FunctionToOptimize(
                function_name=node.name, file_path=self.file_path, parents=self.ast_path[:], is_async=True
            )
        )
AST return type

get_first_top_level_function_or_method_ast now returns both FunctionDef and AsyncFunctionDef. Verify downstream callers and type hints handle the async case correctly.

) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
    if not parents:
        result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node)
        if result is not None:
            return result
        return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, node)

@github-actions
Copy link

PR Code Suggestions ✨

No code suggestions found for the PR.

KRRT7 and others added 18 commits September 22, 2025 12:21
The optimization achieves a 25% speedup by **eliminating redundant AST node creation** inside the loop. 

**Key change:** The `timeout_decorator` AST node is now created once before the loop instead of being recreated for every test method that needs it. In the original code, this AST structure was built 3,411 times during profiling, consuming significant time in object allocation and initialization.

**Why this works:** AST nodes are immutable once created, so the same `timeout_decorator` instance can be safely appended to multiple method decorator lists. This eliminates:
- Repeated `ast.Call()` constructor calls
- Redundant `ast.Name()` and `ast.Constant()` object creation
- Multiple attribute assignments for the same decorator structure

**Performance characteristics:** The optimization is most effective for large test classes with many test methods (showing 24-33% improvements in tests with 500+ methods), while having minimal impact on classes with few or no test methods. This makes it particularly valuable for comprehensive test suites where classes commonly contain dozens of test methods.

The line profiler shows the AST node creation operations dropped from ~3,400 hits to just ~25 hits, directly correlating with the observed speedup.
…25-09-22T19.41.32

⚡️ Speed up method `AsyncCallInstrumenter.visit_ClassDef` by 26% in PR #739 (`get-throughput-from-output`)
add End to end test for async optimization
Get throughput from output  for async functions
@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Sep 26, 2025

⚡️ Codeflash found optimizations for this PR

📄 11% (0.11x) speedup for CommentMapper.visit_FunctionDef in codeflash/code_utils/edit_generated_tests.py

⏱️ Runtime : 2.62 milliseconds 2.36 milliseconds (best of 295 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch standalone-fto-async).

@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Sep 26, 2025

⚡️ Codeflash found optimizations for this PR

📄 11% (0.11x) speedup for CommentMapper.visit_AsyncFunctionDef in codeflash/code_utils/edit_generated_tests.py

⏱️ Runtime : 5.28 milliseconds 4.76 milliseconds (best of 117 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch standalone-fto-async).

@codeflash-ai
Copy link
Contributor

codeflash-ai bot commented Sep 26, 2025

⚡️ Codeflash found optimizations for this PR

📄 13% (0.13x) speedup for AsyncCallInstrumenter._process_test_function in codeflash/code_utils/instrument_existing_tests.py

⏱️ Runtime : 2.35 milliseconds 2.09 milliseconds (best of 15 runs)

A dependent PR with the suggested changes has been created. Please review:

If you approve, it will be merged into this PR (branch standalone-fto-async).

Comment on lines +400 to +407
for node in ast.walk(stmt):
if (
isinstance(node, ast.Await)
and isinstance(node.value, ast.Call)
and self._is_target_call(node.value)
and self._call_in_positions(node.value)
):
# Check if this call is in one of our target positions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 13% (0.13x) speedup for AsyncCallInstrumenter._instrument_statement in codeflash/code_utils/instrument_existing_tests.py

⏱️ Runtime : 7.34 milliseconds 6.51 milliseconds (best of 169 runs)

📝 Explanation and details

The optimization replaces ast.walk() with a custom iterative traversal that specifically targets ast.Await nodes containing ast.Call nodes.

Key optimization: Instead of walking through all AST nodes (9,373 iterations in the original), the optimized version uses a stack-based approach that only yields ast.Await nodes with ast.Call values, reducing iterations to 2,740 - a 75% reduction in node visits.

How it works: The new _await_call_nodes() function uses an explicit stack to traverse the AST, only yielding nodes that match the pattern await some_call(). This eliminates the need to check isinstance(node, ast.Await) and isinstance(node.value, ast.Call) for every single node in the tree.

Performance impact: The line profiler shows the main loop time dropped from 40.8ms to 38.6ms (5% improvement), with overall function time improving from 59.2ms to 51.4ms (13% speedup). The optimization is particularly effective for test cases with:

  • Large ASTs with few await calls (22-46% faster on basic cases)
  • Multiple nested statements where most nodes aren't await calls
  • Complex expressions where await calls are deeply embedded

This targeted traversal approach is especially beneficial when the ratio of total AST nodes to await-call patterns is high, which is typical in real codebases.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 1564 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 92.9%
🌀 Generated Regression Tests and Runtime
from __future__ import annotations

import ast
from types import SimpleNamespace

# imports
import pytest  # used for our unit tests
from codeflash.code_utils.instrument_existing_tests import \
    AsyncCallInstrumenter
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodePosition, TestingMode


class FunctionToOptimize:
    def __init__(self, function_name, parents=None, top_level_parent_name=None):
        self.function_name = function_name
        self.parents = parents or []
        self.top_level_parent_name = top_level_parent_name

class CodePosition:
    def __init__(self, line_no, col_no, end_col_offset=None):
        self.line_no = line_no
        self.col_no = col_no
        self.end_col_offset = end_col_offset

class TestingMode:
    BEHAVIOR = "behavior"
    COVERAGE = "coverage"

# ---- TEST FIXTURE HELPERS ----

def make_instrumenter(
    function_name="foo",
    call_positions=None,
    parents=None,
    top_level_parent_name=None,
    mode=TestingMode.BEHAVIOR,
):
    func = FunctionToOptimize(
        function_name=function_name,
        parents=parents,
        top_level_parent_name=top_level_parent_name,
    )
    return AsyncCallInstrumenter(
        function=func,
        module_path="dummy.py",
        test_framework="pytest",
        call_positions=call_positions or [],
        mode=mode,
    )

def parse_stmt(src):
    """Parse a single statement from code string, return ast.stmt node."""
    mod = ast.parse(src)
    # Return the first statement node
    return mod.body[0]

def set_ast_positions(node, lineno=1, col_offset=0, end_lineno=None, end_col_offset=None):
    """Set lineno/col_offset/end_lineno/end_col_offset recursively for all nodes."""
    for sub in ast.walk(node):
        if hasattr(sub, "lineno"):
            sub.lineno = lineno
        if hasattr(sub, "col_offset"):
            sub.col_offset = col_offset
        if hasattr(sub, "end_lineno") and end_lineno is not None:
            sub.end_lineno = end_lineno
        if hasattr(sub, "end_col_offset") and end_col_offset is not None:
            sub.end_col_offset = end_col_offset
    return node

# ---- BASIC TEST CASES ----

def test_no_await_returns_false():
    """No await in statement: should return (stmt, False)."""
    stmt = parse_stmt("x = 1")
    instr = make_instrumenter()
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 8.68μs -> 6.75μs (28.5% faster)

def test_await_non_target_function_returns_false():
    """Await of a function that is not the target: should return False."""
    stmt = parse_stmt("await bar()")
    # Set positions so that node_in_call_position can work
    set_ast_positions(stmt, lineno=1, col_offset=0, end_lineno=1, end_col_offset=10)
    pos = [CodePosition(line_no=1, col_no=7, end_col_offset=10)]
    instr = make_instrumenter(function_name="foo", call_positions=pos)
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 6.71μs -> 5.79μs (15.9% faster)

def test_await_target_function_not_in_position_returns_false():
    """Await of target function, but not in call_positions: should return False."""
    stmt = parse_stmt("await foo()")
    set_ast_positions(stmt, lineno=3, col_offset=0, end_lineno=3, end_col_offset=10)
    pos = [CodePosition(line_no=1, col_no=7, end_col_offset=10)]  # position does not match
    instr = make_instrumenter(function_name="foo", call_positions=pos)
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 7.74μs -> 6.76μs (14.4% faster)

def test_await_target_function_in_position_returns_true():
    """Await of target function, in call_positions: should return True."""
    stmt = parse_stmt("await foo()")
    set_ast_positions(stmt, lineno=1, col_offset=0, end_lineno=1, end_col_offset=10)
    pos = [CodePosition(line_no=1, col_no=7, end_col_offset=10)]
    instr = make_instrumenter(function_name="foo", call_positions=pos)
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 6.17μs -> 5.07μs (21.7% faster)

def test_await_target_method_in_position_returns_true():
    """Await of target function as method attribute, in call_positions: should return True."""
    stmt = parse_stmt("await obj.foo()")
    set_ast_positions(stmt, lineno=2, col_offset=0, end_lineno=2, end_col_offset=14)
    pos = [CodePosition(line_no=2, col_no=8, end_col_offset=14)]
    instr = make_instrumenter(function_name="foo", call_positions=pos)
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 6.13μs -> 4.94μs (24.2% faster)

# ---- EDGE TEST CASES ----

def test_multiple_awaits_only_one_matches():
    """Multiple awaits, only one matches target and position."""
    src = """
await foo()
await bar()
await foo()
"""
    mod = ast.parse(src)
    # Set positions for each await
    set_ast_positions(mod.body[0], lineno=1, col_offset=0, end_lineno=1, end_col_offset=10)
    set_ast_positions(mod.body[1], lineno=2, col_offset=0, end_lineno=2, end_col_offset=10)
    set_ast_positions(mod.body[2], lineno=3, col_offset=0, end_lineno=3, end_col_offset=10)
    pos = [CodePosition(line_no=3, col_no=7, end_col_offset=10)]
    instr = make_instrumenter(function_name="foo", call_positions=pos)
    # Should return True for the third statement only
    _, did_instrument_1 = instr._instrument_statement(mod.body[0], "_") # 7.06μs -> 6.77μs (4.28% faster)
    _, did_instrument_2 = instr._instrument_statement(mod.body[1], "_") # 4.94μs -> 4.12μs (20.0% faster)
    _, did_instrument_3 = instr._instrument_statement(mod.body[2], "_") # 4.48μs -> 3.25μs (38.0% faster)

def test_await_target_function_multiple_positions():
    """Await of target function, several call_positions, one matches."""
    stmt = parse_stmt("await foo()")
    set_ast_positions(stmt, lineno=5, col_offset=0, end_lineno=5, end_col_offset=10)
    pos = [
        CodePosition(line_no=4, col_no=7, end_col_offset=10),
        CodePosition(line_no=5, col_no=7, end_col_offset=10),
    ]
    instr = make_instrumenter(function_name="foo", call_positions=pos)
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 6.14μs -> 5.04μs (21.9% faster)

def test_await_target_function_missing_ast_positions():
    """Await of target function, but ast node missing lineno/col_offset: should return False."""
    stmt = parse_stmt("await foo()")
    # Remove lineno/col_offset from the ast.Call node
    for node in ast.walk(stmt):
        if isinstance(node, ast.Call):
            if hasattr(node, "lineno"):
                delattr(node, "lineno")
            if hasattr(node, "col_offset"):
                delattr(node, "col_offset")
    pos = [CodePosition(line_no=1, col_no=7, end_col_offset=10)]
    instr = make_instrumenter(function_name="foo", call_positions=pos)
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 6.57μs -> 5.99μs (9.70% faster)

def test_await_target_function_nested_in_expr():
    """Await of target function nested inside an expression (should still match if position matches)."""
    stmt = parse_stmt("x = await foo() + 1")
    # Set positions for the foo() call
    for node in ast.walk(stmt):
        if isinstance(node, ast.Call):
            node.lineno = 1
            node.col_offset = 9
            node.end_lineno = 1
            node.end_col_offset = 15
    pos = [CodePosition(line_no=1, col_no=10, end_col_offset=15)]
    instr = make_instrumenter(function_name="foo", call_positions=pos)
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 8.77μs -> 7.91μs (10.8% faster)

def test_await_target_function_attribute_chain():
    """Await of target function as deeply nested attribute (should match if attr name matches)."""
    stmt = parse_stmt("await obj.sub.foo()")
    for node in ast.walk(stmt):
        if isinstance(node, ast.Call):
            node.lineno = 1
            node.col_offset = 6
            node.end_lineno = 1
            node.end_col_offset = 18
    pos = [CodePosition(line_no=1, col_no=10, end_col_offset=18)]
    instr = make_instrumenter(function_name="foo", call_positions=pos)
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 6.06μs -> 4.80μs (26.3% faster)

def test_await_target_function_col_offset_range():
    """Await of target function, position matches only if col_offset is within range."""
    stmt = parse_stmt("await foo()")
    set_ast_positions(stmt, lineno=1, col_offset=0, end_lineno=1, end_col_offset=20)
    pos = [CodePosition(line_no=1, col_no=15, end_col_offset=20)]
    instr = make_instrumenter(function_name="foo", call_positions=pos)
    # Should match because col_offset <= pos.col_no <= end_col_offset
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 5.78μs -> 4.73μs (22.3% faster)

def test_await_target_function_line_range():
    """Await of target function spanning multiple lines, position falls inside."""
    src = "await foo(\n    1,\n    2\n)"
    stmt = parse_stmt(src)
    # Set the call node to span lines 1-4
    for node in ast.walk(stmt):
        if isinstance(node, ast.Call):
            node.lineno = 1
            node.col_offset = 6
            node.end_lineno = 4
            node.end_col_offset = 1
    pos = [CodePosition(line_no=3, col_no=4, end_col_offset=1)]
    instr = make_instrumenter(function_name="foo", call_positions=pos)
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 6.02μs -> 4.91μs (22.6% faster)

# ---- LARGE SCALE TEST CASES ----

def test_large_number_of_call_positions():
    """Test with a large number of call_positions (up to 1000), only one matches."""
    stmt = parse_stmt("await foo()")
    set_ast_positions(stmt, lineno=100, col_offset=0, end_lineno=100, end_col_offset=10)
    # 999 non-matching positions, 1 matching
    pos = [CodePosition(line_no=i, col_no=0, end_col_offset=10) for i in range(1, 1000)]
    pos.append(CodePosition(line_no=100, col_no=0, end_col_offset=10))
    instr = make_instrumenter(function_name="foo", call_positions=pos)
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 16.7μs -> 14.9μs (12.3% faster)

def test_large_ast_statement_with_one_matching_await():
    """Test a large ast statement (with many subnodes), only one await matches."""
    # Build a large statement with many assignments, only one has await foo()
    stmts = ["x{} = {}".format(i, i) for i in range(500)]
    stmts.insert(250, "await foo()")
    src = "\n".join(stmts)
    mod = ast.parse(src)
    # Set positions for the 251st statement (index 250)
    set_ast_positions(mod.body[250], lineno=251, col_offset=0, end_lineno=251, end_col_offset=10)
    pos = [CodePosition(line_no=251, col_no=7, end_col_offset=10)]
    instr = make_instrumenter(function_name="foo", call_positions=pos)
    new_stmt, did_instrument = instr._instrument_statement(mod.body[250], "_") # 6.95μs -> 5.67μs (22.6% faster)
    # All others should be False
    for i in range(500):
        if i == 250:
            continue
        _, did_instrument_other = instr._instrument_statement(mod.body[i], "_") # 1.85ms -> 1.57ms (17.9% faster)

def test_large_number_of_awaits_only_last_matches():
    """Test 1000 awaits, only the last matches call_positions."""
    src = "\n".join([f"await foo()" for _ in range(999)] + ["await foo()"])
    mod = ast.parse(src)
    # Set positions for each await
    for i, stmt in enumerate(mod.body):
        set_ast_positions(stmt, lineno=i+1, col_offset=0, end_lineno=i+1, end_col_offset=10)
    pos = [CodePosition(line_no=1000, col_no=7, end_col_offset=10)]
    instr = make_instrumenter(function_name="foo", call_positions=pos)
    for i in range(999):
        _, did_instrument = instr._instrument_statement(mod.body[i], "_") # 4.63ms -> 3.97ms (16.6% faster)
    _, did_instrument_last = instr._instrument_statement(mod.body[999], "_") # 4.58μs -> 3.39μs (35.2% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from __future__ import annotations

import ast

# imports
import pytest  # used for our unit tests
from codeflash.code_utils.instrument_existing_tests import \
    AsyncCallInstrumenter


# Minimal stubs for required classes
class CodePosition:
    def __init__(self, line_no, col_no, end_col_offset=None):
        self.line_no = line_no
        self.col_no = col_no
        self.end_col_offset = end_col_offset

class FunctionToOptimize:
    def __init__(self, function_name, parents=None, top_level_parent_name=None):
        self.function_name = function_name
        self.parents = parents or []
        self.top_level_parent_name = top_level_parent_name

# unit tests

# Helper to create an Await node with a Call to a function, with position attributes
def make_await_call_stmt(func_name, lineno=1, col_offset=0, end_lineno=None, end_col_offset=None):
    call = ast.Call(
        func=ast.Name(id=func_name, ctx=ast.Load()),
        args=[],
        keywords=[],
    )
    call.lineno = lineno
    call.col_offset = col_offset
    call.end_lineno = end_lineno if end_lineno is not None else lineno
    call.end_col_offset = end_col_offset if end_col_offset is not None else col_offset + 10
    await_node = ast.Await(value=call)
    await_node.lineno = lineno
    await_node.col_offset = col_offset
    await_node.end_lineno = call.end_lineno
    await_node.end_col_offset = call.end_col_offset
    expr = ast.Expr(value=await_node)
    expr.lineno = lineno
    expr.col_offset = col_offset
    expr.end_lineno = call.end_lineno
    expr.end_col_offset = call.end_col_offset
    return expr

# Helper to create a non-await call to the function
def make_non_await_call_stmt(func_name, lineno=1, col_offset=0, end_lineno=None, end_col_offset=None):
    call = ast.Call(
        func=ast.Name(id=func_name, ctx=ast.Load()),
        args=[],
        keywords=[],
    )
    call.lineno = lineno
    call.col_offset = col_offset
    call.end_lineno = end_lineno if end_lineno is not None else lineno
    call.end_col_offset = end_col_offset if end_col_offset is not None else col_offset + 10
    expr = ast.Expr(value=call)
    expr.lineno = lineno
    expr.col_offset = col_offset
    expr.end_lineno = call.end_lineno
    expr.end_col_offset = call.end_col_offset
    return expr

# Helper to create an Await node with a Call to a different function
def make_await_call_stmt_other(func_name, other_func_name, lineno=1, col_offset=0, end_lineno=None, end_col_offset=None):
    call = ast.Call(
        func=ast.Name(id=other_func_name, ctx=ast.Load()),
        args=[],
        keywords=[],
    )
    call.lineno = lineno
    call.col_offset = col_offset
    call.end_lineno = end_lineno if end_lineno is not None else lineno
    call.end_col_offset = end_col_offset if end_col_offset is not None else col_offset + 10
    await_node = ast.Await(value=call)
    await_node.lineno = lineno
    await_node.col_offset = col_offset
    await_node.end_lineno = call.end_lineno
    await_node.end_col_offset = call.end_col_offset
    expr = ast.Expr(value=await_node)
    expr.lineno = lineno
    expr.col_offset = col_offset
    expr.end_lineno = call.end_lineno
    expr.end_col_offset = call.end_col_offset
    return expr

# 1. Basic Test Cases

def test_instrument_await_target_call_in_position():
    """Should instrument when awaiting the target function at the correct position."""
    func_name = "foo"
    stmt = make_await_call_stmt(func_name, lineno=10, col_offset=5, end_lineno=10, end_col_offset=15)
    call_positions = [CodePosition(line_no=10, col_no=5, end_col_offset=15)]
    function = FunctionToOptimize(func_name)
    instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_node") # 8.68μs -> 5.94μs (46.0% faster)

def test_no_instrument_when_not_await():
    """Should not instrument if the call is not awaited."""
    func_name = "foo"
    stmt = make_non_await_call_stmt(func_name, lineno=10, col_offset=5)
    call_positions = [CodePosition(line_no=10, col_no=5)]
    function = FunctionToOptimize(func_name)
    instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_node") # 7.83μs -> 5.40μs (45.1% faster)

def test_no_instrument_when_call_to_other_function():
    """Should not instrument if the awaited call is to a different function."""
    func_name = "foo"
    stmt = make_await_call_stmt_other(func_name, "bar", lineno=10, col_offset=5)
    call_positions = [CodePosition(line_no=10, col_no=5)]
    function = FunctionToOptimize(func_name)
    instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_node") # 8.85μs -> 6.37μs (38.8% faster)

def test_no_instrument_when_position_does_not_match():
    """Should not instrument if the awaited call is not at a target position."""
    func_name = "foo"
    stmt = make_await_call_stmt(func_name, lineno=10, col_offset=5)
    call_positions = [CodePosition(line_no=11, col_no=5)]
    function = FunctionToOptimize(func_name)
    instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_node") # 9.88μs -> 7.61μs (29.7% faster)

def test_instrument_with_attribute_call():
    """Should instrument if the awaited call is to an attribute with the correct function name."""
    func_name = "foo"
    # Build ast.Await(value=ast.Call(func=ast.Attribute(...)))
    call = ast.Call(
        func=ast.Attribute(
            value=ast.Name(id="obj", ctx=ast.Load()),
            attr=func_name,
            ctx=ast.Load()
        ),
        args=[],
        keywords=[]
    )
    call.lineno = 20
    call.col_offset = 2
    call.end_lineno = 20
    call.end_col_offset = 12
    await_node = ast.Await(value=call)
    await_node.lineno = 20
    await_node.col_offset = 2
    await_node.end_lineno = 20
    await_node.end_col_offset = 12
    expr = ast.Expr(value=await_node)
    expr.lineno = 20
    expr.col_offset = 2
    expr.end_lineno = 20
    expr.end_col_offset = 12

    call_positions = [CodePosition(line_no=20, col_no=2, end_col_offset=12)]
    function = FunctionToOptimize(func_name)
    instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
    new_stmt, did_instrument = instr._instrument_statement(expr, "_node") # 8.01μs -> 5.48μs (46.0% faster)

# 2. Edge Test Cases

def test_instrument_multiple_awaits_only_one_matches():
    """Should instrument if only one of several awaits matches the target function and position."""
    func_name = "foo"
    stmt1 = make_await_call_stmt(func_name, lineno=5, col_offset=0)
    stmt2 = make_await_call_stmt_other(func_name, "bar", lineno=6, col_offset=0)
    stmt3 = make_await_call_stmt(func_name, lineno=10, col_offset=5)
    # Compose a block with multiple statements
    block = ast.If(
        test=ast.Constant(value=True),
        body=[stmt1, stmt2, stmt3],
        orelse=[]
    )
    block.lineno = 5
    block.col_offset = 0
    block.end_lineno = 10
    block.end_col_offset = 15

    call_positions = [CodePosition(line_no=10, col_no=5)]
    function = FunctionToOptimize(func_name)
    instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
    new_stmt, did_instrument = instr._instrument_statement(block, "_node") # 14.2μs -> 7.05μs (102% faster)

def test_instrument_with_missing_lineno_col_offset():
    """Should not instrument if the call node is missing lineno/col_offset attributes."""
    func_name = "foo"
    call = ast.Call(
        func=ast.Name(id=func_name, ctx=ast.Load()),
        args=[],
        keywords=[]
    )
    # Do not set lineno/col_offset attributes
    await_node = ast.Await(value=call)
    expr = ast.Expr(value=await_node)
    call_positions = [CodePosition(line_no=1, col_no=0)]
    function = FunctionToOptimize(func_name)
    instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
    new_stmt, did_instrument = instr._instrument_statement(expr, "_node") # 8.82μs -> 6.54μs (34.8% faster)

def test_instrument_with_nested_calls():
    """Should instrument if the awaited call is nested inside another statement (e.g., in a list)."""
    func_name = "foo"
    stmt = make_await_call_stmt(func_name, lineno=15, col_offset=3)
    list_stmt = ast.List(elts=[stmt], ctx=ast.Load())
    expr = ast.Expr(value=list_stmt)
    expr.lineno = 15
    expr.col_offset = 3
    expr.end_lineno = 15
    expr.end_col_offset = 20
    call_positions = [CodePosition(line_no=15, col_no=3)]
    function = FunctionToOptimize(func_name)
    instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
    new_stmt, did_instrument = instr._instrument_statement(expr, "_node") # 10.3μs -> 8.26μs (24.3% faster)

def test_instrument_with_multiple_call_positions():
    """Should instrument if the awaited call matches any of several call positions."""
    func_name = "foo"
    stmt = make_await_call_stmt(func_name, lineno=30, col_offset=7)
    call_positions = [
        CodePosition(line_no=10, col_no=1),
        CodePosition(line_no=30, col_no=7),
        CodePosition(line_no=40, col_no=0)
    ]
    function = FunctionToOptimize(func_name)
    instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_node") # 7.88μs -> 5.06μs (55.7% faster)

def test_instrument_with_end_col_offset_variation():
    """Should instrument when call position matches at the end_col_offset."""
    func_name = "foo"
    stmt = make_await_call_stmt(func_name, lineno=50, col_offset=2, end_col_offset=20)
    call_positions = [CodePosition(line_no=50, col_no=18, end_col_offset=20)]
    function = FunctionToOptimize(func_name)
    instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_node") # 7.36μs -> 4.82μs (52.8% faster)

def test_instrument_with_no_call_positions():
    """Should not instrument if call_positions is empty."""
    func_name = "foo"
    stmt = make_await_call_stmt(func_name, lineno=10, col_offset=5)
    call_positions = []
    function = FunctionToOptimize(func_name)
    instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_node") # 9.21μs -> 6.98μs (31.8% faster)

def test_instrument_with_non_matching_attribute():
    """Should not instrument if attribute call does not match function name."""
    func_name = "foo"
    call = ast.Call(
        func=ast.Attribute(
            value=ast.Name(id="obj", ctx=ast.Load()),
            attr="bar",
            ctx=ast.Load()
        ),
        args=[],
        keywords=[]
    )
    call.lineno = 60
    call.col_offset = 4
    call.end_lineno = 60
    call.end_col_offset = 14
    await_node = ast.Await(value=call)
    await_node.lineno = 60
    await_node.col_offset = 4
    await_node.end_lineno = 60
    await_node.end_col_offset = 14
    expr = ast.Expr(value=await_node)
    expr.lineno = 60
    expr.col_offset = 4
    expr.end_lineno = 60
    expr.end_col_offset = 14

    call_positions = [CodePosition(line_no=60, col_no=4, end_col_offset=14)]
    function = FunctionToOptimize(func_name)
    instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
    new_stmt, did_instrument = instr._instrument_statement(expr, "_node") # 10.0μs -> 8.54μs (17.3% faster)

# 3. Large Scale Test Cases

def test_large_number_of_call_positions():
    """Should instrument if the awaited call matches among many call positions."""
    func_name = "foo"
    stmt = make_await_call_stmt(func_name, lineno=500, col_offset=8)
    # 999 positions that don't match, 1 that does
    call_positions = [CodePosition(line_no=i, col_no=0) for i in range(1, 500)]
    call_positions.append(CodePosition(line_no=500, col_no=8))
    function = FunctionToOptimize(func_name)
    instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_node") # 49.1μs -> 46.5μs (5.56% faster)

def test_large_ast_with_one_matching_await():
    """Should instrument when only one awaited call in a large AST matches the target and position."""
    func_name = "foo"
    stmts = []
    for i in range(1, 100):
        # Most are to other functions
        stmts.append(make_await_call_stmt_other(func_name, f"bar{i}", lineno=i, col_offset=0))
    # Insert a matching awaited call at position 50
    matching_stmt = make_await_call_stmt(func_name, lineno=50, col_offset=4)
    stmts[49] = matching_stmt  # Replace the 50th statement
    block = ast.Module(body=stmts, type_ignores=[])
    block.lineno = 1
    block.col_offset = 0
    block.end_lineno = 99
    block.end_col_offset = 10
    call_positions = [CodePosition(line_no=50, col_no=4)]
    function = FunctionToOptimize(func_name)
    instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
    new_stmt, did_instrument = instr._instrument_statement(block, "_node") # 101μs -> 166μs (38.9% slower)

def test_large_ast_with_no_matching_await():
    """Should not instrument when no awaited call in a large AST matches the target and position."""
    func_name = "foo"
    stmts = []
    for i in range(1, 100):
        stmts.append(make_await_call_stmt_other(func_name, f"bar{i}", lineno=i, col_offset=0))
    block = ast.Module(body=stmts, type_ignores=[])
    block.lineno = 1
    block.col_offset = 0
    block.end_lineno = 99
    block.end_col_offset = 10
    call_positions = [CodePosition(line_no=150, col_no=0)]
    function = FunctionToOptimize(func_name)
    instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
    new_stmt, did_instrument = instr._instrument_statement(block, "_node") # 317μs -> 317μs (0.009% slower)

def test_large_ast_with_multiple_matching_awaits():
    """Should instrument when multiple awaited calls in a large AST match the target and positions."""
    func_name = "foo"
    stmts = []
    match_lines = [10, 20, 30, 40, 50]
    for i in range(1, 101):
        if i in match_lines:
            stmts.append(make_await_call_stmt(func_name, lineno=i, col_offset=2))
        else:
            stmts.append(make_await_call_stmt_other(func_name, f"bar{i}", lineno=i, col_offset=0))
    block = ast.Module(body=stmts, type_ignores=[])
    block.lineno = 1
    block.col_offset = 0
    block.end_lineno = 100
    block.end_col_offset = 10
    call_positions = [CodePosition(line_no=i, col_no=2) for i in match_lines]
    function = FunctionToOptimize(func_name)
    instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
    new_stmt, did_instrument = instr._instrument_statement(block, "_node") # 72.0μs -> 169μs (57.4% slower)

def test_large_call_positions_all_non_matching():
    """Should not instrument if there are many call positions but none match."""
    func_name = "foo"
    stmt = make_await_call_stmt(func_name, lineno=1000, col_offset=0)
    call_positions = [CodePosition(line_no=i, col_no=1) for i in range(1, 1000)]
    function = FunctionToOptimize(func_name)
    instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
    new_stmt, did_instrument = instr._instrument_statement(stmt, "_node") # 93.3μs -> 90.8μs (2.68% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr678-2025-09-26T20.19.54

Click to see suggested changes
Suggested change
for node in ast.walk(stmt):
if (
isinstance(node, ast.Await)
and isinstance(node.value, ast.Call)
and self._is_target_call(node.value)
and self._call_in_positions(node.value)
):
# Check if this call is in one of our target positions
def _await_call_nodes(node):
stack = [node]
while stack:
cur = stack.pop()
if isinstance(cur, ast.Await) and isinstance(cur.value, ast.Call):
yield cur
for child in ast.iter_child_nodes(cur):
stack.append(child)
for await_node in _await_call_nodes(stmt):
call_node = await_node.value
if self._is_target_call(call_node) and self._call_in_positions(call_node):

@KRRT7 KRRT7 force-pushed the standalone-fto-async branch from 40c4108 to 7bbb1e7 Compare September 26, 2025 20:26
@KRRT7
Copy link
Contributor Author

KRRT7 commented Sep 27, 2025

closing in favor of #769 - cleaner commits

@KRRT7 KRRT7 closed this Sep 27, 2025
auto-merge was automatically disabled September 27, 2025 00:16

Pull request was closed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Review effort 4/5 workflow-modified This PR modifies GitHub Actions workflows

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants