Skip to content

Commit e8699ec

Browse files
committed
temp
1 parent 593b7f2 commit e8699ec

File tree

3 files changed

+172
-0
lines changed

3 files changed

+172
-0
lines changed

code_to_optimize/async_adder.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import asyncio
2+
3+
4+
async def async_add(a, b):
5+
"""Simple async function that adds two numbers."""
6+
await asyncio.sleep(0.001) # Simulate some async work
7+
print(f"codeflash stdout: Adding {a} + {b}")
8+
result = a + b
9+
print(f"result: {result}")
10+
return result

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ def find_and_update_line_node(
7373
self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None
7474
) -> Iterable[ast.stmt] | None:
7575
call_node = None
76+
await_node = None
77+
7678
for node in ast.walk(test_node):
7779
if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions):
7880
call_node = node
@@ -121,6 +123,60 @@ def find_and_update_line_node(
121123
]
122124
node.keywords = call_node.keywords
123125
break
126+
127+
# Check for awaited function calls
128+
elif isinstance(node, ast.Await) and isinstance(node.value, ast.Call) and node_in_call_position(node.value, self.call_positions):
129+
call_node = node.value
130+
await_node = node
131+
if isinstance(call_node.func, ast.Name):
132+
function_name = call_node.func.id
133+
call_node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
134+
call_node.args = [
135+
ast.Name(id=function_name, ctx=ast.Load()),
136+
ast.Constant(value=self.module_path),
137+
ast.Constant(value=test_class_name or None),
138+
ast.Constant(value=node_name),
139+
ast.Constant(value=self.function_object.qualified_name),
140+
ast.Constant(value=index),
141+
ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
142+
*(
143+
[ast.Name(id="codeflash_cur", ctx=ast.Load()), ast.Name(id="codeflash_con", ctx=ast.Load())]
144+
if self.mode == TestingMode.BEHAVIOR
145+
else []
146+
),
147+
*call_node.args,
148+
]
149+
call_node.keywords = call_node.keywords
150+
# Keep the await wrapper around the modified call
151+
await_node.value = call_node
152+
break
153+
if isinstance(call_node.func, ast.Attribute):
154+
function_to_test = call_node.func.attr
155+
if function_to_test == self.function_object.function_name:
156+
function_name = ast.unparse(call_node.func)
157+
call_node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
158+
call_node.args = [
159+
ast.Name(id=function_name, ctx=ast.Load()),
160+
ast.Constant(value=self.module_path),
161+
ast.Constant(value=test_class_name or None),
162+
ast.Constant(value=node_name),
163+
ast.Constant(value=self.function_object.qualified_name),
164+
ast.Constant(value=index),
165+
ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
166+
*(
167+
[
168+
ast.Name(id="codeflash_cur", ctx=ast.Load()),
169+
ast.Name(id="codeflash_con", ctx=ast.Load()),
170+
]
171+
if self.mode == TestingMode.BEHAVIOR
172+
else []
173+
),
174+
*call_node.args,
175+
]
176+
call_node.keywords = call_node.keywords
177+
# Keep the await wrapper around the modified call
178+
await_node.value = call_node
179+
break
124180

125181
if call_node is None:
126182
return None
@@ -131,9 +187,35 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
131187
for inner_node in ast.walk(node):
132188
if isinstance(inner_node, ast.FunctionDef):
133189
self.visit_FunctionDef(inner_node, node.name)
190+
elif isinstance(inner_node, ast.AsyncFunctionDef):
191+
self.visit_AsyncFunctionDef(inner_node, node.name)
134192

135193
return node
136194

195+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef, test_class_name: str | None = None) -> ast.AsyncFunctionDef:
196+
"""Handle async function definitions by converting to sync and back."""
197+
# Convert to sync FunctionDef, process it, then convert back
198+
sync_node = ast.FunctionDef(
199+
name=node.name,
200+
args=node.args,
201+
body=node.body,
202+
decorator_list=node.decorator_list,
203+
returns=node.returns,
204+
lineno=node.lineno,
205+
col_offset=node.col_offset if hasattr(node, 'col_offset') else 0
206+
)
207+
processed_sync = self.visit_FunctionDef(sync_node, test_class_name)
208+
# Convert back to AsyncFunctionDef
209+
return ast.AsyncFunctionDef(
210+
name=processed_sync.name,
211+
args=processed_sync.args,
212+
body=processed_sync.body,
213+
decorator_list=processed_sync.decorator_list,
214+
returns=processed_sync.returns,
215+
lineno=processed_sync.lineno,
216+
col_offset=processed_sync.col_offset if hasattr(processed_sync, 'col_offset') else 0
217+
)
218+
137219
def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None = None) -> ast.FunctionDef:
138220
if node.name.startswith("test_"):
139221
did_update = False

tests/test_instrument_all_and_run.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,86 @@ def test_sort():
255255
test_path_perf.unlink(missing_ok=True)
256256

257257

258+
def test_async_function_behavior_results() -> None:
259+
"""Test that async_codeflash_wrap_string is used for async functions."""
260+
code = """import asyncio
261+
from code_to_optimize.async_adder import async_add
262+
263+
264+
async def test_async_add():
265+
result = await async_add(2, 3)
266+
assert result == 5"""
267+
268+
expected = (
269+
"""import gc
270+
import os
271+
import sqlite3
272+
import time
273+
274+
import dill as pickle
275+
276+
import asyncio
277+
from code_to_optimize.async_adder import async_add
278+
279+
280+
"""
281+
+ async_codeflash_wrap_string
282+
+ """
283+
async def test_async_add():
284+
codeflash_loop_index = int(os.environ['CODEFLASH_LOOP_INDEX'])
285+
codeflash_iteration = os.environ['CODEFLASH_TEST_ITERATION']
286+
codeflash_con = sqlite3.connect(f'{tmp_dir_path}_{{codeflash_iteration}}.sqlite')
287+
codeflash_cur = codeflash_con.cursor()
288+
codeflash_cur.execute('CREATE TABLE IF NOT EXISTS test_results (test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, runtime INTEGER, return_value BLOB, verification_type TEXT)')
289+
result = await codeflash_wrap(async_add, '{module_path}', None, 'test_async_add', 'async_add', '1', codeflash_loop_index, codeflash_cur, codeflash_con, 2, 3)
290+
assert result == 5
291+
codeflash_con.close()
292+
"""
293+
)
294+
295+
test_path = (
296+
Path(__file__).parent.resolve()
297+
/ "../code_to_optimize/tests/pytest/test_async_adder_behavior_temp.py"
298+
).resolve()
299+
test_path_perf = (
300+
Path(__file__).parent.resolve()
301+
/ "../code_to_optimize/tests/pytest/test_async_adder_perf_temp.py"
302+
).resolve()
303+
fto_path = (Path(__file__).parent.resolve() / "../code_to_optimize/async_adder.py").resolve()
304+
original_code = fto_path.read_text("utf-8")
305+
306+
try:
307+
with test_path.open("w") as f:
308+
f.write(code)
309+
310+
tests_root = (Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/").resolve()
311+
project_root_path = (Path(__file__).parent / "..").resolve()
312+
original_cwd = Path.cwd()
313+
run_cwd = Path(__file__).parent.parent.resolve()
314+
func = FunctionToOptimize(function_name="async_add", parents=[], file_path=Path(fto_path), is_async=True)
315+
os.chdir(run_cwd)
316+
success, new_test = inject_profiling_into_existing_test(
317+
test_path,
318+
[CodePosition(6, 19)],
319+
func,
320+
project_root_path,
321+
"pytest",
322+
mode=TestingMode.BEHAVIOR,
323+
is_async=True,
324+
)
325+
os.chdir(original_cwd)
326+
assert success
327+
assert new_test is not None
328+
assert "await wrapped(*args, **kwargs)" in new_test
329+
assert "async def codeflash_wrap" in new_test
330+
assert "await codeflash_wrap(async_add" in new_test
331+
332+
finally:
333+
fto_path.write_text(original_code, "utf-8")
334+
test_path.unlink(missing_ok=True)
335+
test_path_perf.unlink(missing_ok=True)
336+
337+
258338
def test_class_method_full_instrumentation() -> None:
259339
code = """from code_to_optimize.bubble_sort_method import BubbleSorter
260340

0 commit comments

Comments
 (0)