Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 32 additions & 1 deletion codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import isort
import libcst as cst
import libcst.matchers as m
from libcst.metadata import PositionProvider

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
Expand Down Expand Up @@ -37,6 +38,34 @@ def normalize_code(code: str) -> str:
return ast.unparse(normalize_node(ast.parse(code)))


class AddRequestArgument(cst.CSTTransformer):
METADATA_DEPENDENCIES = (PositionProvider,)

def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
args = updated_node.params.params
arg_names = {arg.name.value for arg in args}

# Skip if 'request' is already present
if "request" in arg_names:
return updated_node

# Create a new 'request' param
request_param = cst.Param(name=cst.Name("request"))

# Add 'request' as the first argument (after 'self' or 'cls' if needed)
if args:
first_arg = args[0].name.value
if first_arg in {"self", "cls"}:
new_params = [args[0], request_param] + list(args[1:]) # noqa: RUF005
else:
new_params = [request_param] + list(args) # noqa: RUF005
else:
new_params = [request_param]

new_param_list = updated_node.params.with_changes(params=new_params)
return updated_node.with_changes(params=new_param_list)


class PytestMarkAdder(cst.CSTTransformer):
"""Transformer that adds pytest marks to test functions."""

Expand Down Expand Up @@ -139,8 +168,10 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
def disable_autouse(test_path: Path) -> str:
file_content = test_path.read_text(encoding="utf-8")
module = cst.parse_module(file_content)
add_request_argument = AddRequestArgument()
disable_autouse_fixture = AutouseFixtureModifier()
modified_module = module.visit(disable_autouse_fixture)
modified_module = module.visit(add_request_argument)
modified_module = modified_module.visit(disable_autouse_fixture)
test_path.write_text(modified_module.code, encoding="utf-8")
return file_content

Expand Down
Loading