Skip to content

Commit 84fa47d

Browse files
authored
Merge branch 'main' into benchmark-fixture
2 parents fa75373 + 5b33d1d commit 84fa47d

File tree

3 files changed

+585
-38
lines changed

3 files changed

+585
-38
lines changed

codeflash/code_utils/code_replacer.py

Lines changed: 87 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import isort
99
import libcst as cst
10-
import libcst.matchers as m
10+
from libcst.metadata import PositionProvider
1111

1212
from codeflash.cli_cmds.console import logger
1313
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
@@ -37,6 +37,55 @@ def normalize_code(code: str) -> str:
3737
return ast.unparse(normalize_node(ast.parse(code)))
3838

3939

40+
class AddRequestArgument(cst.CSTTransformer):
41+
METADATA_DEPENDENCIES = (PositionProvider,)
42+
43+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
44+
# Matcher for '@fixture' or '@pytest.fixture'
45+
for decorator in original_node.decorators:
46+
dec = decorator.decorator
47+
48+
if isinstance(dec, cst.Call):
49+
func_name = ""
50+
if isinstance(dec.func, cst.Attribute) and isinstance(dec.func.value, cst.Name):
51+
if dec.func.attr.value == "fixture" and dec.func.value.value == "pytest":
52+
func_name = "pytest.fixture"
53+
elif isinstance(dec.func, cst.Name) and dec.func.value == "fixture":
54+
func_name = "fixture"
55+
56+
if func_name:
57+
for arg in dec.args:
58+
if (
59+
arg.keyword
60+
and arg.keyword.value == "autouse"
61+
and isinstance(arg.value, cst.Name)
62+
and arg.value.value == "True"
63+
):
64+
args = updated_node.params.params
65+
arg_names = {arg.name.value for arg in args}
66+
67+
# Skip if 'request' is already present
68+
if "request" in arg_names:
69+
return updated_node
70+
71+
# Create a new 'request' param
72+
request_param = cst.Param(name=cst.Name("request"))
73+
74+
# Add 'request' as the first argument (after 'self' or 'cls' if needed)
75+
if args:
76+
first_arg = args[0].name.value
77+
if first_arg in {"self", "cls"}:
78+
new_params = [args[0], request_param] + list(args[1:]) # noqa: RUF005
79+
else:
80+
new_params = [request_param] + list(args) # noqa: RUF005
81+
else:
82+
new_params = [request_param]
83+
84+
new_param_list = updated_node.params.with_changes(params=new_params)
85+
return updated_node.with_changes(params=new_param_list)
86+
return updated_node
87+
88+
4089
class PytestMarkAdder(cst.CSTTransformer):
4190
"""Transformer that adds pytest marks to test functions."""
4291

@@ -106,41 +155,51 @@ def _create_pytest_mark(self) -> cst.Decorator:
106155
class AutouseFixtureModifier(cst.CSTTransformer):
107156
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
108157
# Matcher for '@fixture' or '@pytest.fixture'
109-
fixture_decorator_func = m.Name("fixture") | m.Attribute(value=m.Name("pytest"), attr=m.Name("fixture"))
110-
111158
for decorator in original_node.decorators:
112-
if m.matches(
113-
decorator,
114-
m.Decorator(
115-
decorator=m.Call(
116-
func=fixture_decorator_func, args=[m.Arg(value=m.Name("True"), keyword=m.Name("autouse"))]
117-
)
118-
),
119-
):
120-
# Found a matching fixture with autouse=True
121-
122-
# 1. The original body of the function will become the 'else' block.
123-
# updated_node.body is an IndentedBlock, which is what cst.Else expects.
124-
else_block = cst.Else(body=updated_node.body)
125-
126-
# 2. Create the new 'if' block that will exit the fixture early.
127-
if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")')
128-
yield_statement = cst.parse_statement("yield")
129-
if_body = cst.IndentedBlock(body=[yield_statement])
130-
131-
# 3. Construct the full if/else statement.
132-
new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block)
133-
134-
# 4. Replace the entire function's body with our new single statement.
135-
return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement]))
159+
dec = decorator.decorator
160+
161+
if isinstance(dec, cst.Call):
162+
func_name = ""
163+
if isinstance(dec.func, cst.Attribute) and isinstance(dec.func.value, cst.Name):
164+
if dec.func.attr.value == "fixture" and dec.func.value.value == "pytest":
165+
func_name = "pytest.fixture"
166+
elif isinstance(dec.func, cst.Name) and dec.func.value == "fixture":
167+
func_name = "fixture"
168+
169+
if func_name:
170+
for arg in dec.args:
171+
if (
172+
arg.keyword
173+
and arg.keyword.value == "autouse"
174+
and isinstance(arg.value, cst.Name)
175+
and arg.value.value == "True"
176+
):
177+
# Found a matching fixture with autouse=True
178+
179+
# 1. The original body of the function will become the 'else' block.
180+
# updated_node.body is an IndentedBlock, which is what cst.Else expects.
181+
else_block = cst.Else(body=updated_node.body)
182+
183+
# 2. Create the new 'if' block that will exit the fixture early.
184+
if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")')
185+
yield_statement = cst.parse_statement("yield")
186+
if_body = cst.IndentedBlock(body=[yield_statement])
187+
188+
# 3. Construct the full if/else statement.
189+
new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block)
190+
191+
# 4. Replace the entire function's body with our new single statement.
192+
return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement]))
136193
return updated_node
137194

138195

139196
def disable_autouse(test_path: Path) -> str:
140197
file_content = test_path.read_text(encoding="utf-8")
141198
module = cst.parse_module(file_content)
199+
add_request_argument = AddRequestArgument()
142200
disable_autouse_fixture = AutouseFixtureModifier()
143-
modified_module = module.visit(disable_autouse_fixture)
201+
modified_module = module.visit(add_request_argument)
202+
modified_module = modified_module.visit(disable_autouse_fixture)
144203
test_path.write_text(modified_module.code, encoding="utf-8")
145204
return file_content
146205

codeflash/tracer.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
import contextlib
15+
import datetime
1516
import importlib.machinery
1617
import io
1718
import json
@@ -81,6 +82,7 @@ def __init__(
8182
config_file_path: Path | None = None,
8283
max_function_count: int = 256,
8384
timeout: int | None = None, # seconds
85+
command: str | None = None,
8486
) -> None:
8587
"""Use this class to trace function calls.
8688
@@ -91,6 +93,7 @@ def __init__(
9193
:param max_function_count: Maximum number of times to trace one function
9294
:param timeout: Timeout in seconds for the tracer, if the traced code takes more than this time, then tracing
9395
stops and normal execution continues. If this is None then no timeout applies
96+
:param command: The command that initiated the tracing (for metadata storage)
9497
"""
9598
if functions is None:
9699
functions = []
@@ -148,6 +151,9 @@ def __init__(
148151
assert "test_framework" in self.config, "Please specify 'test-framework' in pyproject.toml config file"
149152
self.t = self.timer()
150153

154+
# Store command information for metadata table
155+
self.command = command if command else " ".join(sys.argv)
156+
151157
def __enter__(self) -> None:
152158
if self.disable:
153159
return
@@ -174,6 +180,22 @@ def __enter__(self) -> None:
174180
"CREATE TABLE function_calls(type TEXT, function TEXT, classname TEXT, filename TEXT, "
175181
"line_number INTEGER, last_frame_address INTEGER, time_ns INTEGER, args BLOB)"
176182
)
183+
184+
# Create metadata table to store command information
185+
cur.execute("CREATE TABLE metadata(key TEXT PRIMARY KEY, value TEXT)")
186+
187+
# Store command metadata
188+
cur.execute("INSERT INTO metadata VALUES (?, ?)", ("command", self.command))
189+
cur.execute("INSERT INTO metadata VALUES (?, ?)", ("program_name", self.file_being_called_from))
190+
cur.execute(
191+
"INSERT INTO metadata VALUES (?, ?)",
192+
("functions_filter", json.dumps(self.functions) if self.functions else None),
193+
)
194+
cur.execute(
195+
"INSERT INTO metadata VALUES (?, ?)",
196+
("timestamp", datetime.datetime.now(datetime.timezone.utc).isoformat()),
197+
)
198+
cur.execute("INSERT INTO metadata VALUES (?, ?)", ("project_root", str(self.project_root)))
177199
console.rule("Codeflash: Traced Program Output Begin", style="bold blue")
178200
frame = sys._getframe(0) # Get this frame and simulate a call to it # noqa: SLF001
179201
self.dispatch["call"](self, frame, 0)
@@ -842,6 +864,7 @@ def main() -> ArgumentParser:
842864
max_function_count=args.max_function_count,
843865
timeout=args.tracer_timeout,
844866
config_file_path=args.codeflash_config,
867+
command=" ".join(sys.argv),
845868
).runctx(code, globs, None)
846869

847870
except BrokenPipeError as exc:

0 commit comments

Comments
 (0)