Skip to content

Commit fadbdf7

Browse files
committed
better impl
1 parent a42c63a commit fadbdf7

File tree

2 files changed

+34
-174
lines changed

2 files changed

+34
-174
lines changed

codeflash/code_utils/codeflash_wrap_decorator.py

Lines changed: 12 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,17 @@
11
from __future__ import annotations
22

33
import asyncio
4-
import contextlib
54
import gc
6-
import inspect
75
import os
86
import sqlite3
9-
import time
107
from enum import Enum
118
from functools import wraps
129
from pathlib import Path
1310
from tempfile import TemporaryDirectory
14-
from typing import TYPE_CHECKING, Any, Callable, TypeVar
11+
from typing import Any, Callable, TypeVar
1512

1613
import dill as pickle
1714

18-
if TYPE_CHECKING:
19-
from types import FrameType
20-
2115

2216
class VerificationType(str, Enum): # moved from codeflash/verification/codeflash_capture.py
2317
FUNCTION_CALL = (
@@ -36,175 +30,19 @@ def get_run_tmp_file(file_path: Path) -> Path: # moved from codeflash/code_util
3630
return Path(get_run_tmp_file.tmpdir.name) / file_path
3731

3832

39-
def _extract_class_name_tracer(frame_locals: dict[str, Any]) -> str | None:
40-
try:
41-
self_arg = frame_locals.get("self")
42-
if self_arg is not None:
43-
try:
44-
return self_arg.__class__.__name__
45-
except (AttributeError, Exception):
46-
cls_arg = frame_locals.get("cls")
47-
if cls_arg is not None:
48-
with contextlib.suppress(AttributeError, Exception):
49-
return cls_arg.__name__
50-
else:
51-
cls_arg = frame_locals.get("cls")
52-
if cls_arg is not None:
53-
with contextlib.suppress(AttributeError, Exception):
54-
return cls_arg.__name__
55-
except Exception:
56-
return None
57-
return None
58-
59-
60-
def _get_module_name_cf_tracer(frame: FrameType | None) -> str:
61-
try:
62-
test_module = inspect.getmodule(frame)
63-
except Exception:
64-
test_module = None
65-
66-
if test_module is not None:
67-
module_name = getattr(test_module, "__name__", None)
68-
if module_name is not None:
69-
return module_name
70-
71-
if frame is not None:
72-
return frame.f_globals.get("__name__", "unknown_module")
73-
return "unknown_module"
33+
def extract_test_context_from_frame() -> tuple[str, str | None, str]:
34+
# test_module = os.environ.get("CODEFLASH_TEST_MODULE")
35+
test_module = os.environ["CODEFLASH_TEST_MODULE"]
36+
test_class = os.environ.get("CODEFLASH_TEST_CLASS", None)
37+
# test_function = os.environ.get("CODEFLASH_TEST_FUNCTION")
38+
test_function = os.environ["CODEFLASH_TEST_FUNCTION"]
7439

40+
if test_module and test_function:
41+
return (test_module, test_class if test_class else None, test_function)
7542

76-
def extract_test_context_from_frame() -> tuple[str, str | None, str]:
77-
frame = inspect.currentframe()
78-
# optimize?
79-
try:
80-
frames_info = []
81-
potential_tests = []
82-
83-
# First pass: collect all frame information
84-
if frame is not None:
85-
frame = frame.f_back
86-
87-
while frame is not None:
88-
try:
89-
function_name = frame.f_code.co_name
90-
filename = frame.f_code.co_filename
91-
filename_path = Path(filename)
92-
frame_locals = frame.f_locals
93-
test_module_name = _get_module_name_cf_tracer(frame)
94-
class_name = _extract_class_name_tracer(frame_locals)
95-
96-
frames_info.append(
97-
{
98-
"function_name": function_name,
99-
"filename_path": filename_path,
100-
"frame_locals": frame_locals,
101-
"test_module_name": test_module_name,
102-
"class_name": class_name,
103-
"frame": frame,
104-
}
105-
)
106-
107-
except Exception: # noqa: S112
108-
continue
109-
110-
frame = frame.f_back
111-
112-
# Second pass: analyze frames with full context
113-
test_class_candidates = []
114-
for frame_info in frames_info:
115-
function_name = frame_info["function_name"]
116-
filename_path = frame_info["filename_path"]
117-
frame_locals = frame_info["frame_locals"]
118-
test_module_name = frame_info["test_module_name"]
119-
class_name = frame_info["class_name"]
120-
frame_obj = frame_info["frame"]
121-
122-
# Keep track of test classes
123-
if class_name and (
124-
class_name.startswith("Test") or class_name.endswith("Test") or "test" in class_name.lower()
125-
):
126-
test_class_candidates.append((class_name, test_module_name))
127-
128-
# Now process frames again looking for test functions with full candidates list
129-
# Collect all test functions to prioritize outer ones over nested ones
130-
test_functions = []
131-
for frame_info in frames_info:
132-
function_name = frame_info["function_name"]
133-
filename_path = frame_info["filename_path"]
134-
frame_locals = frame_info["frame_locals"]
135-
test_module_name = frame_info["test_module_name"]
136-
class_name = frame_info["class_name"]
137-
frame_obj = frame_info["frame"]
138-
139-
# Collect test functions
140-
if function_name.startswith("test_"):
141-
test_class_name = class_name
142-
143-
# If no class found in current frame, check if we have any test class candidates
144-
# Prefer the innermost (first) test class candidate which is more specific
145-
if test_class_name is None and test_class_candidates:
146-
test_class_name = test_class_candidates[0][0]
147-
148-
test_functions.append((test_module_name, test_class_name, function_name))
149-
150-
# Prioritize test functions with class context, then innermost
151-
if test_functions:
152-
# First prefer test functions with class context
153-
for test_func in test_functions:
154-
if test_func[1] is not None: # has class_name
155-
return test_func
156-
# If no test function has class context, return the outermost (most likely the actual test method)
157-
return test_functions[-1]
158-
159-
# If no direct test functions found, look for other test patterns
160-
for frame_info in frames_info:
161-
function_name = frame_info["function_name"]
162-
filename_path = frame_info["filename_path"]
163-
frame_locals = frame_info["frame_locals"]
164-
test_module_name = frame_info["test_module_name"]
165-
class_name = frame_info["class_name"]
166-
frame_obj = frame_info["frame"]
167-
168-
# Test file/module detection
169-
if (
170-
frame_obj.f_globals.get("__name__", "").startswith("test_")
171-
or filename_path.stem.startswith("test_")
172-
or "test" in filename_path.parts
173-
):
174-
if class_name and (
175-
class_name.startswith("Test") or class_name.endswith("Test") or "test" in class_name.lower()
176-
):
177-
potential_tests.append((test_module_name, class_name, function_name))
178-
elif "test" in test_module_name or filename_path.stem.startswith("test_"):
179-
# For functions without class context, try to find the most recent test class
180-
best_class = test_class_candidates[0][0] if test_class_candidates else None
181-
potential_tests.append((test_module_name, best_class, function_name))
182-
183-
# Framework integration detection
184-
if (
185-
(
186-
function_name in ["runTest", "_runTest", "run", "_testMethodName"]
187-
or "pytest" in str(frame_obj.f_globals.get("__file__", ""))
188-
or "unittest" in str(frame_obj.f_globals.get("__file__", ""))
189-
)
190-
and class_name
191-
and (class_name.startswith("Test") or "test" in class_name.lower())
192-
):
193-
test_method = function_name
194-
if "self" in frame_locals:
195-
with contextlib.suppress(AttributeError, TypeError):
196-
test_method = getattr(frame_locals["self"], "_testMethodName", function_name)
197-
potential_tests.append((test_module_name, class_name, test_method))
198-
199-
if potential_tests:
200-
for test_module, test_class, test_func in potential_tests:
201-
if test_func.startswith("test_"):
202-
return test_module, test_class, test_func
203-
return potential_tests[0]
204-
205-
raise RuntimeError("No test function found in call stack")
206-
finally:
207-
del frame
43+
raise RuntimeError(
44+
"Test context environment variables not set - ensure tests are run through codeflash test runner"
45+
)
20846

20947

21048
def codeflash_behavior_async(func: F) -> F:

codeflash/verification/pytest_plugin.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,3 +450,25 @@ def make_progress_id(i: int, n: int = count) -> str:
450450
metafunc.parametrize(
451451
"__pytest_loop_step_number", range(count), indirect=True, ids=make_progress_id, scope=scope
452452
)
453+
454+
@pytest.hookimpl(tryfirst=True)
455+
def pytest_runtest_setup(self, item: pytest.Item) -> None:
456+
test_module_name = item.module.__name__ if item.module else "unknown_module"
457+
458+
test_class_name = None
459+
if item.cls:
460+
test_class_name = item.cls.__name__
461+
462+
test_function_name = item.name
463+
if "[" in test_function_name:
464+
test_function_name = test_function_name.split("[", 1)[0]
465+
466+
os.environ["CODEFLASH_TEST_MODULE"] = test_module_name
467+
os.environ["CODEFLASH_TEST_CLASS"] = test_class_name or ""
468+
os.environ["CODEFLASH_TEST_FUNCTION"] = test_function_name
469+
470+
@pytest.hookimpl(trylast=True)
471+
def pytest_runtest_teardown(self, _: pytest.Item) -> None:
472+
"""Clean up test context environment variables after each test."""
473+
for var in ["CODEFLASH_TEST_MODULE", "CODEFLASH_TEST_CLASS", "CODEFLASH_TEST_FUNCTION"]:
474+
os.environ.pop(var, None)

0 commit comments

Comments
 (0)