Skip to content

Commit b153989

Browse files
committed
first pass at wrapper deco behavioral
1 parent 207612c commit b153989

File tree

7 files changed

+1711
-629
lines changed

7 files changed

+1711
-629
lines changed
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
from __future__ import annotations
2+
3+
import gc
4+
import inspect
5+
import os
6+
import pickle
7+
import time
8+
from functools import wraps
9+
from pathlib import Path
10+
from typing import Any, Callable, TypeVar
11+
12+
from codeflash.code_utils.code_utils import get_run_tmp_file
13+
14+
F = TypeVar("F", bound=Callable[..., Any])
15+
16+
17+
def extract_test_context_from_frame() -> tuple[str, str | None, str]:
18+
frame = inspect.currentframe()
19+
try:
20+
while frame:
21+
frame = frame.f_back
22+
if frame and frame.f_code.co_name.startswith("test_"):
23+
test_name = frame.f_code.co_name
24+
test_module_name = frame.f_globals.get("__name__", "unknown_module")
25+
test_class_name = None
26+
if "self" in frame.f_locals:
27+
test_class_name = frame.f_locals["self"].__class__.__name__
28+
29+
return test_module_name, test_class_name, test_name
30+
raise RuntimeError("No test function found in call stack")
31+
finally:
32+
del frame
33+
34+
35+
def codeflash_behavior_async(func: F) -> F:
36+
"""Async decorator for behavior analysis - collects timing data and function inputs/outputs."""
37+
function_name = func.__name__
38+
line_id = f"{func.__name__}_{func.__code__.co_firstlineno}"
39+
loop_index = int(os.environ.get("CODEFLASH_LOOP_INDEX", "1"))
40+
41+
@wraps(func)
42+
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
43+
test_module_name, test_class_name, test_name = extract_test_context_from_frame()
44+
45+
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
46+
47+
if not hasattr(async_wrapper, "index"):
48+
async_wrapper.index = {}
49+
if test_id in async_wrapper.index:
50+
async_wrapper.index[test_id] += 1
51+
else:
52+
async_wrapper.index[test_id] = 0
53+
54+
codeflash_test_index = async_wrapper.index[test_id]
55+
invocation_id = f"{line_id}_{codeflash_test_index}"
56+
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
57+
58+
print(f"!$######{test_stdout_tag}######$!")
59+
60+
exception = None
61+
gc.disable()
62+
try:
63+
counter = time.perf_counter_ns()
64+
ret = func(*args, **kwargs)
65+
66+
if inspect.isawaitable(ret):
67+
counter = time.perf_counter_ns()
68+
return_value = await ret
69+
else:
70+
return_value = ret
71+
72+
codeflash_duration = time.perf_counter_ns() - counter
73+
except Exception as e:
74+
codeflash_duration = time.perf_counter_ns() - counter
75+
exception = e
76+
finally:
77+
gc.enable()
78+
79+
print(f"!######{test_stdout_tag}######!")
80+
81+
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
82+
83+
codeflash_run_tmp_dir = get_run_tmp_file(Path()).as_posix()
84+
85+
output_file = Path(codeflash_run_tmp_dir) / f"test_return_values_{iteration}.bin"
86+
87+
with output_file.open("ab") as f:
88+
pickled_values = (
89+
pickle.dumps((args, kwargs, exception)) if exception else pickle.dumps((args, kwargs, return_value))
90+
)
91+
_test_name = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{line_id}".encode(
92+
"ascii"
93+
)
94+
95+
f.write(len(_test_name).to_bytes(4, byteorder="big"))
96+
f.write(_test_name)
97+
f.write(codeflash_duration.to_bytes(8, byteorder="big"))
98+
f.write(len(pickled_values).to_bytes(4, byteorder="big"))
99+
f.write(pickled_values)
100+
f.write(loop_index.to_bytes(8, byteorder="big"))
101+
f.write(len(invocation_id).to_bytes(4, byteorder="big"))
102+
f.write(invocation_id.encode("ascii"))
103+
104+
if exception:
105+
raise exception
106+
return return_value
107+
108+
return async_wrapper
109+
110+
111+
def codeflash_performance_async(func: F) -> F:
112+
"""Async decorator for performance analysis - lightweight timing measurements only."""
113+
function_name = func.__name__
114+
line_id = f"{func.__name__}_{func.__code__.co_firstlineno}"
115+
loop_index = int(os.environ.get("CODEFLASH_LOOP_INDEX", "1"))
116+
117+
@wraps(func)
118+
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
119+
test_module_name, test_class_name, test_name = extract_test_context_from_frame()
120+
121+
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
122+
123+
if not hasattr(async_wrapper, "index"):
124+
async_wrapper.index = {}
125+
if test_id in async_wrapper.index:
126+
async_wrapper.index[test_id] += 1
127+
else:
128+
async_wrapper.index[test_id] = 0
129+
130+
codeflash_test_index = async_wrapper.index[test_id]
131+
invocation_id = f"{line_id}_{codeflash_test_index}"
132+
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
133+
134+
print(f"!$######{test_stdout_tag}######$!")
135+
136+
exception = None
137+
gc.disable()
138+
try:
139+
counter = time.perf_counter_ns()
140+
ret = func(*args, **kwargs)
141+
142+
if inspect.isawaitable(ret):
143+
counter = time.perf_counter_ns()
144+
return_value = await ret
145+
else:
146+
return_value = ret
147+
148+
codeflash_duration = time.perf_counter_ns() - counter
149+
except Exception as e:
150+
codeflash_duration = time.perf_counter_ns() - counter
151+
exception = e
152+
finally:
153+
gc.enable()
154+
155+
# For performance mode, include timing in the output tag like sync functions do
156+
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
157+
158+
if exception:
159+
raise exception
160+
return return_value
161+
162+
return async_wrapper

0 commit comments

Comments
 (0)