diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 932053fc6..f0964aae7 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -7,7 +7,7 @@ import isort import libcst as cst -import libcst.matchers as m +from libcst.metadata import PositionProvider from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module @@ -37,6 +37,55 @@ def normalize_code(code: str) -> str: return ast.unparse(normalize_node(ast.parse(code))) +class AddRequestArgument(cst.CSTTransformer): + METADATA_DEPENDENCIES = (PositionProvider,) + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + # Matcher for '@fixture' or '@pytest.fixture' + for decorator in original_node.decorators: + dec = decorator.decorator + + if isinstance(dec, cst.Call): + func_name = "" + if isinstance(dec.func, cst.Attribute) and isinstance(dec.func.value, cst.Name): + if dec.func.attr.value == "fixture" and dec.func.value.value == "pytest": + func_name = "pytest.fixture" + elif isinstance(dec.func, cst.Name) and dec.func.value == "fixture": + func_name = "fixture" + + if func_name: + for arg in dec.args: + if ( + arg.keyword + and arg.keyword.value == "autouse" + and isinstance(arg.value, cst.Name) + and arg.value.value == "True" + ): + args = updated_node.params.params + arg_names = {arg.name.value for arg in args} + + # Skip if 'request' is already present + if "request" in arg_names: + return updated_node + + # Create a new 'request' param + request_param = cst.Param(name=cst.Name("request")) + + # Add 'request' as the first argument (after 'self' or 'cls' if needed) + if args: + first_arg = args[0].name.value + if first_arg in {"self", "cls"}: + new_params = [args[0], request_param] + list(args[1:]) # noqa: RUF005 + else: + new_params = [request_param] + list(args) # noqa: RUF005 + else: + new_params = [request_param] + + new_param_list = updated_node.params.with_changes(params=new_params) + return updated_node.with_changes(params=new_param_list) + return updated_node + + class PytestMarkAdder(cst.CSTTransformer): """Transformer that adds pytest marks to test functions.""" @@ -106,41 +155,51 @@ def _create_pytest_mark(self) -> cst.Decorator: class AutouseFixtureModifier(cst.CSTTransformer): def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # Matcher for '@fixture' or '@pytest.fixture' - fixture_decorator_func = m.Name("fixture") | m.Attribute(value=m.Name("pytest"), attr=m.Name("fixture")) - for decorator in original_node.decorators: - if m.matches( - decorator, - m.Decorator( - decorator=m.Call( - func=fixture_decorator_func, args=[m.Arg(value=m.Name("True"), keyword=m.Name("autouse"))] - ) - ), - ): - # Found a matching fixture with autouse=True - - # 1. The original body of the function will become the 'else' block. - # updated_node.body is an IndentedBlock, which is what cst.Else expects. - else_block = cst.Else(body=updated_node.body) - - # 2. Create the new 'if' block that will exit the fixture early. - if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")') - yield_statement = cst.parse_statement("yield") - if_body = cst.IndentedBlock(body=[yield_statement]) - - # 3. Construct the full if/else statement. - new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block) - - # 4. Replace the entire function's body with our new single statement. - return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement])) + dec = decorator.decorator + + if isinstance(dec, cst.Call): + func_name = "" + if isinstance(dec.func, cst.Attribute) and isinstance(dec.func.value, cst.Name): + if dec.func.attr.value == "fixture" and dec.func.value.value == "pytest": + func_name = "pytest.fixture" + elif isinstance(dec.func, cst.Name) and dec.func.value == "fixture": + func_name = "fixture" + + if func_name: + for arg in dec.args: + if ( + arg.keyword + and arg.keyword.value == "autouse" + and isinstance(arg.value, cst.Name) + and arg.value.value == "True" + ): + # Found a matching fixture with autouse=True + + # 1. The original body of the function will become the 'else' block. + # updated_node.body is an IndentedBlock, which is what cst.Else expects. + else_block = cst.Else(body=updated_node.body) + + # 2. Create the new 'if' block that will exit the fixture early. + if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")') + yield_statement = cst.parse_statement("yield") + if_body = cst.IndentedBlock(body=[yield_statement]) + + # 3. Construct the full if/else statement. + new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block) + + # 4. Replace the entire function's body with our new single statement. + return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement])) return updated_node def disable_autouse(test_path: Path) -> str: file_content = test_path.read_text(encoding="utf-8") module = cst.parse_module(file_content) + add_request_argument = AddRequestArgument() disable_autouse_fixture = AutouseFixtureModifier() - modified_module = module.visit(disable_autouse_fixture) + modified_module = module.visit(add_request_argument) + modified_module = modified_module.visit(disable_autouse_fixture) test_path.write_text(modified_module.code, encoding="utf-8") return file_content diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 1dab67c97..363dbaee4 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -1,6 +1,6 @@ from __future__ import annotations import libcst as cst -from codeflash.code_utils.code_replacer import AutouseFixtureModifier, PytestMarkAdder +from codeflash.code_utils.code_replacer import AutouseFixtureModifier, PytestMarkAdder, AddRequestArgument import dataclasses import os from collections import defaultdict @@ -2564,16 +2564,61 @@ def test_module_with_only_imports(self): class TestIntegration: - """Integration tests for both transformers working together.""" + """Integration tests for all transformers working together.""" - def test_both_transformers_together(self): - """Test that both transformers can work on the same code.""" + def test_all_transformers_together(self): + """Test that all three transformers can work on the same code.""" + source_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def my_fixture(): + yield "value" + +def test_something(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.fixture(autouse=True) +@pytest.mark.codeflash_no_autouse +def my_fixture(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + yield "value" + +@pytest.mark.codeflash_no_autouse +def test_something(): + assert True +''' + # First apply AddRequestArgument + module = cst.parse_module(source_code) + request_adder = AddRequestArgument() + modified_module = module.visit(request_adder) + + # Then apply AutouseFixtureModifier + autouse_modifier = AutouseFixtureModifier() + modified_module = modified_module.visit(autouse_modifier) + + # Finally apply PytestMarkAdder + mark_adder = PytestMarkAdder("codeflash_no_autouse") + final_module = modified_module.visit(mark_adder) + + # Compare complete strings + assert final_module.code == expected_code + + def test_transformers_with_existing_request_parameter(self): + """Test transformers when request parameter already exists.""" source_code = ''' import pytest @pytest.fixture(autouse=True) def my_fixture(request): + setup_code() yield "value" + cleanup_code() def test_something(): assert True @@ -2587,22 +2632,442 @@ def my_fixture(request): if request.node.get_closest_marker("codeflash_no_autouse"): yield else: + setup_code() yield "value" + cleanup_code() @pytest.mark.codeflash_no_autouse def test_something(): assert True ''' - # First apply AutouseFixtureModifier + # Apply all transformers in sequence module = cst.parse_module(source_code) + request_adder = AddRequestArgument() + modified_module = module.visit(request_adder) + autouse_modifier = AutouseFixtureModifier() - modified_module = module.visit(autouse_modifier) + modified_module = modified_module.visit(autouse_modifier) - # Then apply PytestMarkAdder mark_adder = PytestMarkAdder("codeflash_no_autouse") final_module = modified_module.visit(mark_adder) - code = final_module.code - # Should have both modifications - assert code==expected_code + # Compare complete strings + assert final_module.code == expected_code + + def test_transformers_with_self_parameter(self): + """Test transformers when fixture has self parameter.""" + source_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def my_fixture(self): + yield "value" + +def test_something(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.fixture(autouse=True) +@pytest.mark.codeflash_no_autouse +def my_fixture(self, request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + yield "value" + +@pytest.mark.codeflash_no_autouse +def test_something(): + assert True +''' + # Apply all transformers in sequence + module = cst.parse_module(source_code) + request_adder = AddRequestArgument() + modified_module = module.visit(request_adder) + + autouse_modifier = AutouseFixtureModifier() + modified_module = modified_module.visit(autouse_modifier) + + mark_adder = PytestMarkAdder("codeflash_no_autouse") + final_module = modified_module.visit(mark_adder) + + # Compare complete strings + assert final_module.code == expected_code + + def test_transformers_with_multiple_fixtures(self): + """Test transformers with multiple autouse fixtures.""" + source_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def fixture_one(): + yield "one" + +@pytest.fixture(autouse=True) +def fixture_two(self, param): + yield "two" + +@pytest.fixture +def regular_fixture(): + return "regular" + +def test_something(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.fixture(autouse=True) +@pytest.mark.codeflash_no_autouse +def fixture_one(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + yield "one" + +@pytest.fixture(autouse=True) +@pytest.mark.codeflash_no_autouse +def fixture_two(self, request, param): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + yield "two" + +@pytest.fixture +@pytest.mark.codeflash_no_autouse +def regular_fixture(): + return "regular" + +@pytest.mark.codeflash_no_autouse +def test_something(): + assert True +''' + # Apply all transformers in sequence + module = cst.parse_module(source_code) + request_adder = AddRequestArgument() + modified_module = module.visit(request_adder) + + autouse_modifier = AutouseFixtureModifier() + modified_module = modified_module.visit(autouse_modifier) + + mark_adder = PytestMarkAdder("codeflash_no_autouse") + final_module = modified_module.visit(mark_adder) + + # Compare complete strings + assert final_module.code == expected_code + + + + +class TestAddRequestArgument: + """Test cases for AddRequestArgument transformer.""" + + def test_adds_request_to_autouse_fixture_no_existing_args(self): + """Test adding request argument to autouse fixture with no existing arguments.""" + source_code = ''' +@fixture(autouse=True) +def my_fixture(): + pass +''' + expected = ''' +@fixture(autouse=True) +def my_fixture(request): + pass +''' + + module = cst.parse_module(source_code) + transformer = AddRequestArgument() + modified_module = module.visit(transformer) + + assert modified_module.code.strip() == expected.strip() + + def test_adds_request_to_pytest_fixture_autouse(self): + """Test adding request argument to pytest.fixture with autouse=True.""" + source_code = ''' +@pytest.fixture(autouse=True) +def my_fixture(): + pass +''' + expected = ''' +@pytest.fixture(autouse=True) +def my_fixture(request): + pass +''' + + module = cst.parse_module(source_code) + transformer = AddRequestArgument() + modified_module = module.visit(transformer) + + assert modified_module.code.strip() == expected.strip() + + def test_adds_request_after_self_parameter(self): + """Test adding request argument after self parameter.""" + source_code = ''' +@fixture(autouse=True) +def my_fixture(self): + pass +''' + expected = ''' +@fixture(autouse=True) +def my_fixture(self, request): + pass +''' + + module = cst.parse_module(source_code) + transformer = AddRequestArgument() + modified_module = module.visit(transformer) + + assert modified_module.code.strip() == expected.strip() + + def test_adds_request_after_cls_parameter(self): + """Test adding request argument after cls parameter.""" + source_code = ''' +@fixture(autouse=True) +def my_fixture(cls): + pass +''' + expected = ''' +@fixture(autouse=True) +def my_fixture(cls, request): + pass +''' + + module = cst.parse_module(source_code) + transformer = AddRequestArgument() + modified_module = module.visit(transformer) + + assert modified_module.code.strip() == expected.strip() + + def test_adds_request_before_other_parameters(self): + """Test adding request argument before other parameters (not self/cls).""" + source_code = ''' +@fixture(autouse=True) +def my_fixture(param1, param2): + pass +''' + expected = ''' +@fixture(autouse=True) +def my_fixture(request, param1, param2): + pass +''' + + module = cst.parse_module(source_code) + transformer = AddRequestArgument() + modified_module = module.visit(transformer) + + assert modified_module.code.strip() == expected.strip() + + def test_adds_request_after_self_with_other_parameters(self): + """Test adding request argument after self with other parameters.""" + source_code = ''' +@fixture(autouse=True) +def my_fixture(self, param1, param2): + pass +''' + expected = ''' +@fixture(autouse=True) +def my_fixture(self, request, param1, param2): + pass +''' + + module = cst.parse_module(source_code) + transformer = AddRequestArgument() + modified_module = module.visit(transformer) + + assert modified_module.code.strip() == expected.strip() + + def test_skips_when_request_already_present(self): + """Test that request argument is not added when already present.""" + source_code = ''' +@fixture(autouse=True) +def my_fixture(request): + pass +''' + expected = ''' +@fixture(autouse=True) +def my_fixture(request): + pass +''' + + module = cst.parse_module(source_code) + transformer = AddRequestArgument() + modified_module = module.visit(transformer) + + assert modified_module.code.strip() == expected.strip() + + def test_skips_when_request_present_with_other_args(self): + """Test that request argument is not added when already present with other args.""" + source_code = ''' +@fixture(autouse=True) +def my_fixture(self, request, param1): + pass +''' + expected = ''' +@fixture(autouse=True) +def my_fixture(self, request, param1): + pass +''' + + module = cst.parse_module(source_code) + transformer = AddRequestArgument() + modified_module = module.visit(transformer) + + assert modified_module.code.strip() == expected.strip() + + def test_ignores_non_autouse_fixture(self): + """Test that non-autouse fixtures are not modified.""" + source_code = ''' +@fixture +def my_fixture(): + pass +''' + expected = ''' +@fixture +def my_fixture(): + pass +''' + + module = cst.parse_module(source_code) + transformer = AddRequestArgument() + modified_module = module.visit(transformer) + + assert modified_module.code.strip() == expected.strip() + + def test_ignores_fixture_with_autouse_false(self): + """Test that fixtures with autouse=False are not modified.""" + source_code = ''' +@fixture(autouse=False) +def my_fixture(): + pass +''' + expected = ''' +@fixture(autouse=False) +def my_fixture(): + pass +''' + + module = cst.parse_module(source_code) + transformer = AddRequestArgument() + modified_module = module.visit(transformer) + + assert modified_module.code.strip() == expected.strip() + + def test_ignores_regular_function(self): + """Test that regular functions are not modified.""" + source_code = ''' +def my_function(): + pass +''' + expected = ''' +def my_function(): + pass +''' + + module = cst.parse_module(source_code) + transformer = AddRequestArgument() + modified_module = module.visit(transformer) + + assert modified_module.code.strip() == expected.strip() + + def test_handles_multiple_autouse_fixtures(self): + """Test handling multiple autouse fixtures in the same module.""" + source_code = ''' +@fixture(autouse=True) +def fixture1(): + pass + +@pytest.fixture(autouse=True) +def fixture2(self): + pass + +@fixture(autouse=True) +def fixture3(request): + pass +''' + expected = ''' +@fixture(autouse=True) +def fixture1(request): + pass + +@pytest.fixture(autouse=True) +def fixture2(self, request): + pass + +@fixture(autouse=True) +def fixture3(request): + pass +''' + + module = cst.parse_module(source_code) + transformer = AddRequestArgument() + modified_module = module.visit(transformer) + + assert modified_module.code.strip() == expected.strip() + + def test_handles_fixture_with_other_decorators(self): + """Test handling fixture with other decorators.""" + source_code = ''' +@some_decorator +@fixture(autouse=True) +@another_decorator +def my_fixture(): + pass +''' + expected = ''' +@some_decorator +@fixture(autouse=True) +@another_decorator +def my_fixture(request): + pass +''' + + module = cst.parse_module(source_code) + transformer = AddRequestArgument() + modified_module = module.visit(transformer) + + assert modified_module.code.strip() == expected.strip() + + def test_preserves_function_body_and_docstring(self): + """Test that function body and docstring are preserved.""" + source_code = ''' +@fixture(autouse=True) +def my_fixture(): + """This is a docstring.""" + x = 1 + y = 2 + return x + y +''' + expected = ''' +@fixture(autouse=True) +def my_fixture(request): + """This is a docstring.""" + x = 1 + y = 2 + return x + y +''' + + module = cst.parse_module(source_code) + transformer = AddRequestArgument() + modified_module = module.visit(transformer) + + assert modified_module.code.strip() == expected.strip() + + def test_handles_fixture_with_additional_arguments(self): + """Test handling fixture with additional keyword arguments.""" + source_code = ''' +@fixture(autouse=True, scope="session") +def my_fixture(): + pass +''' + expected = ''' +@fixture(autouse=True, scope="session") +def my_fixture(request): + pass +''' + + module = cst.parse_module(source_code) + transformer = AddRequestArgument() + modified_module = module.visit(transformer) + assert modified_module.code.strip() == expected.strip()