Skip to content

Commit 14353d4

Browse files
committed
Update codeflash_wrap_decorator.py
1 parent ad1d085 commit 14353d4

File tree

1 file changed

+56
-7
lines changed

1 file changed

+56
-7
lines changed

codeflash/code_utils/codeflash_wrap_decorator.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,77 @@
1717
def extract_test_context_from_frame() -> tuple[str, str | None, str]:
1818
frame = inspect.currentframe()
1919
try:
20+
potential_tests = []
21+
2022
while frame:
2123
frame = frame.f_back
22-
if frame and frame.f_code.co_name.startswith("test_"):
23-
test_name = frame.f_code.co_name
24+
if not frame:
25+
break
26+
27+
function_name = frame.f_code.co_name
28+
filename = frame.f_code.co_filename
29+
30+
if function_name.startswith("test_"):
31+
test_name = function_name
2432
test_module_name = frame.f_globals.get("__name__", "unknown_module")
2533
test_class_name = None
34+
2635
if "self" in frame.f_locals:
27-
test_class_name = frame.f_locals["self"].__class__.__name__
36+
self_obj = frame.f_locals["self"]
37+
if hasattr(self_obj, "__class__") and hasattr(self_obj.__class__, "__name__"):
38+
test_class_name = self_obj.__class__.__name__
2839

2940
return test_module_name, test_class_name, test_name
41+
42+
if (
43+
frame.f_globals.get("__name__", "").startswith("test_")
44+
or Path(filename).stem.startswith("test_")
45+
or "test" in Path(filename).parts
46+
):
47+
test_module_name = frame.f_globals.get("__name__", "unknown_module")
48+
49+
if "self" in frame.f_locals:
50+
self_obj = frame.f_locals["self"]
51+
if hasattr(self_obj, "__class__") and hasattr(self_obj.__class__, "__name__"):
52+
class_name = self_obj.__class__.__name__
53+
if class_name.startswith("Test") or class_name.endswith("Test") or "test" in class_name.lower():
54+
potential_tests.append((test_module_name, class_name, function_name))
55+
56+
elif "test" in test_module_name or Path(filename).stem.startswith("test_"):
57+
potential_tests.append((test_module_name, None, function_name))
58+
59+
if (
60+
function_name in ["runTest", "_runTest", "run", "_testMethodName"]
61+
or "pytest" in str(frame.f_globals.get("__file__", ""))
62+
or "unittest" in str(frame.f_globals.get("__file__", ""))
63+
):
64+
# This might be a test framework frame, look for test context nearby
65+
test_module_name = frame.f_globals.get("__name__", "unknown_module")
66+
if "self" in frame.f_locals:
67+
self_obj = frame.f_locals["self"]
68+
if hasattr(self_obj, "__class__"):
69+
class_name = self_obj.__class__.__name__
70+
if class_name.startswith("Test") or "test" in class_name.lower():
71+
test_method = getattr(self_obj, "_testMethodName", function_name)
72+
potential_tests.append((test_module_name, class_name, test_method))
73+
74+
if potential_tests:
75+
for test_module, test_class, test_func in potential_tests:
76+
if test_func.startswith("test_"):
77+
return test_module, test_class, test_func
78+
return potential_tests[0]
79+
3080
raise RuntimeError("No test function found in call stack")
3181
finally:
3282
del frame
3383

3484

3585
def codeflash_behavior_async(func: F) -> F:
36-
function_name = func.__name__
37-
line_id = f"{func.__name__}_{func.__code__.co_firstlineno}"
38-
loop_index = int(os.environ.get("CODEFLASH_LOOP_INDEX", "1"))
39-
4086
@wraps(func)
4187
async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
88+
function_name = func.__name__
89+
line_id = f"{func.__name__}_{func.__code__.co_firstlineno}"
90+
loop_index = os.environ["CODEFLASH_LOOP_INDEX"]
4291
test_module_name, test_class_name, test_name = extract_test_context_from_frame()
4392

4493
test_id = f"{test_module_name}:{test_class_name}:{test_name}:{line_id}:{loop_index}"

0 commit comments

Comments
 (0)