Skip to content

Commit 7c53b19

Browse files
committed
first pass async wrapper
1 parent d1bc897 commit 7c53b19

File tree

6 files changed

+581
-225
lines changed

6 files changed

+581
-225
lines changed

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ def inject_profiling_into_existing_test(
345345
tree = InjectPerfOnly(func, test_module_path, test_framework, call_positions, mode=mode).visit(tree)
346346
new_imports = [
347347
ast.Import(names=[ast.alias(name="time")]),
348+
ast.Import(names=[ast.alias(name="inspect")]),
348349
ast.Import(names=[ast.alias(name="gc")]),
349350
ast.Import(names=[ast.alias(name="os")]),
350351
]
@@ -354,7 +355,7 @@ def inject_profiling_into_existing_test(
354355
)
355356
if test_framework == "unittest":
356357
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
357-
tree.body = [*new_imports, create_wrapper_function(mode), *tree.body]
358+
tree.body = [*new_imports, create_wrapper_function(mode), create_async_wrapper_inner(), *tree.body]
358359
return True, isort.code(ast.unparse(tree), float_to_top=True)
359360

360361

@@ -534,13 +535,39 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
534535
),
535536
lineno=lineno + 11,
536537
),
537-
ast.Assign(
538-
targets=[ast.Name(id="return_value", ctx=ast.Store())],
539-
value=ast.Call(
540-
func=ast.Name(id="wrapped", ctx=ast.Load()),
541-
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
542-
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
538+
ast.If(
539+
test=ast.Call(
540+
func=ast.Attribute(
541+
value=ast.Name(id="inspect", ctx=ast.Load()), attr="iscoroutinefunction", ctx=ast.Load()
542+
),
543+
args=[ast.Name(id="wrapped", ctx=ast.Load())],
544+
keywords=[],
543545
),
546+
body=[
547+
ast.Assign(
548+
targets=[ast.Name(id="return_value", ctx=ast.Store())],
549+
value=ast.Call(
550+
func=ast.Name(id="codeflash_async_wrap_inner", ctx=ast.Load()),
551+
args=[
552+
ast.Name(id="wrapped", ctx=ast.Load()),
553+
ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load()),
554+
],
555+
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
556+
),
557+
lineno=lineno + 12,
558+
)
559+
],
560+
orelse=[
561+
ast.Assign(
562+
targets=[ast.Name(id="return_value", ctx=ast.Store())],
563+
value=ast.Call(
564+
func=ast.Name(id="wrapped", ctx=ast.Load()),
565+
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
566+
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
567+
),
568+
lineno=lineno + 12,
569+
)
570+
],
544571
lineno=lineno + 12,
545572
),
546573
ast.Assign(
@@ -729,3 +756,32 @@ def create_wrapper_function(mode: TestingMode = TestingMode.BEHAVIOR) -> ast.Fun
729756
decorator_list=[],
730757
returns=None,
731758
)
759+
760+
761+
def create_async_wrapper_inner() -> ast.AsyncFunctionDef:
762+
return ast.AsyncFunctionDef(
763+
name="codeflash_async_wrap_inner",
764+
args=ast.arguments(
765+
args=[ast.arg(arg="wrapped", annotation=None)],
766+
vararg=ast.arg(arg="args"),
767+
kwarg=ast.arg(arg="kwargs"),
768+
posonlyargs=[],
769+
kwonlyargs=[],
770+
kw_defaults=[],
771+
defaults=[],
772+
),
773+
body=[
774+
ast.Return(
775+
value=ast.Await(
776+
value=ast.Call(
777+
func=ast.Name(id="wrapped", ctx=ast.Load()),
778+
args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
779+
keywords=[ast.keyword(arg=None, value=ast.Name(id="kwargs", ctx=ast.Load()))],
780+
)
781+
)
782+
)
783+
],
784+
decorator_list=[],
785+
returns=None,
786+
lineno=1,
787+
)

codeflash/discovery/functions_to_optimize.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,10 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None:
421421
# iterate over the class methods
422422
if node.name == self.class_name:
423423
for body_node in node.body:
424-
if isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef)) and body_node.name == self.function_name:
424+
if (
425+
isinstance(body_node, (ast.FunctionDef, ast.AsyncFunctionDef))
426+
and body_node.name == self.function_name
427+
):
425428
self.is_top_level = True
426429
if any(
427430
isinstance(decorator, ast.Name) and decorator.id == "classmethod"

codeflash/optimization/function_optimizer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
diff_length,
3636
file_name_from_test_module_name,
3737
get_run_tmp_file,
38-
has_any_async_functions,
3938
module_name_from_file_path,
4039
restore_conftest,
4140
)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ dev = [
7474
"types-unidiff>=0.7.0.20240505,<0.8",
7575
"uv>=0.6.2",
7676
"pre-commit>=4.2.0,<5",
77+
"pytest-asyncio>=1.1.0",
7778
]
7879

7980
[tool.hatch.build.targets.sdist]
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
import ast
2+
import asyncio
3+
import textwrap
4+
from pathlib import Path
5+
6+
import pytest
7+
8+
from codeflash.code_utils.instrument_existing_tests import (
9+
InjectPerfOnly,
10+
create_async_wrapper_inner,
11+
create_wrapper_function,
12+
inject_profiling_into_existing_test,
13+
)
14+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
15+
from codeflash.models.models import CodePosition, TestingMode
16+
17+
18+
def test_create_async_wrapper_inner():
19+
async_wrapper = create_async_wrapper_inner()
20+
21+
assert isinstance(async_wrapper, ast.AsyncFunctionDef)
22+
assert async_wrapper.name == "codeflash_async_wrap_inner"
23+
24+
assert len(async_wrapper.body) == 1
25+
assert isinstance(async_wrapper.body[0], ast.Return)
26+
assert isinstance(async_wrapper.body[0].value, ast.Await)
27+
28+
29+
def test_wrapper_function_includes_async_check():
30+
wrapper = create_wrapper_function(TestingMode.PERFORMANCE)
31+
32+
async_check_found = False
33+
for node in ast.walk(wrapper):
34+
if isinstance(node, ast.If):
35+
if isinstance(node.test, ast.Call):
36+
if (
37+
isinstance(node.test.func, ast.Attribute)
38+
and node.test.func.attr == "iscoroutinefunction"
39+
):
40+
async_check_found = True
41+
assert len(node.body) > 0
42+
assert len(node.orelse) > 0
43+
for stmt in node.body:
44+
if isinstance(stmt, ast.Assign):
45+
assert hasattr(stmt, "lineno")
46+
for stmt in node.orelse:
47+
if isinstance(stmt, ast.Assign):
48+
assert hasattr(stmt, "lineno")
49+
break
50+
51+
assert async_check_found, "Async check not found in wrapper function"
52+
53+
54+
def test_inject_profiling_with_async_function():
55+
test_code = textwrap.dedent("""
56+
import asyncio
57+
from my_module import async_process_data
58+
59+
async def test_async_function():
60+
result = await async_process_data("test")
61+
assert result == "processed"
62+
""")
63+
64+
test_file = Path("/tmp/test_async.py")
65+
test_file.write_text(test_code)
66+
67+
# Create function to optimize
68+
function = FunctionToOptimize(
69+
function_name="async_process_data",
70+
parents=[],
71+
file_path=Path("my_module.py"),
72+
starting_line=1,
73+
ending_line=10,
74+
)
75+
76+
call_positions = [CodePosition(line_no=5, col_no=19)]
77+
78+
success, modified_code = inject_profiling_into_existing_test(
79+
test_file,
80+
call_positions,
81+
function,
82+
Path("/tmp"),
83+
"pytest",
84+
TestingMode.PERFORMANCE,
85+
)
86+
87+
assert success
88+
assert modified_code is not None
89+
90+
assert "codeflash_async_wrap_inner" in modified_code
91+
assert "inspect.iscoroutinefunction" in modified_code
92+
assert "import inspect" in modified_code
93+
94+
try:
95+
ast.parse(modified_code)
96+
except SyntaxError:
97+
pytest.fail(f"Modified code has syntax errors:\n{modified_code}")
98+
99+
test_file.unlink()
100+
101+
102+
def test_async_wrapper_preserves_return_value():
103+
test_code = textwrap.dedent("""
104+
import asyncio
105+
106+
async def async_function(x):
107+
await asyncio.sleep(0.01)
108+
return x * 2
109+
110+
async def test_async_return():
111+
result = await async_function(5)
112+
assert result == 10
113+
""")
114+
115+
tree = ast.parse(test_code)
116+
117+
function = FunctionToOptimize(
118+
function_name="async_function",
119+
parents=[],
120+
file_path=Path("test.py"),
121+
starting_line=3,
122+
ending_line=5,
123+
)
124+
125+
call_positions = [CodePosition(line_no=8, col_no=19)]
126+
127+
visitor = InjectPerfOnly(
128+
function, "test_module", "pytest", call_positions, TestingMode.PERFORMANCE
129+
)
130+
131+
modified_tree = visitor.visit(tree)
132+
133+
# Add the wrapper functions
134+
modified_tree.body = [
135+
ast.Import(names=[ast.alias(name="inspect")]),
136+
create_wrapper_function(TestingMode.PERFORMANCE),
137+
create_async_wrapper_inner(),
138+
*modified_tree.body,
139+
]
140+
141+
try:
142+
modified_code = ast.unparse(modified_tree)
143+
assert "codeflash_async_wrap_inner" in modified_code
144+
except AttributeError as e:
145+
pytest.fail(f"AST unparsing failed with AttributeError: {e}")
146+
147+
148+
@pytest.mark.asyncio
149+
async def test_async_wrapper_execution():
150+
"""Test that the async wrapper can be executed correctly."""
151+
152+
# Create a simple async function to wrap
153+
async def test_func(x, y=10):
154+
await asyncio.sleep(0.01)
155+
return x + y
156+
157+
# Create the wrapper code dynamically
158+
wrapper_code = textwrap.dedent("""
159+
import asyncio
160+
import inspect
161+
162+
async def codeflash_async_wrap_inner(wrapped, *args, **kwargs):
163+
return await wrapped(*args, **kwargs)
164+
165+
async def test_wrapper():
166+
result = await codeflash_async_wrap_inner(test_func, 5, y=15)
167+
return result
168+
""")
169+
170+
# Execute the wrapper
171+
namespace = {"test_func": test_func}
172+
exec(wrapper_code, namespace)
173+
174+
# Run the test
175+
result = await namespace["test_wrapper"]()
176+
assert result == 20
177+
178+
179+
def test_mixed_sync_async_instrumentation():
180+
"""Test that both sync and async functions can be instrumented in the same test."""
181+
test_code = textwrap.dedent("""
182+
import asyncio
183+
184+
def sync_function(x):
185+
return x * 2
186+
187+
async def async_function(x):
188+
await asyncio.sleep(0.01)
189+
return x * 3
190+
191+
async def test_mixed():
192+
sync_result = sync_function(5)
193+
async_result = await async_function(5)
194+
assert sync_result == 10
195+
assert async_result == 15
196+
""")
197+
198+
tree = ast.parse(test_code)
199+
200+
sync_function = FunctionToOptimize(
201+
function_name="sync_function",
202+
parents=[],
203+
file_path=Path("test.py"),
204+
starting_line=3,
205+
ending_line=4,
206+
)
207+
208+
call_positions = [
209+
CodePosition(line_no=11, col_no=19),
210+
CodePosition(line_no=12, col_no=25),
211+
]
212+
213+
visitor = InjectPerfOnly(
214+
sync_function,
215+
"test_module",
216+
"pytest",
217+
call_positions,
218+
TestingMode.PERFORMANCE,
219+
)
220+
221+
modified_tree = visitor.visit(tree)
222+
223+
modified_tree.body = [
224+
ast.Import(names=[ast.alias(name="time")]),
225+
ast.Import(names=[ast.alias(name="inspect")]),
226+
ast.Import(names=[ast.alias(name="gc")]),
227+
ast.Import(names=[ast.alias(name="os")]),
228+
create_wrapper_function(TestingMode.PERFORMANCE),
229+
create_async_wrapper_inner(),
230+
*modified_tree.body,
231+
]
232+
233+
modified_code = ast.unparse(modified_tree)
234+
# Both wrapper functions should be present
235+
assert "codeflash_wrap" in modified_code
236+
assert "codeflash_async_wrap_inner" in modified_code
237+
assert "inspect.iscoroutinefunction" in modified_code

0 commit comments

Comments
 (0)