Skip to content

Commit 2a2a3a3

Browse files
committed
add tests for async wrapper
1 parent 4dfda09 commit 2a2a3a3

File tree

1 file changed

+286
-0
lines changed

1 file changed

+286
-0
lines changed
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import os
5+
import sqlite3
6+
import tempfile
7+
from pathlib import Path
8+
9+
import pytest
10+
import dill as pickle
11+
12+
from codeflash.code_utils.codeflash_wrap_decorator import (
13+
codeflash_behavior_async,
14+
codeflash_performance_async,
15+
)
16+
from codeflash.verification.codeflash_capture import VerificationType
17+
18+
19+
class TestAsyncWrapperSQLiteValidation:
20+
21+
@pytest.fixture
22+
def test_env_setup(self):
23+
original_env = {}
24+
test_env = {
25+
"CODEFLASH_LOOP_INDEX": "1",
26+
"CODEFLASH_TEST_ITERATION": "0",
27+
}
28+
29+
for key, value in test_env.items():
30+
original_env[key] = os.environ.get(key)
31+
os.environ[key] = value
32+
33+
yield test_env
34+
35+
for key, original_value in original_env.items():
36+
if original_value is None:
37+
os.environ.pop(key, None)
38+
else:
39+
os.environ[key] = original_value
40+
41+
@pytest.fixture
42+
def temp_db_path(self, test_env_setup):
43+
iteration = test_env_setup["CODEFLASH_TEST_ITERATION"]
44+
db_path = Path.cwd() / f"codeflash_test_results_{iteration}.sqlite"
45+
46+
yield db_path
47+
48+
if db_path.exists():
49+
db_path.unlink()
50+
51+
@pytest.mark.asyncio
52+
async def test_behavior_async_basic_function(self, test_env_setup, temp_db_path):
53+
54+
@codeflash_behavior_async
55+
async def simple_async_add(a: int, b: int) -> int:
56+
await asyncio.sleep(0.001)
57+
return a + b
58+
59+
result = await simple_async_add(5, 3)
60+
61+
assert result == 8
62+
63+
assert temp_db_path.exists()
64+
65+
con = sqlite3.connect(temp_db_path)
66+
cur = con.cursor()
67+
68+
cur.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='test_results'")
69+
assert cur.fetchone() is not None
70+
71+
cur.execute("SELECT * FROM test_results")
72+
rows = cur.fetchall()
73+
74+
assert len(rows) == 1
75+
row = rows[0]
76+
77+
(test_module_path, test_class_name, test_function_name, function_getting_tested,
78+
loop_index, iteration_id, runtime, return_value_blob, verification_type) = row
79+
80+
assert test_module_path == __name__
81+
assert test_class_name == "TestAsyncWrapperSQLiteValidation"
82+
assert test_function_name == "test_behavior_async_basic_function"
83+
assert function_getting_tested == "simple_async_add"
84+
assert loop_index == 1
85+
# Line ID will be the actual line number from the source code, not a simple counter
86+
assert iteration_id.startswith("simple_async_add_") and iteration_id.endswith("_0")
87+
assert runtime > 0
88+
assert verification_type == VerificationType.FUNCTION_CALL.value
89+
90+
unpickled_data = pickle.loads(return_value_blob)
91+
args, kwargs, return_val = unpickled_data
92+
93+
assert args == (5, 3)
94+
assert kwargs == {}
95+
assert return_val == 8
96+
97+
con.close()
98+
99+
@pytest.mark.asyncio
100+
async def test_behavior_async_exception_handling(self, test_env_setup, temp_db_path):
101+
102+
@codeflash_behavior_async
103+
async def async_divide(a: int, b: int) -> float:
104+
await asyncio.sleep(0.001)
105+
if b == 0:
106+
raise ValueError("Cannot divide by zero")
107+
return a / b
108+
109+
result = await async_divide(10, 2)
110+
assert result == 5.0
111+
112+
with pytest.raises(ValueError, match="Cannot divide by zero"):
113+
await async_divide(10, 0)
114+
115+
con = sqlite3.connect(temp_db_path)
116+
cur = con.cursor()
117+
cur.execute("SELECT * FROM test_results ORDER BY iteration_id")
118+
rows = cur.fetchall()
119+
120+
assert len(rows) == 2
121+
122+
success_row = rows[0]
123+
success_data = pickle.loads(success_row[7]) # return_value_blob
124+
args, kwargs, return_val = success_data
125+
assert args == (10, 2)
126+
assert return_val == 5.0
127+
128+
# Check exception record
129+
exception_row = rows[1]
130+
exception_data = pickle.loads(exception_row[7]) # return_value_blob
131+
assert isinstance(exception_data, ValueError)
132+
assert str(exception_data) == "Cannot divide by zero"
133+
134+
con.close()
135+
136+
@pytest.mark.asyncio
137+
async def test_performance_async_no_database_storage(self, test_env_setup, temp_db_path, capsys):
138+
"""Test performance async decorator doesn't store to database."""
139+
140+
@codeflash_performance_async
141+
async def async_multiply(a: int, b: int) -> int:
142+
"""Async function for performance testing."""
143+
await asyncio.sleep(0.002)
144+
return a * b
145+
146+
result = await async_multiply(4, 7)
147+
148+
assert result == 28
149+
150+
assert not temp_db_path.exists()
151+
152+
captured = capsys.readouterr()
153+
output_lines = captured.out.strip().split('\n')
154+
155+
assert len([line for line in output_lines if "!$######" in line]) == 1
156+
assert len([line for line in output_lines if "!######" in line and "######!" in line]) == 1
157+
158+
closing_tag = [line for line in output_lines if "!######" in line and "######!" in line][0]
159+
assert "async_multiply" in closing_tag
160+
161+
timing_part = closing_tag.split(":")[-1].replace("######!", "")
162+
timing_value = int(timing_part)
163+
assert timing_value > 0 # Should have positive timing
164+
165+
@pytest.mark.asyncio
166+
async def test_multiple_calls_indexing(self, test_env_setup, temp_db_path):
167+
168+
@codeflash_behavior_async
169+
async def async_increment(value: int) -> int:
170+
await asyncio.sleep(0.001)
171+
return value + 1
172+
173+
# Call the function multiple times
174+
results = []
175+
for i in range(3):
176+
result = await async_increment(i)
177+
results.append(result)
178+
179+
assert results == [1, 2, 3]
180+
181+
con = sqlite3.connect(temp_db_path)
182+
cur = con.cursor()
183+
cur.execute("SELECT iteration_id, return_value FROM test_results ORDER BY iteration_id")
184+
rows = cur.fetchall()
185+
186+
assert len(rows) == 3
187+
188+
actual_ids = [row[0] for row in rows]
189+
assert len(actual_ids) == 3
190+
191+
base_pattern = actual_ids[0].rsplit('_', 1)[0] # e.g., "async_increment_199"
192+
expected_pattern = [f"{base_pattern}_{i}" for i in range(3)]
193+
assert actual_ids == expected_pattern
194+
195+
for i, (_, return_value_blob) in enumerate(rows):
196+
args, kwargs, return_val = pickle.loads(return_value_blob)
197+
assert args == (i,)
198+
assert return_val == i + 1
199+
200+
con.close()
201+
202+
@pytest.mark.asyncio
203+
async def test_complex_async_function_with_kwargs(self, test_env_setup, temp_db_path):
204+
205+
@codeflash_behavior_async
206+
async def complex_async_func(
207+
pos_arg: str,
208+
*args: int,
209+
keyword_arg: str = "default",
210+
**kwargs: str
211+
) -> dict:
212+
await asyncio.sleep(0.001)
213+
return {
214+
"pos_arg": pos_arg,
215+
"args": args,
216+
"keyword_arg": keyword_arg,
217+
"kwargs": kwargs,
218+
}
219+
220+
result = await complex_async_func(
221+
"hello",
222+
1, 2, 3,
223+
keyword_arg="custom",
224+
extra1="value1",
225+
extra2="value2"
226+
)
227+
228+
expected_result = {
229+
"pos_arg": "hello",
230+
"args": (1, 2, 3),
231+
"keyword_arg": "custom",
232+
"kwargs": {"extra1": "value1", "extra2": "value2"}
233+
}
234+
235+
assert result == expected_result
236+
237+
con = sqlite3.connect(temp_db_path)
238+
cur = con.cursor()
239+
cur.execute("SELECT return_value FROM test_results")
240+
row = cur.fetchone()
241+
242+
stored_args, stored_kwargs, stored_result = pickle.loads(row[0])
243+
244+
assert stored_args == ("hello", 1, 2, 3)
245+
assert stored_kwargs == {"keyword_arg": "custom", "extra1": "value1", "extra2": "value2"}
246+
assert stored_result == expected_result
247+
248+
con.close()
249+
250+
@pytest.mark.asyncio
251+
async def test_database_schema_validation(self, test_env_setup, temp_db_path):
252+
253+
@codeflash_behavior_async
254+
async def schema_test_func() -> str:
255+
return "schema_test"
256+
257+
await schema_test_func()
258+
259+
con = sqlite3.connect(temp_db_path)
260+
cur = con.cursor()
261+
262+
cur.execute("PRAGMA table_info(test_results)")
263+
columns = cur.fetchall()
264+
265+
expected_columns = [
266+
(0, 'test_module_path', 'TEXT', 0, None, 0),
267+
(1, 'test_class_name', 'TEXT', 0, None, 0),
268+
(2, 'test_function_name', 'TEXT', 0, None, 0),
269+
(3, 'function_getting_tested', 'TEXT', 0, None, 0),
270+
(4, 'loop_index', 'INTEGER', 0, None, 0),
271+
(5, 'iteration_id', 'TEXT', 0, None, 0),
272+
(6, 'runtime', 'INTEGER', 0, None, 0),
273+
(7, 'return_value', 'BLOB', 0, None, 0),
274+
(8, 'verification_type', 'TEXT', 0, None, 0)
275+
]
276+
277+
assert columns == expected_columns
278+
con.close()
279+
280+
def test_sync_test_context_extraction(self):
281+
from codeflash.code_utils.codeflash_wrap_decorator import extract_test_context_from_frame
282+
283+
test_module, test_class, test_func = extract_test_context_from_frame()
284+
assert test_module == __name__
285+
assert test_class == "TestAsyncWrapperSQLiteValidation"
286+
assert test_func == "test_sync_test_context_extraction"

0 commit comments

Comments
 (0)