-
Couldn't load subscription status.
- Fork 22
support async functions in codeflash #617
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1abf574
8654ed8
bd59fe7
4112eff
557d1ab
479da54
c07d9d1
d948ad7
4132759
f1de751
7ec754c
42fb4f0
b957d3c
b02a5bb
eabfd2b
0042d08
f61c139
99b4300
3bafc8d
a150e43
f2c0740
66eac2a
e292c5e
a5182c6
c4e3e00
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| import asyncio | ||
|
|
||
|
|
||
| async def async_add(a, b): | ||
| """Simple async function that adds two numbers.""" | ||
| await asyncio.sleep(0.001) # Simulate some async work | ||
| print(f"codeflash stdout: Adding {a} + {b}") | ||
| result = a + b | ||
| print(f"result: {result}") | ||
| return result |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| [tool.codeflash] | ||
| # All paths are relative to this pyproject.toml's directory. | ||
| module-root = "." | ||
| tests-root = "tests" | ||
| test-framework = "pytest" | ||
| ignore-paths = [] | ||
| formatter-cmds = ["black $file"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| from time import sleep | ||
|
|
||
| async def tasked(): | ||
| sleep(0.002) | ||
| return "Tasked" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| import pytest | ||
| import asyncio | ||
| from shocker import tasked | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_tasked_basic(): | ||
| result = await tasked() | ||
| assert result == "Tasked" | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_tasked_gather(): | ||
| results = await asyncio.gather(*(tasked() for _ in range(5))) | ||
| assert results == ["Tasked"] * 5 | ||
|
|
||
|
|
||
| def test_tasked_many_parallel_invocations(): | ||
| async def run_many(): | ||
| tasks = [tasked() for _ in range(1000)] | ||
| results = await asyncio.gather(*tasks) | ||
| return results | ||
|
|
||
| results = asyncio.run(run_many()) | ||
| assert len(results) == 1000, "Should return 1000 results" | ||
| assert all(r == "Tasked" for r in results), "All results should be 'Tasked'" |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -48,6 +48,39 @@ def is_argument_name(name: str, arguments_node: ast.arguments) -> bool: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class AsyncIOGatherRemover(ast.NodeTransformer): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _contains_asyncio_gather(self, node: ast.AST) -> bool: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Check if a node contains asyncio.gather calls.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for child_node in ast.walk(node): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| isinstance(child_node, ast.Call) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| and isinstance(child_node.func, ast.Attribute) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| and isinstance(child_node.func.value, ast.Name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| and child_node.func.value.id == "asyncio" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| and child_node.func.attr == "gather" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| isinstance(child_node, ast.Call) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| and isinstance(child_node.func, ast.Name) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| and child_node.func.id == "gather" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+54
to
+70
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ⚡️Codeflash found 36% (0.36x) speedup for
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Test | Status |
|---|---|
| ⚙️ Existing Unit Tests | 🔘 None Found |
| 🌀 Generated Regression Tests | ✅ 62 Passed |
| ⏪ Replay Tests | 🔘 None Found |
| 🔎 Concolic Coverage Tests | 🔘 None Found |
| 📊 Tests Coverage | 100.0% |
🌀 Generated Regression Tests and Runtime
from __future__ import annotations
import ast
# imports
import pytest # used for our unit tests
from codeflash.code_utils.instrument_existing_tests import AsyncIOGatherRemover
# unit tests
# Helper function to parse source and extract the first FunctionDef node
def get_first_functiondef(source: str) -> ast.FunctionDef:
"""Parse Python source and return the first FunctionDef node."""
tree = ast.parse(source)
for node in ast.iter_child_nodes(tree):
if isinstance(node, ast.FunctionDef):
return node
raise ValueError("No FunctionDef found in source")
# -----------------
# 1. Basic Test Cases
# -----------------
def test_non_test_function_no_gather_retains_function():
"""Function not starting with 'test_' and no gather should be kept."""
src = "def foo():\n x = 1"
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 16.6μs -> 17.2μs (3.50% slower)
def test_test_function_no_gather_retains_function():
"""Function starting with 'test_' but no gather should be kept."""
src = "def test_example():\n y = 2"
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 26.9μs -> 25.1μs (7.26% faster)
def test_test_function_with_asyncio_gather_removes_function():
"""Function starting with 'test_' and containing asyncio.gather should be removed."""
src = "def test_async():\n await asyncio.gather(a, b)"
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 13.0μs -> 6.49μs (100% faster)
def test_test_function_with_gather_direct_call_removes_function():
"""Function starting with 'test_' and containing direct gather() call should be removed."""
src = "def test_direct():\n await gather(a, b)"
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 12.4μs -> 6.14μs (101% faster)
# -----------------
# 2. Edge Test Cases
# -----------------
def test_non_test_function_with_asyncio_gather_kept():
"""Non-test function with asyncio.gather should not be removed."""
src = "def foo():\n await asyncio.gather(a, b)"
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 18.2μs -> 18.5μs (1.52% slower)
def test_test_function_with_gather_in_string_kept():
"""'gather' in a string literal should not trigger removal."""
src = "def test_str():\n x = 'asyncio.gather(a, b)'"
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 26.1μs -> 24.5μs (6.41% faster)
def test_test_function_with_gather_as_variable_kept():
"""'gather' used as a variable, not a call, should not trigger removal."""
src = "def test_var():\n gather = 5\n x = gather + 2"
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 41.0μs -> 39.5μs (3.75% faster)
def test_test_function_with_gather_call_on_other_object_kept():
"""gather called as method on another object should not trigger removal."""
src = "def test_other():\n obj.gather(a, b)"
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 31.3μs -> 29.8μs (5.11% faster)
def test_test_function_with_gather_in_comment_kept():
"""gather in a comment should not trigger removal."""
src = "def test_comment():\n # asyncio.gather(a, b)\n pass"
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 15.9μs -> 14.2μs (11.6% faster)
def test_test_function_with_nested_gather_call_removes_function():
"""gather call nested in an inner function should trigger removal."""
src = (
"def test_nested():\n"
" def inner():\n"
" asyncio.gather(a, b)\n"
" inner()"
)
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 17.0μs -> 10.8μs (58.1% faster)
def test_test_function_with_gather_in_lambda_removes_function():
"""gather call inside a lambda triggers removal."""
src = "def test_lambda():\n f = lambda: asyncio.gather(a, b)\n f()"
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 18.6μs -> 10.6μs (75.0% faster)
def test_test_function_with_gather_as_keyword_arg_kept():
"""'gather' used as a keyword argument should not trigger removal."""
src = "def test_kwarg():\n foo(gather=5)"
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 30.0μs -> 28.8μs (4.07% faster)
def test_test_function_with_gather_imported_as_alias_removes_function():
"""Direct gather() call, even if imported as alias, triggers removal."""
src = "def test_alias():\n await gather(a, b)"
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 12.4μs -> 6.11μs (102% faster)
def test_test_function_with_gather_in_try_except_removes_function():
"""gather call inside try/except triggers removal."""
src = (
"def test_try():\n"
" try:\n"
" asyncio.gather(a, b)\n"
" except Exception:\n"
" pass"
)
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 14.3μs -> 10.6μs (35.1% faster)
def test_test_function_with_gather_in_listcomp_removes_function():
"""gather call inside a list comprehension triggers removal."""
src = "def test_listcomp():\n [asyncio.gather(a, b) for a in range(2)]"
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 12.8μs -> 13.1μs (2.30% slower)
def test_test_function_with_gather_in_decorator_kept():
"""gather used in a decorator name should not trigger removal."""
src = "@gather\ndef test_decorated():\n pass"
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 19.9μs -> 18.4μs (7.99% faster)
def test_test_function_with_gather_in_attribute_but_not_asyncio_kept():
"""gather called as attribute of something other than 'asyncio' should not trigger removal."""
src = "def test_attr():\n foo.gather(a, b)"
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 31.4μs -> 29.7μs (5.57% faster)
def test_test_function_with_gather_in_multiple_places_removes_function():
"""Multiple gather calls in a test function should trigger removal."""
src = (
"def test_multi():\n"
" asyncio.gather(a, b)\n"
" gather(a, b)\n"
" x = 1"
)
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 13.6μs -> 9.06μs (50.2% faster)
def test_test_function_with_gather_as_class_method_removes_function():
"""gather call inside a test function as a class method triggers removal."""
src = (
"def test_class():\n"
" class C:\n"
" def m(self):\n"
" asyncio.gather(a, b)\n"
" C().m()"
)
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 20.8μs -> 14.8μs (40.6% faster)
# -----------------
# 3. Large Scale Test Cases
# -----------------
def test_large_function_with_many_statements_and_one_gather_removes_function():
"""Large test function with many statements and one gather call should be removed."""
body = "\n ".join([f"x{i} = {i}" for i in range(500)])
src = f"def test_large():\n {body}\n asyncio.gather(a, b)\n y = 3"
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 1.24ms -> 37.4μs (3206% faster)
def test_large_function_with_many_statements_no_gather_kept():
"""Large test function with many statements, no gather call, should be kept."""
body = "\n ".join([f"x{i} = {i}" for i in range(500)])
src = f"def test_large_nogather():\n {body}\n y = 3"
node = get_first_functiondef(src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_FunctionDef(node); result = codeflash_output # 3.60ms -> 3.46ms (3.88% faster)
def test_many_functions_some_with_gather():
"""Apply remover to a module with many functions, some test_ with gather, some not."""
src = (
"def test_a():\n asyncio.gather(a, b)\n"
"def test_b():\n pass\n"
"def foo():\n asyncio.gather(a, b)\n"
"def test_c():\n gather(a, b)\n"
"def bar():\n pass\n"
"def test_d():\n pass\n"
)
tree = ast.parse(src)
remover = AsyncIOGatherRemover()
new_tree = remover.visit(tree)
# Collect function names in the transformed tree
kept_names = [n.name for n in new_tree.body if isinstance(n, ast.FunctionDef)]
def test_large_number_of_functions_performance():
"""Performance: 1000 functions, only a few test_ with gather should be removed."""
src_lines = []
# 995 non-test functions
for i in range(995):
src_lines.append(f"def func_{i}():\n pass")
# 5 test functions with gather
for i in range(5):
src_lines.append(f"def test_remove_{i}():\n asyncio.gather(a, b)")
src = "\n".join(src_lines)
tree = ast.parse(src)
remover = AsyncIOGatherRemover()
new_tree = remover.visit(tree)
# Only the 995 non-test functions should remain
kept_names = [n.name for n in new_tree.body if isinstance(n, ast.FunctionDef)]
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from __future__ import annotations
import ast
# imports
import pytest # used for our unit tests
from codeflash.code_utils.instrument_existing_tests import AsyncIOGatherRemover
# unit tests
def parse_and_get_funcdef_nodes(source_code):
"""
Helper to parse source code and return all ast.FunctionDef nodes.
"""
tree = ast.parse(source_code)
return [node for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)]
def run_remover_on_funcdef(node):
"""
Helper to run AsyncIOGatherRemover on a single FunctionDef node.
"""
remover = AsyncIOGatherRemover()
return remover.visit_FunctionDef(node)
# 1. Basic Test Cases
def test_non_test_function_with_asyncio_gather_is_preserved():
# Function does not start with 'test_' but contains asyncio.gather
code = """
def foo():
x = asyncio.gather(a(), b())
return x
"""
func_node = parse_and_get_funcdef_nodes(code)[0]
result = run_remover_on_funcdef(func_node)
def test_test_function_without_asyncio_gather_is_preserved():
# Function starts with 'test_' but does NOT contain asyncio.gather
code = """
def test_bar():
x = await something()
return x
"""
func_node = parse_and_get_funcdef_nodes(code)[0]
result = run_remover_on_funcdef(func_node)
def test_test_function_with_asyncio_gather_is_removed():
# Function starts with 'test_' and contains asyncio.gather
code = """
def test_baz():
await asyncio.gather(a(), b())
"""
func_node = parse_and_get_funcdef_nodes(code)[0]
result = run_remover_on_funcdef(func_node)
def test_test_function_with_gather_imported_directly_is_removed():
# Function starts with 'test_' and uses 'gather' directly (imported)
code = """
def test_direct_gather():
await gather(a(), b())
"""
func_node = parse_and_get_funcdef_nodes(code)[0]
result = run_remover_on_funcdef(func_node)
def test_non_test_function_with_gather_directly_is_preserved():
# Function does not start with 'test_' and uses 'gather' directly
code = """
def foo_gather():
await gather(a(), b())
"""
func_node = parse_and_get_funcdef_nodes(code)[0]
result = run_remover_on_funcdef(func_node)
# 2. Edge Test Cases
def test_test_function_with_asyncio_gather_in_nested_function_is_removed():
# 'test_' function with gather in a nested function
code = """
def test_nested():
def inner():
return asyncio.gather(a(), b())
return inner()
"""
func_node = parse_and_get_funcdef_nodes(code)[0]
result = run_remover_on_funcdef(func_node)
def test_test_function_with_gather_as_variable_not_removed():
# 'test_' function where 'gather' is a variable, not a function call
code = """
def test_var_gather():
gather = 42
return gather
"""
func_node = parse_and_get_funcdef_nodes(code)[0]
result = run_remover_on_funcdef(func_node)
def test_test_function_with_similar_name_not_removed():
# 'test_' function with 'gather' in a variable name, not a call
code = """
def test_similar():
gatherer = "something"
return gatherer
"""
func_node = parse_and_get_funcdef_nodes(code)[0]
result = run_remover_on_funcdef(func_node)
def test_test_function_with_asyncio_gather_as_attribute_of_other_object_not_removed():
# 'test_' function with gather as attribute of something other than asyncio
code = """
def test_other_gather():
obj = SomeClass()
await obj.gather(a(), b())
"""
func_node = parse_and_get_funcdef_nodes(code)[0]
result = run_remover_on_funcdef(func_node)
def test_test_function_with_asyncio_gather_in_comment_not_removed():
# 'test_' function with 'asyncio.gather' only in a comment
code = """
def test_comment():
# await asyncio.gather(a(), b())
return 1
"""
func_node = parse_and_get_funcdef_nodes(code)[0]
result = run_remover_on_funcdef(func_node)
def test_test_function_with_asyncio_gather_in_string_not_removed():
# 'test_' function with 'asyncio.gather' only in a string
code = '''
def test_string():
s = "asyncio.gather(a(), b())"
return s
'''
func_node = parse_and_get_funcdef_nodes(code)[0]
result = run_remover_on_funcdef(func_node)
def test_test_function_with_gather_called_on_different_object_not_removed():
# 'test_' function with gather called on a different object (not asyncio)
code = """
def test_other_obj_gather():
await something.gather(a(), b())
"""
func_node = parse_and_get_funcdef_nodes(code)[0]
result = run_remover_on_funcdef(func_node)
def test_function_with_name_test_but_not_at_start_is_preserved():
# Function with 'test' in the name but not at the start
code = """
def not_a_test_function():
await asyncio.gather(a(), b())
"""
func_node = parse_and_get_funcdef_nodes(code)[0]
result = run_remover_on_funcdef(func_node)
def test_test_function_with_multiple_asyncio_gather_calls_is_removed():
# 'test_' function with multiple asyncio.gather calls
code = """
def test_multi_gather():
await asyncio.gather(a(), b())
await asyncio.gather(c(), d())
"""
func_node = parse_and_get_funcdef_nodes(code)[0]
result = run_remover_on_funcdef(func_node)
def test_test_function_with_gather_as_keyword_argument_not_removed():
# 'test_' function where 'gather' is a keyword argument, not a call
code = """
def test_gather_kwarg():
foo(gather=asyncio.gather)
"""
func_node = parse_and_get_funcdef_nodes(code)[0]
result = run_remover_on_funcdef(func_node)
# 3. Large Scale Test Cases
def test_large_function_body_with_gather_only_once():
# Large function body, gather appears only once
body_lines = [" x = %d" % i for i in range(400)]
body_lines.append(" await asyncio.gather(a(), b())")
body_lines.extend([" y = %d" % i for i in range(400, 800)])
code = "def test_large():\n" + "\n".join(body_lines)
func_node = parse_and_get_funcdef_nodes(code)[0]
result = run_remover_on_funcdef(func_node)
def test_large_function_body_no_gather():
# Large function body, but no gather
body_lines = [" x = %d" % i for i in range(900)]
code = "def test_large_no_gather():\n" + "\n".join(body_lines)
func_node = parse_and_get_funcdef_nodes(code)[0]
result = run_remover_on_funcdef(func_node)
def test_large_number_of_nested_functions():
# Many nested functions, only outermost is 'test_' with gather in deepest
code = "def test_outer():\n"
indent = " "
for i in range(50):
code += f"{indent}def inner_{i}():\n"
indent += " "
code += f"{indent}return asyncio.gather(a(), b())\n"
func_node = parse_and_get_funcdef_nodes(code)[0]
result = run_remover_on_funcdef(func_node)To test or edit this optimization locally git merge codeflash/optimize-pr617-2025-09-02T17.39.22
Click to see suggested changes
| for child_node in ast.walk(node): | |
| if ( | |
| isinstance(child_node, ast.Call) | |
| and isinstance(child_node.func, ast.Attribute) | |
| and isinstance(child_node.func.value, ast.Name) | |
| and child_node.func.value.id == "asyncio" | |
| and child_node.func.attr == "gather" | |
| ): | |
| return True | |
| if ( | |
| isinstance(child_node, ast.Call) | |
| and isinstance(child_node.func, ast.Name) | |
| and child_node.func.id == "gather" | |
| ): | |
| return True | |
| stack = [node] | |
| while stack: | |
| child_node = stack.pop() | |
| if isinstance(child_node, ast.Call): | |
| func = child_node.func | |
| if ( | |
| isinstance(func, ast.Attribute) | |
| and isinstance(func.value, ast.Name) | |
| and func.value.id == "asyncio" | |
| and func.attr == "gather" | |
| ): | |
| return True | |
| if isinstance(func, ast.Name) and func.id == "gather": | |
| return True | |
| stack.extend(ast.iter_child_nodes(child_node)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
⚡️Codeflash found 25% (0.25x) speedup for AsyncIOGatherRemover.visit_AsyncFunctionDef in codeflash/code_utils/instrument_existing_tests.py
⏱️ Runtime : 781 microseconds → 627 microseconds (best of 71 runs)
📝 Explanation and details
The optimization replaces ast.walk() with a custom iterative depth-first search using a stack, which provides early termination when asyncio.gather calls are found.
Key optimizations:
-
Short-circuit traversal: The original code uses
ast.walk(node)which always traverses the entire AST subtree, even after finding a match. The optimized version uses an explicit stack and immediately returnsTrueupon finding the firstgathercall, avoiding unnecessary node visits. -
Iterative over recursive: Replaces the generator-based
ast.walk()with a simplewhileloop and stack operations (stack.pop(),stack.extend()), reducing function call overhead. -
Streamlined control flow: The nested
ifconditions are restructured to checkast.Callonce, then branch to handle the two gather patterns (attribute vs direct call), eliminating redundantisinstancechecks.
Performance impact based on test results:
- Most effective on test functions that contain gather calls (103-136% speedup) because early termination kicks in immediately
- Moderate gains on functions without gather calls (6-10% speedup) due to reduced traversal overhead
- Minimal impact on non-test functions (2-4% speedup) since the name check happens first
The 24% overall speedup comes from avoiding the continuation of AST traversal after matches are found, which is particularly beneficial when gather calls appear early in the function body or in nested structures.
✅ Correctness verification report:
| Test | Status |
|---|---|
| ⚙️ Existing Unit Tests | 🔘 None Found |
| 🌀 Generated Regression Tests | ✅ 74 Passed |
| ⏪ Replay Tests | 🔘 None Found |
| 🔎 Concolic Coverage Tests | 🔘 None Found |
| 📊 Tests Coverage | 100.0% |
🌀 Generated Regression Tests and Runtime
from __future__ import annotations
import ast
# imports
import pytest # used for our unit tests
from codeflash.code_utils.instrument_existing_tests import AsyncIOGatherRemover
# unit tests
def make_async_func(name, body_src):
"""Helper to create an AsyncFunctionDef node from code string."""
src = f"async def {name}():\n"
# Indent body
body_src = "\n".join(" " + line if line.strip() else line for line in body_src.splitlines())
src += body_src
mod = ast.parse(src)
# Return the AsyncFunctionDef node
for node in mod.body:
if isinstance(node, ast.AsyncFunctionDef):
return node
raise ValueError("No AsyncFunctionDef found")
# ------------------ Basic Test Cases ------------------
def test_basic_removes_test_with_asyncio_gather():
# Should remove test_ functions containing asyncio.gather
node = make_async_func("test_example", "await asyncio.gather(a, b)")
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 15.0μs -> 7.40μs (103% faster)
def test_basic_keeps_non_test_with_asyncio_gather():
# Should NOT remove non-test_ functions even if they have asyncio.gather
node = make_async_func("run_example", "await asyncio.gather(a, b)")
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 19.2μs -> 18.8μs (2.46% faster)
def test_basic_keeps_test_without_asyncio_gather():
# Should NOT remove test_ functions if they don't have asyncio.gather
node = make_async_func("test_no_gather", "await something_else(a, b)")
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 30.9μs -> 28.3μs (9.23% faster)
def test_basic_keeps_non_test_without_asyncio_gather():
# Should NOT remove non-test_ functions without asyncio.gather
node = make_async_func("run_no_gather", "await something_else(a, b)")
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 16.7μs -> 16.3μs (2.77% faster)
def test_basic_removes_test_with_gather_imported():
# Should remove test_ functions using imported gather
node = make_async_func("test_imported_gather", "await gather(a, b)")
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 13.9μs -> 6.40μs (117% faster)
# ------------------ Edge Test Cases ------------------
def test_edge_test_name_case_sensitive():
# Should only remove if name starts with lowercase 'test_'
node = make_async_func("Test_example", "await asyncio.gather(a, b)")
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 18.5μs -> 17.9μs (3.31% faster)
def test_edge_gather_as_variable():
# Should NOT remove if 'gather' is used as a variable, not a function
node = make_async_func("test_var_gather", "gather = 5\nawait something_else(a, b)")
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 42.6μs -> 39.2μs (8.74% faster)
def test_edge_gather_in_comment():
# Should NOT remove if 'asyncio.gather' appears only in a comment
node = make_async_func("test_comment_gather", "# await asyncio.gather(a, b)\nawait something_else(a, b)")
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 30.4μs -> 27.6μs (10.2% faster)
def test_edge_gather_in_string():
# Should NOT remove if 'asyncio.gather' appears only in a string
node = make_async_func("test_string_gather", "x = 'await asyncio.gather(a, b)'\nawait something_else(a, b)")
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 42.0μs -> 38.4μs (9.34% faster)
def test_edge_nested_gather_call():
# Should remove if asyncio.gather is nested deep in the function
node = make_async_func("test_nested_gather", "if True:\n await asyncio.gather(a, b)")
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 16.0μs -> 7.49μs (114% faster)
def test_edge_multiple_gather_calls():
# Should remove if multiple gather calls are present
node = make_async_func("test_multi_gather", "await asyncio.gather(a, b)\nawait gather(c, d)")
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 14.7μs -> 6.22μs (136% faster)
def test_edge_gather_as_method_of_other_object():
# Should NOT remove if gather is called as method of something other than asyncio
node = make_async_func("test_other_gather", "await foo.gather(a, b)")
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 33.4μs -> 30.6μs (8.92% faster)
def test_edge_gather_in_lambda():
# Should remove if asyncio.gather is called inside a lambda
node = make_async_func("test_lambda_gather", "x = lambda: asyncio.gather(a, b)\nawait something_else(a, b)")
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 18.8μs -> 13.6μs (38.7% faster)
def test_edge_gather_in_inner_function():
# Should remove if asyncio.gather is called in an inner function
src = (
"def inner():\n"
" return asyncio.gather(a, b)\n"
"await something_else(a, b)"
)
node = make_async_func("test_inner_func_gather", src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 17.6μs -> 13.8μs (26.9% faster)
def test_edge_gather_in_try_except():
# Should remove if asyncio.gather is called inside try/except
src = (
"try:\n"
" await asyncio.gather(a, b)\n"
"except Exception:\n"
" pass"
)
node = make_async_func("test_try_gather", src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 17.4μs -> 10.9μs (58.8% faster)
def test_edge_gather_in_async_for():
# Should remove if asyncio.gather is called inside async for
src = (
"async for x in y:\n"
" await asyncio.gather(a, b)"
)
node = make_async_func("test_async_for_gather", src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 17.6μs -> 7.91μs (123% faster)
def test_edge_gather_in_async_with():
# Should remove if asyncio.gather is called inside async with
src = (
"async with foo:\n"
" await asyncio.gather(a, b)"
)
node = make_async_func("test_async_with_gather", src)
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 17.1μs -> 7.59μs (125% faster)
def test_edge_gather_with_kwargs():
# Should remove if asyncio.gather is called with keyword arguments
node = make_async_func("test_gather_kwargs", "await asyncio.gather(a, b, return_exceptions=True)")
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 13.3μs -> 6.08μs (119% faster)
def test_edge_gather_with_star_args():
# Should remove if asyncio.gather is called with *args
node = make_async_func("test_gather_star", "await asyncio.gather(*tasks)")
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 13.1μs -> 6.16μs (113% faster)
def test_edge_gather_with_alias():
# Should NOT remove if asyncio is imported as something else and used
node = make_async_func("test_alias_gather", "import asyncio as aio\nawait aio.gather(a, b)")
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 37.9μs -> 35.3μs (7.62% faster)
# ------------------ Large Scale Test Cases ------------------
#------------------------------------------------
from __future__ import annotations
import ast
# imports
import pytest # used for our unit tests
from codeflash.code_utils.instrument_existing_tests import AsyncIOGatherRemover
# unit tests
# Helper to parse code and get AsyncFunctionDef node(s)
def get_async_func_nodes(code: str):
"""Parse code and return all AsyncFunctionDef nodes."""
tree = ast.parse(code)
return [n for n in ast.walk(tree) if isinstance(n, ast.AsyncFunctionDef)]
# Basic Test Cases
def test_basic_removal_with_asyncio_gather():
# Should remove test function with asyncio.gather
code = """
import asyncio
async def test_func():
await asyncio.gather(foo(), bar())
return 1
"""
node = get_async_func_nodes(code)[0]
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 12.5μs -> 7.37μs (69.6% faster)
def test_basic_removal_with_gather_direct():
# Should remove test function with direct gather call
code = """
async def test_func():
await gather(foo(), bar())
return 2
"""
node = get_async_func_nodes(code)[0]
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 12.3μs -> 7.04μs (75.1% faster)
def test_basic_no_removal_non_test_func():
# Should NOT remove non-test function even with asyncio.gather
code = """
import asyncio
async def not_a_test():
await asyncio.gather(foo(), bar())
return 3
"""
node = get_async_func_nodes(code)[0]
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 26.0μs -> 25.9μs (0.545% faster)
def test_basic_no_removal_test_func_no_gather():
# Should NOT remove test function without asyncio.gather
code = """
async def test_func():
await foo()
return 4
"""
node = get_async_func_nodes(code)[0]
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 29.1μs -> 27.3μs (6.45% faster)
# Edge Test Cases
def test_edge_removal_gather_in_nested_function():
# Should remove if gather is in a nested function inside test_
code = """
async def test_func():
async def inner():
await asyncio.gather(foo())
await inner()
"""
node = get_async_func_nodes(code)[0]
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 16.1μs -> 10.5μs (52.5% faster)
def test_edge_removal_gather_in_lambda():
# Should remove if gather is called inside a lambda in test_
code = """
async def test_func():
x = lambda: asyncio.gather(foo())
await foo()
"""
node = get_async_func_nodes(code)[0]
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 15.4μs -> 9.83μs (56.6% faster)
def test_edge_removal_gather_in_comprehension():
# Should remove if gather is used in a list comprehension
code = """
async def test_func():
results = [asyncio.gather(foo()) for _ in range(2)]
await foo()
"""
node = get_async_func_nodes(code)[0]
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 14.1μs -> 15.4μs (8.54% slower)
def test_edge_removal_gather_in_decorator():
# Should remove if gather is used in a decorator
code = """
def dec(func):
async def wrapper(*a, **k):
await asyncio.gather(foo())
return await func(*a, **k)
return wrapper
@dec
async def test_func():
await foo()
"""
node = get_async_func_nodes(code)[0]
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 25.4μs -> 23.2μs (9.41% faster)
def test_edge_removal_gather_as_variable_name():
# Should NOT remove if 'gather' is used as variable name
code = """
async def test_func():
gather = 42
await foo()
"""
node = get_async_func_nodes(code)[0]
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 33.4μs -> 31.4μs (6.31% faster)
def test_edge_removal_gather_as_attribute_not_asyncio():
# Should NOT remove if gather is called as attribute of something else
code = """
async def test_func():
await mylib.gather(foo())
"""
node = get_async_func_nodes(code)[0]
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 29.7μs -> 28.6μs (3.85% faster)
def test_edge_removal_gather_in_comment_only():
# Should NOT remove if gather is only mentioned in a comment
code = """
async def test_func():
# await asyncio.gather(foo())
await foo()
"""
node = get_async_func_nodes(code)[0]
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 21.5μs -> 20.2μs (6.41% faster)
def test_edge_removal_gather_in_docstring_only():
# Should NOT remove if gather is only mentioned in a docstring
code = '''
async def test_func():
"""This function uses asyncio.gather in the future."""
await foo()
'''
node = get_async_func_nodes(code)[0]
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 28.2μs -> 26.6μs (5.91% faster)
def test_edge_removal_test_func_name_variants():
# Should remove if name starts with test_ exactly
code = """
async def test_foo():
await asyncio.gather(foo())
"""
node = get_async_func_nodes(code)[0]
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 10.6μs -> 5.32μs (99.3% faster)
# Should NOT remove if name does not start with test_
code2 = """
async def testfoo():
await asyncio.gather(foo())
"""
node2 = get_async_func_nodes(code2)[0]
remover2 = AsyncIOGatherRemover()
codeflash_output = remover2.visit_AsyncFunctionDef(node2); result2 = codeflash_output # 16.3μs -> 15.6μs (4.17% faster)
def test_edge_removal_gather_called_with_no_args():
# Should remove even if gather is called with no arguments
code = """
async def test_func():
await asyncio.gather()
"""
node = get_async_func_nodes(code)[0]
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 11.1μs -> 5.79μs (91.7% faster)
def test_edge_removal_gather_called_as_keyword():
# Should remove if gather is called with keywords only
code = """
async def test_func():
await asyncio.gather(a=foo())
"""
node = get_async_func_nodes(code)[0]
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 10.7μs -> 5.55μs (93.1% faster)
# Large Scale Test Cases
def test_large_scale_nested_gather_calls():
# Should remove deeply nested gather calls in test_ functions
code = """
async def test_func():
async def a():
async def b():
await asyncio.gather(foo())
await b()
await a()
"""
node = get_async_func_nodes(code)[0]
remover = AsyncIOGatherRemover()
codeflash_output = remover.visit_AsyncFunctionDef(node); result = codeflash_output # 22.8μs -> 15.6μs (46.3% faster)To test or edit this optimization locally git merge codeflash/optimize-pr617-2025-09-02T17.47.10
Click to see suggested changes
| for child_node in ast.walk(node): | |
| if ( | |
| isinstance(child_node, ast.Call) | |
| and isinstance(child_node.func, ast.Attribute) | |
| and isinstance(child_node.func.value, ast.Name) | |
| and child_node.func.value.id == "asyncio" | |
| and child_node.func.attr == "gather" | |
| ): | |
| return True | |
| if ( | |
| isinstance(child_node, ast.Call) | |
| and isinstance(child_node.func, ast.Name) | |
| and child_node.func.id == "gather" | |
| ): | |
| return True | |
| stack = [node] | |
| while stack: | |
| current = stack.pop() | |
| if isinstance(current, ast.Call): | |
| if isinstance(current.func, ast.Attribute): | |
| if ( | |
| isinstance(current.func.value, ast.Name) | |
| and current.func.value.id == "asyncio" | |
| and current.func.attr == "gather" | |
| ): | |
| return True | |
| elif isinstance(current.func, ast.Name) and current.func.id == "gather": | |
| return True | |
| stack.extend(ast.iter_child_nodes(current)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
⚡️Codeflash found 71% (0.71x) speedup for
AsyncIOGatherRemover._contains_asyncio_gatherincodeflash/code_utils/instrument_existing_tests.py⏱️ Runtime :
23.2 milliseconds→13.6 milliseconds(best of73runs)📝 Explanation and details
The optimization replaces
ast.walk()with a manual stack-based traversal that enables early termination. Here's why this delivers a 70% speedup:Key optimization: Early return prevents unnecessary traversal
ast.walk()always visits every node in the AST, even after finding a matchasyncio.gatherorgatheris foundPerformance benefits by test case type:
ast.walk()Technical improvements:
ast.iter_child_nodes()with explicit stack instead of generator-basedast.walk()isinstancecalls by storingchild_node.funcin a variableThe optimization is most effective when
gathercalls exist in the code, as it can terminate search early rather than examining the entire AST structure.✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
To test or edit this optimization locally
git merge codeflash/optimize-pr617-2025-09-02T17.32.23Click to see suggested changes