Skip to content

Commit 0a57afa

Browse files
committed
first pass at wrapper deco behavioral
1 parent 207612c commit 0a57afa

File tree

7 files changed

+1702
-629
lines changed

7 files changed

+1702
-629
lines changed
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
from __future__ import annotations
2+
3+
import gc
4+
import inspect
5+
import os
6+
import pickle
7+
import time
8+
from functools import wraps
9+
from pathlib import Path
10+
from typing import Any, Callable, TypeVar
11+
12+
from codeflash.code_utils.code_utils import get_run_tmp_file
13+
14+
F = TypeVar("F", bound=Callable[..., Any])
15+
16+
17+
def extract_test_context_from_frame() -> tuple[str, str | None, str]:
18+
frame = inspect.currentframe()
19+
try:
20+
while frame:
21+
frame = frame.f_back
22+
if frame and frame.f_code.co_name.startswith("test_"):
23+
test_name = frame.f_code.co_name
24+
test_module_name = frame.f_globals.get("__name__", "unknown_module")
25+
test_class_name = None
26+
if "self" in frame.f_locals:
27+
test_class_name = frame.f_locals["self"].__class__.__name__
28+
29+
return test_module_name, test_class_name, test_name
30+
raise RuntimeError("No test function found in call stack")
31+
finally:
32+
del frame
33+
34+
35+
def codeflash_behavior_async(func: F) -> F:
36+
function_name = func.__name__
37+
line_id = f"{func.__name__}_{func.__code__.co_firstlineno}"
38+
loop_index = int(os.environ.get("CODEFLASH_LOOP_INDEX", "1"))
39+
40+
@wraps(func)
41+
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
42+
test_module_name, test_class_name, test_name = extract_test_context_from_frame()
43+
44+
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
45+
46+
if not hasattr(async_wrapper, "index"):
47+
async_wrapper.index = {}
48+
if test_id in async_wrapper.index:
49+
async_wrapper.index[test_id] += 1
50+
else:
51+
async_wrapper.index[test_id] = 0
52+
53+
codeflash_test_index = async_wrapper.index[test_id]
54+
invocation_id = f"{line_id}_{codeflash_test_index}"
55+
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
56+
57+
print(f"!$######{test_stdout_tag}######$!")
58+
59+
exception = None
60+
gc.disable()
61+
try:
62+
counter = time.perf_counter_ns()
63+
ret = func(*args, **kwargs)
64+
65+
if inspect.isawaitable(ret):
66+
counter = time.perf_counter_ns()
67+
return_value = await ret
68+
else:
69+
return_value = ret
70+
71+
codeflash_duration = time.perf_counter_ns() - counter
72+
except Exception as e:
73+
codeflash_duration = time.perf_counter_ns() - counter
74+
exception = e
75+
finally:
76+
gc.enable()
77+
78+
print(f"!######{test_stdout_tag}######!")
79+
80+
iteration = os.environ.get("CODEFLASH_TEST_ITERATION", "0")
81+
82+
codeflash_run_tmp_dir = get_run_tmp_file(Path()).as_posix()
83+
84+
output_file = Path(codeflash_run_tmp_dir) / f"test_return_values_{iteration}.bin"
85+
86+
with output_file.open("ab") as f:
87+
pickled_values = (
88+
pickle.dumps((args, kwargs, exception)) if exception else pickle.dumps((args, kwargs, return_value))
89+
)
90+
_test_name = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{line_id}".encode(
91+
"ascii"
92+
)
93+
94+
f.write(len(_test_name).to_bytes(4, byteorder="big"))
95+
f.write(_test_name)
96+
f.write(codeflash_duration.to_bytes(8, byteorder="big"))
97+
f.write(len(pickled_values).to_bytes(4, byteorder="big"))
98+
f.write(pickled_values)
99+
f.write(loop_index.to_bytes(8, byteorder="big"))
100+
f.write(len(invocation_id).to_bytes(4, byteorder="big"))
101+
f.write(invocation_id.encode("ascii"))
102+
103+
if exception:
104+
raise exception
105+
return return_value
106+
107+
return async_wrapper
108+
109+
110+
def codeflash_performance_async(func: F) -> F:
111+
function_name = func.__name__
112+
line_id = f"{func.__name__}_{func.__code__.co_firstlineno}"
113+
loop_index = int(os.environ.get("CODEFLASH_LOOP_INDEX", "1"))
114+
115+
@wraps(func)
116+
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
117+
test_module_name, test_class_name, test_name = extract_test_context_from_frame()
118+
119+
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"
120+
121+
if not hasattr(async_wrapper, "index"):
122+
async_wrapper.index = {}
123+
if test_id in async_wrapper.index:
124+
async_wrapper.index[test_id] += 1
125+
else:
126+
async_wrapper.index[test_id] = 0
127+
128+
codeflash_test_index = async_wrapper.index[test_id]
129+
invocation_id = f"{line_id}_{codeflash_test_index}"
130+
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
131+
132+
print(f"!$######{test_stdout_tag}######$!")
133+
134+
exception = None
135+
gc.disable()
136+
try:
137+
counter = time.perf_counter_ns()
138+
ret = func(*args, **kwargs)
139+
140+
if inspect.isawaitable(ret):
141+
counter = time.perf_counter_ns()
142+
return_value = await ret
143+
else:
144+
return_value = ret
145+
146+
codeflash_duration = time.perf_counter_ns() - counter
147+
except Exception as e:
148+
codeflash_duration = time.perf_counter_ns() - counter
149+
exception = e
150+
finally:
151+
gc.enable()
152+
153+
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
154+
155+
if exception:
156+
raise exception
157+
return return_value
158+
159+
return async_wrapper

0 commit comments

Comments
 (0)