Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 87 additions & 28 deletions codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import isort
import libcst as cst
import libcst.matchers as m
from libcst.metadata import PositionProvider

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
Expand Down Expand Up @@ -37,6 +37,55 @@ def normalize_code(code: str) -> str:
return ast.unparse(normalize_node(ast.parse(code)))


class AddRequestArgument(cst.CSTTransformer):
METADATA_DEPENDENCIES = (PositionProvider,)

def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
# Matcher for '@fixture' or '@pytest.fixture'
for decorator in original_node.decorators:
dec = decorator.decorator

if isinstance(dec, cst.Call):
func_name = ""
if isinstance(dec.func, cst.Attribute) and isinstance(dec.func.value, cst.Name):
if dec.func.attr.value == "fixture" and dec.func.value.value == "pytest":
func_name = "pytest.fixture"
elif isinstance(dec.func, cst.Name) and dec.func.value == "fixture":
func_name = "fixture"

if func_name:
for arg in dec.args:
if (
arg.keyword
and arg.keyword.value == "autouse"
and isinstance(arg.value, cst.Name)
and arg.value.value == "True"
):
args = updated_node.params.params
arg_names = {arg.name.value for arg in args}

# Skip if 'request' is already present
if "request" in arg_names:
return updated_node

# Create a new 'request' param
request_param = cst.Param(name=cst.Name("request"))

# Add 'request' as the first argument (after 'self' or 'cls' if needed)
if args:
first_arg = args[0].name.value
if first_arg in {"self", "cls"}:
new_params = [args[0], request_param] + list(args[1:]) # noqa: RUF005
else:
new_params = [request_param] + list(args) # noqa: RUF005
else:
new_params = [request_param]

new_param_list = updated_node.params.with_changes(params=new_params)
return updated_node.with_changes(params=new_param_list)
return updated_node


class PytestMarkAdder(cst.CSTTransformer):
"""Transformer that adds pytest marks to test functions."""

Expand Down Expand Up @@ -106,41 +155,51 @@ def _create_pytest_mark(self) -> cst.Decorator:
class AutouseFixtureModifier(cst.CSTTransformer):
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
# Matcher for '@fixture' or '@pytest.fixture'
fixture_decorator_func = m.Name("fixture") | m.Attribute(value=m.Name("pytest"), attr=m.Name("fixture"))

for decorator in original_node.decorators:
if m.matches(
decorator,
m.Decorator(
decorator=m.Call(
func=fixture_decorator_func, args=[m.Arg(value=m.Name("True"), keyword=m.Name("autouse"))]
)
),
):
# Found a matching fixture with autouse=True

# 1. The original body of the function will become the 'else' block.
# updated_node.body is an IndentedBlock, which is what cst.Else expects.
else_block = cst.Else(body=updated_node.body)

# 2. Create the new 'if' block that will exit the fixture early.
if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")')
yield_statement = cst.parse_statement("yield")
if_body = cst.IndentedBlock(body=[yield_statement])

# 3. Construct the full if/else statement.
new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block)

# 4. Replace the entire function's body with our new single statement.
return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement]))
dec = decorator.decorator

if isinstance(dec, cst.Call):
func_name = ""
if isinstance(dec.func, cst.Attribute) and isinstance(dec.func.value, cst.Name):
if dec.func.attr.value == "fixture" and dec.func.value.value == "pytest":
func_name = "pytest.fixture"
elif isinstance(dec.func, cst.Name) and dec.func.value == "fixture":
func_name = "fixture"

if func_name:
for arg in dec.args:
if (
arg.keyword
and arg.keyword.value == "autouse"
and isinstance(arg.value, cst.Name)
and arg.value.value == "True"
):
# Found a matching fixture with autouse=True

# 1. The original body of the function will become the 'else' block.
# updated_node.body is an IndentedBlock, which is what cst.Else expects.
else_block = cst.Else(body=updated_node.body)

# 2. Create the new 'if' block that will exit the fixture early.
if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")')
yield_statement = cst.parse_statement("yield")
if_body = cst.IndentedBlock(body=[yield_statement])

# 3. Construct the full if/else statement.
new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block)

# 4. Replace the entire function's body with our new single statement.
return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement]))
return updated_node


def disable_autouse(test_path: Path) -> str:
file_content = test_path.read_text(encoding="utf-8")
module = cst.parse_module(file_content)
add_request_argument = AddRequestArgument()
disable_autouse_fixture = AutouseFixtureModifier()
modified_module = module.visit(disable_autouse_fixture)
modified_module = module.visit(add_request_argument)
modified_module = modified_module.visit(disable_autouse_fixture)
test_path.write_text(modified_module.code, encoding="utf-8")
return file_content

Expand Down
Loading
Loading