Skip to content

Commit d1bc897

Browse files
committed
discover async functions properly
1 parent 353484f commit d1bc897

File tree

2 files changed

+301
-2
lines changed

2 files changed

+301
-2
lines changed

codeflash/discovery/functions_to_optimize.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,11 +404,24 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
404404
)
405405
)
406406

407+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
408+
if self.class_name is None and node.name == self.function_name:
409+
self.is_top_level = True
410+
self.function_has_args = any(
411+
(
412+
bool(node.args.args),
413+
bool(node.args.kwonlyargs),
414+
bool(node.args.kwarg),
415+
bool(node.args.posonlyargs),
416+
bool(node.args.vararg),
417+
)
418+
)
419+
407420
def visit_ClassDef(self, node: ast.ClassDef) -> None:
408421
# iterate over the class methods
409422
if node.name == self.class_name:
410423
for body_node in node.body:
411-
if isinstance(body_node, ast.FunctionDef) and body_node.name == self.function_name:
424+
if isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef)) and body_node.name == self.function_name:
412425
self.is_top_level = True
413426
if any(
414427
isinstance(decorator, ast.Name) and decorator.id == "classmethod"
@@ -426,7 +439,7 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
426439
# This way, if we don't have the class name, we can still find the static method
427440
for body_node in node.body:
428441
if (
429-
isinstance(body_node, ast.FunctionDef)
442+
isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
430443
and body_node.name == self.function_name
431444
and body_node.lineno in {self.line_no, self.line_no + 1}
432445
and any(
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
import tempfile
2+
from pathlib import Path
3+
import pytest
4+
5+
from codeflash.discovery.functions_to_optimize import (
6+
find_all_functions_in_file,
7+
get_functions_to_optimize,
8+
inspect_top_level_functions_or_methods,
9+
)
10+
from codeflash.verification.verification_utils import TestConfig
11+
12+
13+
@pytest.fixture
14+
def temp_dir():
15+
with tempfile.TemporaryDirectory() as temp:
16+
yield Path(temp)
17+
18+
19+
def test_async_function_detection(temp_dir):
20+
async_function = """
21+
async def async_function_with_return():
22+
await some_async_operation()
23+
return 42
24+
25+
async def async_function_without_return():
26+
await some_async_operation()
27+
print("No return")
28+
29+
def regular_function():
30+
return 10
31+
"""
32+
33+
file_path = temp_dir / "test_file.py"
34+
file_path.write_text(async_function)
35+
functions_found = find_all_functions_in_file(file_path)
36+
37+
function_names = [fn.function_name for fn in functions_found[file_path]]
38+
39+
assert "async_function_with_return" in function_names
40+
assert "regular_function" in function_names
41+
assert "async_function_without_return" not in function_names
42+
43+
44+
def test_async_method_in_class(temp_dir):
45+
code_with_async_method = """
46+
class AsyncClass:
47+
async def async_method(self):
48+
await self.do_something()
49+
return "result"
50+
51+
async def async_method_no_return(self):
52+
await self.do_something()
53+
pass
54+
55+
def sync_method(self):
56+
return "sync result"
57+
"""
58+
59+
file_path = temp_dir / "test_file.py"
60+
file_path.write_text(code_with_async_method)
61+
functions_found = find_all_functions_in_file(file_path)
62+
63+
found_functions = functions_found[file_path]
64+
function_names = [fn.function_name for fn in found_functions]
65+
qualified_names = [fn.qualified_name for fn in found_functions]
66+
67+
assert "async_method" in function_names
68+
assert "AsyncClass.async_method" in qualified_names
69+
70+
assert "sync_method" in function_names
71+
assert "AsyncClass.sync_method" in qualified_names
72+
73+
assert "async_method_no_return" not in function_names
74+
75+
76+
def test_nested_async_functions(temp_dir):
77+
nested_async = """
78+
async def outer_async():
79+
async def inner_async():
80+
return "inner"
81+
82+
result = await inner_async()
83+
return result
84+
85+
def outer_sync():
86+
async def inner_async():
87+
return "inner from sync"
88+
89+
return inner_async
90+
"""
91+
92+
file_path = temp_dir / "test_file.py"
93+
file_path.write_text(nested_async)
94+
functions_found = find_all_functions_in_file(file_path)
95+
96+
function_names = [fn.function_name for fn in functions_found[file_path]]
97+
98+
assert "outer_async" in function_names
99+
assert "outer_sync" in function_names
100+
assert "inner_async" not in function_names
101+
102+
103+
def test_async_staticmethod_and_classmethod(temp_dir):
104+
async_decorators = """
105+
class MyClass:
106+
@staticmethod
107+
async def async_static_method():
108+
await some_operation()
109+
return "static result"
110+
111+
@classmethod
112+
async def async_class_method(cls):
113+
await cls.some_operation()
114+
return "class result"
115+
116+
@property
117+
async def async_property(self):
118+
return await self.get_value()
119+
"""
120+
121+
file_path = temp_dir / "test_file.py"
122+
file_path.write_text(async_decorators)
123+
functions_found = find_all_functions_in_file(file_path)
124+
125+
function_names = [fn.function_name for fn in functions_found[file_path]]
126+
127+
assert "async_static_method" in function_names
128+
assert "async_class_method" in function_names
129+
130+
assert "async_property" not in function_names
131+
132+
133+
def test_async_generator_functions(temp_dir):
134+
async_generators = """
135+
async def async_generator_with_return():
136+
for i in range(10):
137+
yield i
138+
return "done"
139+
140+
async def async_generator_no_return():
141+
for i in range(10):
142+
yield i
143+
144+
async def regular_async_with_return():
145+
result = await compute()
146+
return result
147+
"""
148+
149+
file_path = temp_dir / "test_file.py"
150+
file_path.write_text(async_generators)
151+
functions_found = find_all_functions_in_file(file_path)
152+
153+
function_names = [fn.function_name for fn in functions_found[file_path]]
154+
155+
assert "async_generator_with_return" in function_names
156+
assert "regular_async_with_return" in function_names
157+
assert "async_generator_no_return" not in function_names
158+
159+
160+
def test_inspect_async_top_level_functions(temp_dir):
161+
code = """
162+
async def top_level_async():
163+
return 42
164+
165+
class AsyncContainer:
166+
async def async_method(self):
167+
async def nested_async():
168+
return 1
169+
return await nested_async()
170+
171+
@staticmethod
172+
async def async_static():
173+
return "static"
174+
175+
@classmethod
176+
async def async_classmethod(cls):
177+
return "classmethod"
178+
"""
179+
180+
file_path = temp_dir / "test_file.py"
181+
file_path.write_text(code)
182+
183+
result = inspect_top_level_functions_or_methods(file_path, "top_level_async")
184+
assert result.is_top_level
185+
186+
result = inspect_top_level_functions_or_methods(file_path, "async_method", class_name="AsyncContainer")
187+
assert result.is_top_level
188+
189+
result = inspect_top_level_functions_or_methods(file_path, "nested_async", class_name="AsyncContainer")
190+
assert not result.is_top_level
191+
192+
result = inspect_top_level_functions_or_methods(file_path, "async_static", class_name="AsyncContainer")
193+
assert result.is_top_level
194+
assert result.is_staticmethod
195+
196+
result = inspect_top_level_functions_or_methods(file_path, "async_classmethod", class_name="AsyncContainer")
197+
assert result.is_top_level
198+
assert result.is_classmethod
199+
200+
201+
def test_get_functions_to_optimize_with_async(temp_dir):
202+
mixed_code = """
203+
async def async_func_one():
204+
return await operation_one()
205+
206+
def sync_func_one():
207+
return operation_one()
208+
209+
async def async_func_two():
210+
print("no return")
211+
212+
class MixedClass:
213+
async def async_method(self):
214+
return await self.operation()
215+
216+
def sync_method(self):
217+
return self.operation()
218+
"""
219+
220+
file_path = temp_dir / "test_file.py"
221+
file_path.write_text(mixed_code)
222+
223+
test_config = TestConfig(
224+
tests_root="tests",
225+
project_root_path=".",
226+
test_framework="pytest",
227+
tests_project_rootdir=Path()
228+
)
229+
230+
functions, functions_count, _ = get_functions_to_optimize(
231+
optimize_all=None,
232+
replay_test=None,
233+
file=file_path,
234+
only_get_this_function=None,
235+
test_cfg=test_config,
236+
ignore_paths=[],
237+
project_root=file_path.parent,
238+
module_root=file_path.parent,
239+
)
240+
241+
assert functions_count == 4
242+
243+
function_names = [fn.function_name for fn in functions[file_path]]
244+
assert "async_func_one" in function_names
245+
assert "sync_func_one" in function_names
246+
assert "async_method" in function_names
247+
assert "sync_method" in function_names
248+
249+
assert "async_func_two" not in function_names
250+
251+
252+
def test_async_function_parents(temp_dir):
253+
complex_structure = """
254+
class OuterClass:
255+
async def outer_method(self):
256+
return 1
257+
258+
class InnerClass:
259+
async def inner_method(self):
260+
return 2
261+
262+
async def module_level_async():
263+
class LocalClass:
264+
async def local_method(self):
265+
return 3
266+
return LocalClass()
267+
"""
268+
269+
file_path = temp_dir / "test_file.py"
270+
file_path.write_text(complex_structure)
271+
functions_found = find_all_functions_in_file(file_path)
272+
273+
found_functions = functions_found[file_path]
274+
275+
for fn in found_functions:
276+
if fn.function_name == "outer_method":
277+
assert len(fn.parents) == 1
278+
assert fn.parents[0].name == "OuterClass"
279+
assert fn.qualified_name == "OuterClass.outer_method"
280+
elif fn.function_name == "inner_method":
281+
assert len(fn.parents) == 2
282+
assert fn.parents[0].name == "OuterClass"
283+
assert fn.parents[1].name == "InnerClass"
284+
elif fn.function_name == "module_level_async":
285+
assert len(fn.parents) == 0
286+
assert fn.qualified_name == "module_level_async"

0 commit comments

Comments
 (0)