Skip to content

Commit 771ba90

Browse files
committed
fix for conftest issue
1 parent 5651629 commit 771ba90

File tree

1 file changed

+32
-1
lines changed

1 file changed

+32
-1
lines changed

codeflash/code_utils/code_replacer.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import isort
99
import libcst as cst
1010
import libcst.matchers as m
11+
from libcst.metadata import PositionProvider
1112

1213
from codeflash.cli_cmds.console import logger
1314
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
@@ -37,6 +38,34 @@ def normalize_code(code: str) -> str:
3738
return ast.unparse(normalize_node(ast.parse(code)))
3839

3940

41+
class AddRequestArgument(cst.CSTTransformer):
42+
METADATA_DEPENDENCIES = (PositionProvider,)
43+
44+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
45+
args = updated_node.params.params
46+
arg_names = {arg.name.value for arg in args}
47+
48+
# Skip if 'request' is already present
49+
if "request" in arg_names:
50+
return updated_node
51+
52+
# Create a new 'request' param
53+
request_param = cst.Param(name=cst.Name("request"))
54+
55+
# Add 'request' as the first argument (after 'self' or 'cls' if needed)
56+
if args:
57+
first_arg = args[0].name.value
58+
if first_arg in {"self", "cls"}:
59+
new_params = [args[0], request_param] + list(args[1:]) # noqa: RUF005
60+
else:
61+
new_params = [request_param] + list(args) # noqa: RUF005
62+
else:
63+
new_params = [request_param]
64+
65+
new_param_list = updated_node.params.with_changes(params=new_params)
66+
return updated_node.with_changes(params=new_param_list)
67+
68+
4069
class PytestMarkAdder(cst.CSTTransformer):
4170
"""Transformer that adds pytest marks to test functions."""
4271

@@ -139,8 +168,10 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
139168
def disable_autouse(test_path: Path) -> str:
140169
file_content = test_path.read_text(encoding="utf-8")
141170
module = cst.parse_module(file_content)
171+
add_request_argument = AddRequestArgument()
142172
disable_autouse_fixture = AutouseFixtureModifier()
143-
modified_module = module.visit(disable_autouse_fixture)
173+
modified_module = module.visit(add_request_argument)
174+
modified_module = modified_module.visit(disable_autouse_fixture)
144175
test_path.write_text(modified_module.code, encoding="utf-8")
145176
return file_content
146177

0 commit comments

Comments
 (0)