-
Notifications
You must be signed in to change notification settings - Fork 22
check if a function is async and add to FTO #678
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
PR Code Suggestions ✨No code suggestions found for the PR. |
apply testgen-async fix bug when iterating over star imports fix cst * import errors
Update codeflash_wrap_decorator.py Update codeflash_wrap_decorator.py
The optimization achieves a 25% speedup by **eliminating redundant AST node creation** inside the loop. **Key change:** The `timeout_decorator` AST node is now created once before the loop instead of being recreated for every test method that needs it. In the original code, this AST structure was built 3,411 times during profiling, consuming significant time in object allocation and initialization. **Why this works:** AST nodes are immutable once created, so the same `timeout_decorator` instance can be safely appended to multiple method decorator lists. This eliminates: - Repeated `ast.Call()` constructor calls - Redundant `ast.Name()` and `ast.Constant()` object creation - Multiple attribute assignments for the same decorator structure **Performance characteristics:** The optimization is most effective for large test classes with many test methods (showing 24-33% improvements in tests with 500+ methods), while having minimal impact on classes with few or no test methods. This makes it particularly valuable for comprehensive test suites where classes commonly contain dozens of test methods. The line profiler shows the AST node creation operations dropped from ~3,400 hits to just ~25 hits, directly correlating with the observed speedup.
…25-09-22T19.41.32 ⚡️ Speed up method `AsyncCallInstrumenter.visit_ClassDef` by 26% in PR #739 (`get-throughput-from-output`)
add End to end test for async optimization
Get throughput from output for async functions
⚡️ Codeflash found optimizations for this PR📄 11% (0.11x) speedup for
|
⚡️ Codeflash found optimizations for this PR📄 11% (0.11x) speedup for
|
⚡️ Codeflash found optimizations for this PR📄 13% (0.13x) speedup for
|
| for node in ast.walk(stmt): | ||
| if ( | ||
| isinstance(node, ast.Await) | ||
| and isinstance(node.value, ast.Call) | ||
| and self._is_target_call(node.value) | ||
| and self._call_in_positions(node.value) | ||
| ): | ||
| # Check if this call is in one of our target positions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
⚡️Codeflash found 13% (0.13x) speedup for AsyncCallInstrumenter._instrument_statement in codeflash/code_utils/instrument_existing_tests.py
⏱️ Runtime : 7.34 milliseconds → 6.51 milliseconds (best of 169 runs)
📝 Explanation and details
The optimization replaces ast.walk() with a custom iterative traversal that specifically targets ast.Await nodes containing ast.Call nodes.
Key optimization: Instead of walking through all AST nodes (9,373 iterations in the original), the optimized version uses a stack-based approach that only yields ast.Await nodes with ast.Call values, reducing iterations to 2,740 - a 75% reduction in node visits.
How it works: The new _await_call_nodes() function uses an explicit stack to traverse the AST, only yielding nodes that match the pattern await some_call(). This eliminates the need to check isinstance(node, ast.Await) and isinstance(node.value, ast.Call) for every single node in the tree.
Performance impact: The line profiler shows the main loop time dropped from 40.8ms to 38.6ms (5% improvement), with overall function time improving from 59.2ms to 51.4ms (13% speedup). The optimization is particularly effective for test cases with:
- Large ASTs with few await calls (22-46% faster on basic cases)
- Multiple nested statements where most nodes aren't await calls
- Complex expressions where await calls are deeply embedded
This targeted traversal approach is especially beneficial when the ratio of total AST nodes to await-call patterns is high, which is typical in real codebases.
✅ Correctness verification report:
| Test | Status |
|---|---|
| ⚙️ Existing Unit Tests | 🔘 None Found |
| 🌀 Generated Regression Tests | ✅ 1564 Passed |
| ⏪ Replay Tests | 🔘 None Found |
| 🔎 Concolic Coverage Tests | 🔘 None Found |
| 📊 Tests Coverage | 92.9% |
🌀 Generated Regression Tests and Runtime
from __future__ import annotations
import ast
from types import SimpleNamespace
# imports
import pytest # used for our unit tests
from codeflash.code_utils.instrument_existing_tests import \
AsyncCallInstrumenter
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodePosition, TestingMode
class FunctionToOptimize:
def __init__(self, function_name, parents=None, top_level_parent_name=None):
self.function_name = function_name
self.parents = parents or []
self.top_level_parent_name = top_level_parent_name
class CodePosition:
def __init__(self, line_no, col_no, end_col_offset=None):
self.line_no = line_no
self.col_no = col_no
self.end_col_offset = end_col_offset
class TestingMode:
BEHAVIOR = "behavior"
COVERAGE = "coverage"
# ---- TEST FIXTURE HELPERS ----
def make_instrumenter(
function_name="foo",
call_positions=None,
parents=None,
top_level_parent_name=None,
mode=TestingMode.BEHAVIOR,
):
func = FunctionToOptimize(
function_name=function_name,
parents=parents,
top_level_parent_name=top_level_parent_name,
)
return AsyncCallInstrumenter(
function=func,
module_path="dummy.py",
test_framework="pytest",
call_positions=call_positions or [],
mode=mode,
)
def parse_stmt(src):
"""Parse a single statement from code string, return ast.stmt node."""
mod = ast.parse(src)
# Return the first statement node
return mod.body[0]
def set_ast_positions(node, lineno=1, col_offset=0, end_lineno=None, end_col_offset=None):
"""Set lineno/col_offset/end_lineno/end_col_offset recursively for all nodes."""
for sub in ast.walk(node):
if hasattr(sub, "lineno"):
sub.lineno = lineno
if hasattr(sub, "col_offset"):
sub.col_offset = col_offset
if hasattr(sub, "end_lineno") and end_lineno is not None:
sub.end_lineno = end_lineno
if hasattr(sub, "end_col_offset") and end_col_offset is not None:
sub.end_col_offset = end_col_offset
return node
# ---- BASIC TEST CASES ----
def test_no_await_returns_false():
"""No await in statement: should return (stmt, False)."""
stmt = parse_stmt("x = 1")
instr = make_instrumenter()
new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 8.68μs -> 6.75μs (28.5% faster)
def test_await_non_target_function_returns_false():
"""Await of a function that is not the target: should return False."""
stmt = parse_stmt("await bar()")
# Set positions so that node_in_call_position can work
set_ast_positions(stmt, lineno=1, col_offset=0, end_lineno=1, end_col_offset=10)
pos = [CodePosition(line_no=1, col_no=7, end_col_offset=10)]
instr = make_instrumenter(function_name="foo", call_positions=pos)
new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 6.71μs -> 5.79μs (15.9% faster)
def test_await_target_function_not_in_position_returns_false():
"""Await of target function, but not in call_positions: should return False."""
stmt = parse_stmt("await foo()")
set_ast_positions(stmt, lineno=3, col_offset=0, end_lineno=3, end_col_offset=10)
pos = [CodePosition(line_no=1, col_no=7, end_col_offset=10)] # position does not match
instr = make_instrumenter(function_name="foo", call_positions=pos)
new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 7.74μs -> 6.76μs (14.4% faster)
def test_await_target_function_in_position_returns_true():
"""Await of target function, in call_positions: should return True."""
stmt = parse_stmt("await foo()")
set_ast_positions(stmt, lineno=1, col_offset=0, end_lineno=1, end_col_offset=10)
pos = [CodePosition(line_no=1, col_no=7, end_col_offset=10)]
instr = make_instrumenter(function_name="foo", call_positions=pos)
new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 6.17μs -> 5.07μs (21.7% faster)
def test_await_target_method_in_position_returns_true():
"""Await of target function as method attribute, in call_positions: should return True."""
stmt = parse_stmt("await obj.foo()")
set_ast_positions(stmt, lineno=2, col_offset=0, end_lineno=2, end_col_offset=14)
pos = [CodePosition(line_no=2, col_no=8, end_col_offset=14)]
instr = make_instrumenter(function_name="foo", call_positions=pos)
new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 6.13μs -> 4.94μs (24.2% faster)
# ---- EDGE TEST CASES ----
def test_multiple_awaits_only_one_matches():
"""Multiple awaits, only one matches target and position."""
src = """
await foo()
await bar()
await foo()
"""
mod = ast.parse(src)
# Set positions for each await
set_ast_positions(mod.body[0], lineno=1, col_offset=0, end_lineno=1, end_col_offset=10)
set_ast_positions(mod.body[1], lineno=2, col_offset=0, end_lineno=2, end_col_offset=10)
set_ast_positions(mod.body[2], lineno=3, col_offset=0, end_lineno=3, end_col_offset=10)
pos = [CodePosition(line_no=3, col_no=7, end_col_offset=10)]
instr = make_instrumenter(function_name="foo", call_positions=pos)
# Should return True for the third statement only
_, did_instrument_1 = instr._instrument_statement(mod.body[0], "_") # 7.06μs -> 6.77μs (4.28% faster)
_, did_instrument_2 = instr._instrument_statement(mod.body[1], "_") # 4.94μs -> 4.12μs (20.0% faster)
_, did_instrument_3 = instr._instrument_statement(mod.body[2], "_") # 4.48μs -> 3.25μs (38.0% faster)
def test_await_target_function_multiple_positions():
"""Await of target function, several call_positions, one matches."""
stmt = parse_stmt("await foo()")
set_ast_positions(stmt, lineno=5, col_offset=0, end_lineno=5, end_col_offset=10)
pos = [
CodePosition(line_no=4, col_no=7, end_col_offset=10),
CodePosition(line_no=5, col_no=7, end_col_offset=10),
]
instr = make_instrumenter(function_name="foo", call_positions=pos)
new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 6.14μs -> 5.04μs (21.9% faster)
def test_await_target_function_missing_ast_positions():
"""Await of target function, but ast node missing lineno/col_offset: should return False."""
stmt = parse_stmt("await foo()")
# Remove lineno/col_offset from the ast.Call node
for node in ast.walk(stmt):
if isinstance(node, ast.Call):
if hasattr(node, "lineno"):
delattr(node, "lineno")
if hasattr(node, "col_offset"):
delattr(node, "col_offset")
pos = [CodePosition(line_no=1, col_no=7, end_col_offset=10)]
instr = make_instrumenter(function_name="foo", call_positions=pos)
new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 6.57μs -> 5.99μs (9.70% faster)
def test_await_target_function_nested_in_expr():
"""Await of target function nested inside an expression (should still match if position matches)."""
stmt = parse_stmt("x = await foo() + 1")
# Set positions for the foo() call
for node in ast.walk(stmt):
if isinstance(node, ast.Call):
node.lineno = 1
node.col_offset = 9
node.end_lineno = 1
node.end_col_offset = 15
pos = [CodePosition(line_no=1, col_no=10, end_col_offset=15)]
instr = make_instrumenter(function_name="foo", call_positions=pos)
new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 8.77μs -> 7.91μs (10.8% faster)
def test_await_target_function_attribute_chain():
"""Await of target function as deeply nested attribute (should match if attr name matches)."""
stmt = parse_stmt("await obj.sub.foo()")
for node in ast.walk(stmt):
if isinstance(node, ast.Call):
node.lineno = 1
node.col_offset = 6
node.end_lineno = 1
node.end_col_offset = 18
pos = [CodePosition(line_no=1, col_no=10, end_col_offset=18)]
instr = make_instrumenter(function_name="foo", call_positions=pos)
new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 6.06μs -> 4.80μs (26.3% faster)
def test_await_target_function_col_offset_range():
"""Await of target function, position matches only if col_offset is within range."""
stmt = parse_stmt("await foo()")
set_ast_positions(stmt, lineno=1, col_offset=0, end_lineno=1, end_col_offset=20)
pos = [CodePosition(line_no=1, col_no=15, end_col_offset=20)]
instr = make_instrumenter(function_name="foo", call_positions=pos)
# Should match because col_offset <= pos.col_no <= end_col_offset
new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 5.78μs -> 4.73μs (22.3% faster)
def test_await_target_function_line_range():
"""Await of target function spanning multiple lines, position falls inside."""
src = "await foo(\n 1,\n 2\n)"
stmt = parse_stmt(src)
# Set the call node to span lines 1-4
for node in ast.walk(stmt):
if isinstance(node, ast.Call):
node.lineno = 1
node.col_offset = 6
node.end_lineno = 4
node.end_col_offset = 1
pos = [CodePosition(line_no=3, col_no=4, end_col_offset=1)]
instr = make_instrumenter(function_name="foo", call_positions=pos)
new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 6.02μs -> 4.91μs (22.6% faster)
# ---- LARGE SCALE TEST CASES ----
def test_large_number_of_call_positions():
"""Test with a large number of call_positions (up to 1000), only one matches."""
stmt = parse_stmt("await foo()")
set_ast_positions(stmt, lineno=100, col_offset=0, end_lineno=100, end_col_offset=10)
# 999 non-matching positions, 1 matching
pos = [CodePosition(line_no=i, col_no=0, end_col_offset=10) for i in range(1, 1000)]
pos.append(CodePosition(line_no=100, col_no=0, end_col_offset=10))
instr = make_instrumenter(function_name="foo", call_positions=pos)
new_stmt, did_instrument = instr._instrument_statement(stmt, "_") # 16.7μs -> 14.9μs (12.3% faster)
def test_large_ast_statement_with_one_matching_await():
"""Test a large ast statement (with many subnodes), only one await matches."""
# Build a large statement with many assignments, only one has await foo()
stmts = ["x{} = {}".format(i, i) for i in range(500)]
stmts.insert(250, "await foo()")
src = "\n".join(stmts)
mod = ast.parse(src)
# Set positions for the 251st statement (index 250)
set_ast_positions(mod.body[250], lineno=251, col_offset=0, end_lineno=251, end_col_offset=10)
pos = [CodePosition(line_no=251, col_no=7, end_col_offset=10)]
instr = make_instrumenter(function_name="foo", call_positions=pos)
new_stmt, did_instrument = instr._instrument_statement(mod.body[250], "_") # 6.95μs -> 5.67μs (22.6% faster)
# All others should be False
for i in range(500):
if i == 250:
continue
_, did_instrument_other = instr._instrument_statement(mod.body[i], "_") # 1.85ms -> 1.57ms (17.9% faster)
def test_large_number_of_awaits_only_last_matches():
"""Test 1000 awaits, only the last matches call_positions."""
src = "\n".join([f"await foo()" for _ in range(999)] + ["await foo()"])
mod = ast.parse(src)
# Set positions for each await
for i, stmt in enumerate(mod.body):
set_ast_positions(stmt, lineno=i+1, col_offset=0, end_lineno=i+1, end_col_offset=10)
pos = [CodePosition(line_no=1000, col_no=7, end_col_offset=10)]
instr = make_instrumenter(function_name="foo", call_positions=pos)
for i in range(999):
_, did_instrument = instr._instrument_statement(mod.body[i], "_") # 4.63ms -> 3.97ms (16.6% faster)
_, did_instrument_last = instr._instrument_statement(mod.body[999], "_") # 4.58μs -> 3.39μs (35.2% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from __future__ import annotations
import ast
# imports
import pytest # used for our unit tests
from codeflash.code_utils.instrument_existing_tests import \
AsyncCallInstrumenter
# Minimal stubs for required classes
class CodePosition:
def __init__(self, line_no, col_no, end_col_offset=None):
self.line_no = line_no
self.col_no = col_no
self.end_col_offset = end_col_offset
class FunctionToOptimize:
def __init__(self, function_name, parents=None, top_level_parent_name=None):
self.function_name = function_name
self.parents = parents or []
self.top_level_parent_name = top_level_parent_name
# unit tests
# Helper to create an Await node with a Call to a function, with position attributes
def make_await_call_stmt(func_name, lineno=1, col_offset=0, end_lineno=None, end_col_offset=None):
call = ast.Call(
func=ast.Name(id=func_name, ctx=ast.Load()),
args=[],
keywords=[],
)
call.lineno = lineno
call.col_offset = col_offset
call.end_lineno = end_lineno if end_lineno is not None else lineno
call.end_col_offset = end_col_offset if end_col_offset is not None else col_offset + 10
await_node = ast.Await(value=call)
await_node.lineno = lineno
await_node.col_offset = col_offset
await_node.end_lineno = call.end_lineno
await_node.end_col_offset = call.end_col_offset
expr = ast.Expr(value=await_node)
expr.lineno = lineno
expr.col_offset = col_offset
expr.end_lineno = call.end_lineno
expr.end_col_offset = call.end_col_offset
return expr
# Helper to create a non-await call to the function
def make_non_await_call_stmt(func_name, lineno=1, col_offset=0, end_lineno=None, end_col_offset=None):
call = ast.Call(
func=ast.Name(id=func_name, ctx=ast.Load()),
args=[],
keywords=[],
)
call.lineno = lineno
call.col_offset = col_offset
call.end_lineno = end_lineno if end_lineno is not None else lineno
call.end_col_offset = end_col_offset if end_col_offset is not None else col_offset + 10
expr = ast.Expr(value=call)
expr.lineno = lineno
expr.col_offset = col_offset
expr.end_lineno = call.end_lineno
expr.end_col_offset = call.end_col_offset
return expr
# Helper to create an Await node with a Call to a different function
def make_await_call_stmt_other(func_name, other_func_name, lineno=1, col_offset=0, end_lineno=None, end_col_offset=None):
call = ast.Call(
func=ast.Name(id=other_func_name, ctx=ast.Load()),
args=[],
keywords=[],
)
call.lineno = lineno
call.col_offset = col_offset
call.end_lineno = end_lineno if end_lineno is not None else lineno
call.end_col_offset = end_col_offset if end_col_offset is not None else col_offset + 10
await_node = ast.Await(value=call)
await_node.lineno = lineno
await_node.col_offset = col_offset
await_node.end_lineno = call.end_lineno
await_node.end_col_offset = call.end_col_offset
expr = ast.Expr(value=await_node)
expr.lineno = lineno
expr.col_offset = col_offset
expr.end_lineno = call.end_lineno
expr.end_col_offset = call.end_col_offset
return expr
# 1. Basic Test Cases
def test_instrument_await_target_call_in_position():
"""Should instrument when awaiting the target function at the correct position."""
func_name = "foo"
stmt = make_await_call_stmt(func_name, lineno=10, col_offset=5, end_lineno=10, end_col_offset=15)
call_positions = [CodePosition(line_no=10, col_no=5, end_col_offset=15)]
function = FunctionToOptimize(func_name)
instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
new_stmt, did_instrument = instr._instrument_statement(stmt, "_node") # 8.68μs -> 5.94μs (46.0% faster)
def test_no_instrument_when_not_await():
"""Should not instrument if the call is not awaited."""
func_name = "foo"
stmt = make_non_await_call_stmt(func_name, lineno=10, col_offset=5)
call_positions = [CodePosition(line_no=10, col_no=5)]
function = FunctionToOptimize(func_name)
instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
new_stmt, did_instrument = instr._instrument_statement(stmt, "_node") # 7.83μs -> 5.40μs (45.1% faster)
def test_no_instrument_when_call_to_other_function():
"""Should not instrument if the awaited call is to a different function."""
func_name = "foo"
stmt = make_await_call_stmt_other(func_name, "bar", lineno=10, col_offset=5)
call_positions = [CodePosition(line_no=10, col_no=5)]
function = FunctionToOptimize(func_name)
instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
new_stmt, did_instrument = instr._instrument_statement(stmt, "_node") # 8.85μs -> 6.37μs (38.8% faster)
def test_no_instrument_when_position_does_not_match():
"""Should not instrument if the awaited call is not at a target position."""
func_name = "foo"
stmt = make_await_call_stmt(func_name, lineno=10, col_offset=5)
call_positions = [CodePosition(line_no=11, col_no=5)]
function = FunctionToOptimize(func_name)
instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
new_stmt, did_instrument = instr._instrument_statement(stmt, "_node") # 9.88μs -> 7.61μs (29.7% faster)
def test_instrument_with_attribute_call():
"""Should instrument if the awaited call is to an attribute with the correct function name."""
func_name = "foo"
# Build ast.Await(value=ast.Call(func=ast.Attribute(...)))
call = ast.Call(
func=ast.Attribute(
value=ast.Name(id="obj", ctx=ast.Load()),
attr=func_name,
ctx=ast.Load()
),
args=[],
keywords=[]
)
call.lineno = 20
call.col_offset = 2
call.end_lineno = 20
call.end_col_offset = 12
await_node = ast.Await(value=call)
await_node.lineno = 20
await_node.col_offset = 2
await_node.end_lineno = 20
await_node.end_col_offset = 12
expr = ast.Expr(value=await_node)
expr.lineno = 20
expr.col_offset = 2
expr.end_lineno = 20
expr.end_col_offset = 12
call_positions = [CodePosition(line_no=20, col_no=2, end_col_offset=12)]
function = FunctionToOptimize(func_name)
instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
new_stmt, did_instrument = instr._instrument_statement(expr, "_node") # 8.01μs -> 5.48μs (46.0% faster)
# 2. Edge Test Cases
def test_instrument_multiple_awaits_only_one_matches():
"""Should instrument if only one of several awaits matches the target function and position."""
func_name = "foo"
stmt1 = make_await_call_stmt(func_name, lineno=5, col_offset=0)
stmt2 = make_await_call_stmt_other(func_name, "bar", lineno=6, col_offset=0)
stmt3 = make_await_call_stmt(func_name, lineno=10, col_offset=5)
# Compose a block with multiple statements
block = ast.If(
test=ast.Constant(value=True),
body=[stmt1, stmt2, stmt3],
orelse=[]
)
block.lineno = 5
block.col_offset = 0
block.end_lineno = 10
block.end_col_offset = 15
call_positions = [CodePosition(line_no=10, col_no=5)]
function = FunctionToOptimize(func_name)
instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
new_stmt, did_instrument = instr._instrument_statement(block, "_node") # 14.2μs -> 7.05μs (102% faster)
def test_instrument_with_missing_lineno_col_offset():
"""Should not instrument if the call node is missing lineno/col_offset attributes."""
func_name = "foo"
call = ast.Call(
func=ast.Name(id=func_name, ctx=ast.Load()),
args=[],
keywords=[]
)
# Do not set lineno/col_offset attributes
await_node = ast.Await(value=call)
expr = ast.Expr(value=await_node)
call_positions = [CodePosition(line_no=1, col_no=0)]
function = FunctionToOptimize(func_name)
instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
new_stmt, did_instrument = instr._instrument_statement(expr, "_node") # 8.82μs -> 6.54μs (34.8% faster)
def test_instrument_with_nested_calls():
"""Should instrument if the awaited call is nested inside another statement (e.g., in a list)."""
func_name = "foo"
stmt = make_await_call_stmt(func_name, lineno=15, col_offset=3)
list_stmt = ast.List(elts=[stmt], ctx=ast.Load())
expr = ast.Expr(value=list_stmt)
expr.lineno = 15
expr.col_offset = 3
expr.end_lineno = 15
expr.end_col_offset = 20
call_positions = [CodePosition(line_no=15, col_no=3)]
function = FunctionToOptimize(func_name)
instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
new_stmt, did_instrument = instr._instrument_statement(expr, "_node") # 10.3μs -> 8.26μs (24.3% faster)
def test_instrument_with_multiple_call_positions():
"""Should instrument if the awaited call matches any of several call positions."""
func_name = "foo"
stmt = make_await_call_stmt(func_name, lineno=30, col_offset=7)
call_positions = [
CodePosition(line_no=10, col_no=1),
CodePosition(line_no=30, col_no=7),
CodePosition(line_no=40, col_no=0)
]
function = FunctionToOptimize(func_name)
instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
new_stmt, did_instrument = instr._instrument_statement(stmt, "_node") # 7.88μs -> 5.06μs (55.7% faster)
def test_instrument_with_end_col_offset_variation():
"""Should instrument when call position matches at the end_col_offset."""
func_name = "foo"
stmt = make_await_call_stmt(func_name, lineno=50, col_offset=2, end_col_offset=20)
call_positions = [CodePosition(line_no=50, col_no=18, end_col_offset=20)]
function = FunctionToOptimize(func_name)
instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
new_stmt, did_instrument = instr._instrument_statement(stmt, "_node") # 7.36μs -> 4.82μs (52.8% faster)
def test_instrument_with_no_call_positions():
"""Should not instrument if call_positions is empty."""
func_name = "foo"
stmt = make_await_call_stmt(func_name, lineno=10, col_offset=5)
call_positions = []
function = FunctionToOptimize(func_name)
instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
new_stmt, did_instrument = instr._instrument_statement(stmt, "_node") # 9.21μs -> 6.98μs (31.8% faster)
def test_instrument_with_non_matching_attribute():
"""Should not instrument if attribute call does not match function name."""
func_name = "foo"
call = ast.Call(
func=ast.Attribute(
value=ast.Name(id="obj", ctx=ast.Load()),
attr="bar",
ctx=ast.Load()
),
args=[],
keywords=[]
)
call.lineno = 60
call.col_offset = 4
call.end_lineno = 60
call.end_col_offset = 14
await_node = ast.Await(value=call)
await_node.lineno = 60
await_node.col_offset = 4
await_node.end_lineno = 60
await_node.end_col_offset = 14
expr = ast.Expr(value=await_node)
expr.lineno = 60
expr.col_offset = 4
expr.end_lineno = 60
expr.end_col_offset = 14
call_positions = [CodePosition(line_no=60, col_no=4, end_col_offset=14)]
function = FunctionToOptimize(func_name)
instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
new_stmt, did_instrument = instr._instrument_statement(expr, "_node") # 10.0μs -> 8.54μs (17.3% faster)
# 3. Large Scale Test Cases
def test_large_number_of_call_positions():
"""Should instrument if the awaited call matches among many call positions."""
func_name = "foo"
stmt = make_await_call_stmt(func_name, lineno=500, col_offset=8)
# 999 positions that don't match, 1 that does
call_positions = [CodePosition(line_no=i, col_no=0) for i in range(1, 500)]
call_positions.append(CodePosition(line_no=500, col_no=8))
function = FunctionToOptimize(func_name)
instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
new_stmt, did_instrument = instr._instrument_statement(stmt, "_node") # 49.1μs -> 46.5μs (5.56% faster)
def test_large_ast_with_one_matching_await():
"""Should instrument when only one awaited call in a large AST matches the target and position."""
func_name = "foo"
stmts = []
for i in range(1, 100):
# Most are to other functions
stmts.append(make_await_call_stmt_other(func_name, f"bar{i}", lineno=i, col_offset=0))
# Insert a matching awaited call at position 50
matching_stmt = make_await_call_stmt(func_name, lineno=50, col_offset=4)
stmts[49] = matching_stmt # Replace the 50th statement
block = ast.Module(body=stmts, type_ignores=[])
block.lineno = 1
block.col_offset = 0
block.end_lineno = 99
block.end_col_offset = 10
call_positions = [CodePosition(line_no=50, col_no=4)]
function = FunctionToOptimize(func_name)
instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
new_stmt, did_instrument = instr._instrument_statement(block, "_node") # 101μs -> 166μs (38.9% slower)
def test_large_ast_with_no_matching_await():
"""Should not instrument when no awaited call in a large AST matches the target and position."""
func_name = "foo"
stmts = []
for i in range(1, 100):
stmts.append(make_await_call_stmt_other(func_name, f"bar{i}", lineno=i, col_offset=0))
block = ast.Module(body=stmts, type_ignores=[])
block.lineno = 1
block.col_offset = 0
block.end_lineno = 99
block.end_col_offset = 10
call_positions = [CodePosition(line_no=150, col_no=0)]
function = FunctionToOptimize(func_name)
instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
new_stmt, did_instrument = instr._instrument_statement(block, "_node") # 317μs -> 317μs (0.009% slower)
def test_large_ast_with_multiple_matching_awaits():
"""Should instrument when multiple awaited calls in a large AST match the target and positions."""
func_name = "foo"
stmts = []
match_lines = [10, 20, 30, 40, 50]
for i in range(1, 101):
if i in match_lines:
stmts.append(make_await_call_stmt(func_name, lineno=i, col_offset=2))
else:
stmts.append(make_await_call_stmt_other(func_name, f"bar{i}", lineno=i, col_offset=0))
block = ast.Module(body=stmts, type_ignores=[])
block.lineno = 1
block.col_offset = 0
block.end_lineno = 100
block.end_col_offset = 10
call_positions = [CodePosition(line_no=i, col_no=2) for i in match_lines]
function = FunctionToOptimize(func_name)
instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
new_stmt, did_instrument = instr._instrument_statement(block, "_node") # 72.0μs -> 169μs (57.4% slower)
def test_large_call_positions_all_non_matching():
"""Should not instrument if there are many call positions but none match."""
func_name = "foo"
stmt = make_await_call_stmt(func_name, lineno=1000, col_offset=0)
call_positions = [CodePosition(line_no=i, col_no=1) for i in range(1, 1000)]
function = FunctionToOptimize(func_name)
instr = AsyncCallInstrumenter(function, "mod.py", "pytest", call_positions)
new_stmt, did_instrument = instr._instrument_statement(stmt, "_node") # 93.3μs -> 90.8μs (2.68% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.To test or edit this optimization locally git merge codeflash/optimize-pr678-2025-09-26T20.19.54
Click to see suggested changes
| for node in ast.walk(stmt): | |
| if ( | |
| isinstance(node, ast.Await) | |
| and isinstance(node.value, ast.Call) | |
| and self._is_target_call(node.value) | |
| and self._call_in_positions(node.value) | |
| ): | |
| # Check if this call is in one of our target positions | |
| def _await_call_nodes(node): | |
| stack = [node] | |
| while stack: | |
| cur = stack.pop() | |
| if isinstance(cur, ast.Await) and isinstance(cur.value, ast.Call): | |
| yield cur | |
| for child in ast.iter_child_nodes(cur): | |
| stack.append(child) | |
| for await_node in _await_call_nodes(stmt): | |
| call_node = await_node.value | |
| if self._is_target_call(call_node) and self._call_in_positions(call_node): |
40c4108 to
7bbb1e7
Compare
|
closing in favor of #769 - cleaner commits |
Pull request was closed
PR Type
Enhancement, Bug fix, Tests
Description
Add support for async function detection
Remove
has_any_async_functionsusageFix coverage utils empty database return
Update tests for async discovery and code validation
Diagram Walkthrough
File Walkthrough
4 files
Remove unused async detection functionAdd AsyncFunctionDef support to lookup functionsEnable async function discovery and flaggingRemove async restriction and support async AST1 files
Fix empty DB check order and return `CoverageData`3 files
Add comprehensive async discovery testsResolve symlink paths in code context testsReplace async detection tests with code validation