Skip to content

Commit 4514dd6

Browse files
committed
address feedback re: using existing logic from tracer & cf capture
1 parent b569d44 commit 4514dd6

File tree

1 file changed

+84
-50
lines changed

1 file changed

+84
-50
lines changed

codeflash/code_utils/codeflash_wrap_decorator.py

Lines changed: 84 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import contextlib
34
import gc
45
import inspect
56
import os
@@ -8,10 +9,13 @@
89
from enum import Enum
910
from functools import wraps
1011
from pathlib import Path
11-
from typing import Any, Callable, TypeVar
12+
from typing import TYPE_CHECKING, Any, Callable, TypeVar
1213

1314
import dill as pickle
1415

16+
if TYPE_CHECKING:
17+
from types import FrameType
18+
1519

1620
class VerificationType(str, Enum): # moved from codeflash/verification/codeflash_capture.py
1721
FUNCTION_CALL = (
@@ -24,63 +28,93 @@ class VerificationType(str, Enum): # moved from codeflash/verification/codeflas
2428
F = TypeVar("F", bound=Callable[..., Any])
2529

2630

31+
def _extract_class_name_tracer(frame_locals: dict[str, Any]) -> str | None:
32+
try:
33+
self_arg = frame_locals.get("self")
34+
if self_arg is not None:
35+
try:
36+
return self_arg.__class__.__name__
37+
except AttributeError:
38+
cls_arg = frame_locals.get("cls")
39+
if cls_arg is not None:
40+
with contextlib.suppress(AttributeError):
41+
return cls_arg.__name__
42+
else:
43+
cls_arg = frame_locals.get("cls")
44+
if cls_arg is not None:
45+
with contextlib.suppress(AttributeError):
46+
return cls_arg.__name__
47+
except: # noqa: E722
48+
# Handle cases where getattr is overridden and raises exceptions (e.g., wrapt)
49+
return None
50+
return None
51+
52+
53+
def _get_module_name_cf_tracer(frame: FrameType | None) -> str:
54+
with contextlib.suppress(Exception):
55+
test_module = inspect.getmodule(frame)
56+
if test_module and hasattr(test_module, "__name__"):
57+
return test_module.__name__
58+
59+
if frame is not None:
60+
return frame.f_globals.get("__name__", "unknown_module")
61+
return "unknown_module"
62+
63+
2764
def extract_test_context_from_frame() -> tuple[str, str | None, str]:
2865
frame = inspect.currentframe()
2966
try:
3067
potential_tests = []
3168

32-
while frame:
69+
if frame is not None:
3370
frame = frame.f_back
34-
if not frame:
35-
break
36-
37-
function_name = frame.f_code.co_name
38-
filename = frame.f_code.co_filename
39-
filename_path = Path(filename)
40-
41-
if function_name.startswith("test_"):
42-
test_name = function_name
43-
test_module_name = frame.f_globals.get("__name__", "unknown_module")
44-
test_class_name = None
45-
46-
if "self" in frame.f_locals:
47-
self_obj = frame.f_locals["self"]
48-
if hasattr(self_obj, "__class__") and hasattr(self_obj.__class__, "__name__"):
49-
test_class_name = self_obj.__class__.__name__
50-
51-
return test_module_name, test_class_name, test_name
52-
53-
if (
54-
frame.f_globals.get("__name__", "").startswith("test_")
55-
or filename_path.stem.startswith("test_")
56-
or "test" in filename_path.parts
57-
):
58-
test_module_name = frame.f_globals.get("__name__", "unknown_module")
59-
60-
if "self" in frame.f_locals:
61-
self_obj = frame.f_locals["self"]
62-
if hasattr(self_obj, "__class__") and hasattr(self_obj.__class__, "__name__"):
63-
class_name = self_obj.__class__.__name__
71+
72+
while frame is not None:
73+
try:
74+
function_name = frame.f_code.co_name
75+
filename = frame.f_code.co_filename
76+
filename_path = Path(filename)
77+
frame_locals = frame.f_locals
78+
if function_name.startswith("test_"):
79+
test_name = function_name
80+
test_module_name = _get_module_name_cf_tracer(frame)
81+
test_class_name = _extract_class_name_tracer(frame_locals)
82+
return test_module_name, test_class_name, test_name
83+
84+
if (
85+
frame.f_globals.get("__name__", "").startswith("test_")
86+
or filename_path.stem.startswith("test_")
87+
or "test" in filename_path.parts
88+
):
89+
test_module_name = _get_module_name_cf_tracer(frame)
90+
class_name = _extract_class_name_tracer(frame_locals)
91+
92+
if class_name:
6493
if class_name.startswith("Test") or class_name.endswith("Test") or "test" in class_name.lower():
6594
potential_tests.append((test_module_name, class_name, function_name))
66-
67-
elif "test" in test_module_name or filename_path.stem.startswith("test_"):
68-
potential_tests.append((test_module_name, None, function_name))
69-
70-
if (
71-
function_name in ["runTest", "_runTest", "run", "_testMethodName"]
72-
or "pytest" in str(frame.f_globals.get("__file__", ""))
73-
or "unittest" in str(frame.f_globals.get("__file__", ""))
74-
):
75-
# This might be a test framework frame, look for test context nearby
76-
test_module_name = frame.f_globals.get("__name__", "unknown_module")
77-
if "self" in frame.f_locals:
78-
self_obj = frame.f_locals["self"]
79-
if hasattr(self_obj, "__class__"):
80-
class_name = self_obj.__class__.__name__
81-
if class_name.startswith("Test") or "test" in class_name.lower():
82-
test_method = getattr(self_obj, "_testMethodName", function_name)
83-
potential_tests.append((test_module_name, class_name, test_method))
95+
elif "test" in test_module_name or filename_path.stem.startswith("test_"):
96+
potential_tests.append((test_module_name, None, function_name))
97+
98+
if (
99+
function_name in ["runTest", "_runTest", "run", "_testMethodName"]
100+
or "pytest" in str(frame.f_globals.get("__file__", ""))
101+
or "unittest" in str(frame.f_globals.get("__file__", ""))
102+
):
103+
test_module_name = _get_module_name_cf_tracer(frame)
104+
class_name = _extract_class_name_tracer(frame_locals)
105+
106+
if class_name and (class_name.startswith("Test") or "test" in class_name.lower()):
107+
test_method = function_name
108+
if "self" in frame_locals:
109+
with contextlib.suppress(AttributeError, TypeError):
110+
test_method = getattr(frame_locals["self"], "_testMethodName", function_name)
111+
potential_tests.append((test_module_name, class_name, test_method))
112+
113+
frame = frame.f_back
114+
115+
except Exception:
116+
frame = frame.f_back
117+
continue
84118

85119
if potential_tests:
86120
for test_module, test_class, test_func in potential_tests:

0 commit comments

Comments
 (0)