diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index 3a6f7dba2..5edff57a0 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -123,6 +123,7 @@ def process_pyproject_config(args: Namespace) -> Namespace: "disable_telemetry", "disable_imports_sorting", "git_remote", + "override_fixtures", ] for key in supported_keys: if key in pyproject_config and ( diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index eb367bdfa..932053fc6 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -5,10 +5,14 @@ from functools import lru_cache from typing import TYPE_CHECKING, Optional, TypeVar +import isort import libcst as cst +import libcst.matchers as m from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module +from codeflash.code_utils.config_parser import find_conftest_files +from codeflash.code_utils.line_profile_utils import ImportAdder from codeflash.models.models import FunctionParent if TYPE_CHECKING: @@ -33,6 +37,142 @@ def normalize_code(code: str) -> str: return ast.unparse(normalize_node(ast.parse(code))) +class PytestMarkAdder(cst.CSTTransformer): + """Transformer that adds pytest marks to test functions.""" + + def __init__(self, mark_name: str) -> None: + super().__init__() + self.mark_name = mark_name + self.has_pytest_import = False + + def visit_Module(self, node: cst.Module) -> None: + """Check if pytest is already imported.""" + for statement in node.body: + if isinstance(statement, cst.SimpleStatementLine): + for stmt in statement.body: + if isinstance(stmt, cst.Import): + for import_alias in stmt.names: + if isinstance(import_alias, cst.ImportAlias) and import_alias.name.value == "pytest": + self.has_pytest_import = True + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002 + """Add pytest import if not present.""" + if not self.has_pytest_import: + # Create import statement + import_stmt = cst.SimpleStatementLine(body=[cst.Import(names=[cst.ImportAlias(name=cst.Name("pytest"))])]) + # Add import at the beginning + updated_node = updated_node.with_changes(body=[import_stmt, *updated_node.body]) + return updated_node + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002 + """Add pytest mark to test functions.""" + # Check if the mark already exists + for decorator in updated_node.decorators: + if self._is_pytest_mark(decorator.decorator, self.mark_name): + return updated_node + + # Create the pytest mark decorator + mark_decorator = self._create_pytest_mark() + + # Add the decorator + new_decorators = [*list(updated_node.decorators), mark_decorator] + return updated_node.with_changes(decorators=new_decorators) + + def _is_pytest_mark(self, decorator: cst.BaseExpression, mark_name: str) -> bool: + """Check if a decorator is a specific pytest mark.""" + if isinstance(decorator, cst.Attribute): + if ( + isinstance(decorator.value, cst.Attribute) + and isinstance(decorator.value.value, cst.Name) + and decorator.value.value.value == "pytest" + and decorator.value.attr.value == "mark" + and decorator.attr.value == mark_name + ): + return True + elif isinstance(decorator, cst.Call) and isinstance(decorator.func, cst.Attribute): + return self._is_pytest_mark(decorator.func, mark_name) + return False + + def _create_pytest_mark(self) -> cst.Decorator: + """Create a pytest mark decorator.""" + # Base: pytest.mark.{mark_name} + mark_attr = cst.Attribute( + value=cst.Attribute(value=cst.Name("pytest"), attr=cst.Name("mark")), attr=cst.Name(self.mark_name) + ) + decorator = mark_attr + return cst.Decorator(decorator=decorator) + + +class AutouseFixtureModifier(cst.CSTTransformer): + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + # Matcher for '@fixture' or '@pytest.fixture' + fixture_decorator_func = m.Name("fixture") | m.Attribute(value=m.Name("pytest"), attr=m.Name("fixture")) + + for decorator in original_node.decorators: + if m.matches( + decorator, + m.Decorator( + decorator=m.Call( + func=fixture_decorator_func, args=[m.Arg(value=m.Name("True"), keyword=m.Name("autouse"))] + ) + ), + ): + # Found a matching fixture with autouse=True + + # 1. The original body of the function will become the 'else' block. + # updated_node.body is an IndentedBlock, which is what cst.Else expects. + else_block = cst.Else(body=updated_node.body) + + # 2. Create the new 'if' block that will exit the fixture early. + if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")') + yield_statement = cst.parse_statement("yield") + if_body = cst.IndentedBlock(body=[yield_statement]) + + # 3. Construct the full if/else statement. + new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block) + + # 4. Replace the entire function's body with our new single statement. + return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement])) + return updated_node + + +def disable_autouse(test_path: Path) -> str: + file_content = test_path.read_text(encoding="utf-8") + module = cst.parse_module(file_content) + disable_autouse_fixture = AutouseFixtureModifier() + modified_module = module.visit(disable_autouse_fixture) + test_path.write_text(modified_module.code, encoding="utf-8") + return file_content + + +def modify_autouse_fixture(test_paths: list[Path]) -> dict[Path, list[str]]: + # find fixutre definition in conftetst.py (the one closest to the test) + # get fixtures present in override-fixtures in pyproject.toml + # add if marker closest return + file_content_map = {} + conftest_files = find_conftest_files(test_paths) + for cf_file in conftest_files: + # iterate over all functions in the file + # if function has autouse fixture, modify function to bypass with custom marker + original_content = disable_autouse(cf_file) + file_content_map[cf_file] = original_content + return file_content_map + + +# # reuse line profiler utils to add decorator and import to test fns +def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None: + for test_path in test_paths: + # read file + file_content = test_path.read_text(encoding="utf-8") + module = cst.parse_module(file_content) + importadder = ImportAdder("import pytest") + modified_module = module.visit(importadder) + modified_module = cst.parse_module(isort.code(modified_module.code, float_to_top=True)) + pytest_mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = modified_module.visit(pytest_mark_adder) + test_path.write_text(modified_module.code, encoding="utf-8") + + class OptimFunctionCollector(cst.CSTVisitor): METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,) diff --git a/codeflash/code_utils/code_utils.py b/codeflash/code_utils/code_utils.py index 5fc9bd9e9..6a9de176b 100644 --- a/codeflash/code_utils/code_utils.py +++ b/codeflash/code_utils/code_utils.py @@ -208,3 +208,8 @@ def cleanup_paths(paths: list[Path]) -> None: shutil.rmtree(path, ignore_errors=True) else: path.unlink(missing_ok=True) + + +def restore_conftest(path_to_content_map: dict[Path, str]) -> None: + for path, file_content in path_to_content_map.items(): + path.write_text(file_content, encoding="utf8") diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index 7b6243a75..13813cfc1 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -31,6 +31,21 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path: raise ValueError(msg) +def find_conftest_files(test_paths: list[Path]) -> list[Path]: + list_of_conftest_files = set() + for test_path in test_paths: + # Find the conftest file on the root of the project + dir_path = Path.cwd() + cur_path = test_path + while cur_path != dir_path: + config_file = cur_path / "conftest.py" + if config_file.exists(): + list_of_conftest_files.add(config_file) + # Search for conftest.py in the parent directories + cur_path = cur_path.parent + return list(list_of_conftest_files) + + def parse_config_file( config_file_path: Path | None = None, override_formatter_check: bool = False, # noqa: FBT001, FBT002 @@ -56,7 +71,12 @@ def parse_config_file( path_keys = ["module-root", "tests-root", "benchmarks-root"] path_list_keys = ["ignore-paths"] str_keys = {"pytest-cmd": "pytest", "git-remote": "origin"} - bool_keys = {"disable-telemetry": False, "disable-imports-sorting": False, "benchmark": False} + bool_keys = { + "override-fixtures": False, + "disable-telemetry": False, + "disable-imports-sorting": False, + "benchmark": False, + } list_str_keys = {"formatter-cmds": ["black $file"]} for key, default_value in str_keys.items(): diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index caa1daf7d..5922d6c1c 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -21,7 +21,11 @@ from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, progress_bar from codeflash.code_utils import env_utils -from codeflash.code_utils.code_replacer import replace_function_definitions_in_module +from codeflash.code_utils.code_replacer import ( + add_custom_marker_to_all_tests, + modify_autouse_fixture, + replace_function_definitions_in_module, +) from codeflash.code_utils.code_utils import ( ImportErrorPattern, cleanup_paths, @@ -29,6 +33,7 @@ get_run_tmp_file, has_any_async_functions, module_name_from_file_path, + restore_conftest, ) from codeflash.code_utils.config_consts import ( INDIVIDUAL_TESTCASE_TIMEOUT, @@ -212,6 +217,11 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 for key in set(self.function_to_tests) | set(function_to_concolic_tests) } instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests) + if self.args.override_fixtures: + logger.info("Disabling all autouse fixtures associated with the generated test files") + original_conftest_content = modify_autouse_fixture(generated_test_paths + generated_perf_test_paths) + logger.info("Add custom marker to generated test files") + add_custom_marker_to_all_tests(generated_test_paths + generated_perf_test_paths) # Get a dict of file_path_to_classes of fto and helpers_of_fto file_path_to_helper_classes = defaultdict(set) @@ -234,6 +244,8 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 ) if not is_successful(baseline_result): + if self.args.override_fixtures: + restore_conftest(original_conftest_content) cleanup_paths(paths_to_cleanup) return Failure(baseline_result.failure()) @@ -241,6 +253,8 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 if isinstance(original_code_baseline, OriginalCodeBaseline) and not coverage_critic( original_code_baseline.coverage_results, self.args.test_framework ): + if self.args.override_fixtures: + restore_conftest(original_conftest_content) cleanup_paths(paths_to_cleanup) return Failure("The threshold for test coverage was not met.") # request for new optimizations but don't block execution, check for completion later @@ -356,6 +370,8 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 ) self.log_successful_optimization(explanation, generated_tests, exp_type) + if self.args.override_fixtures: + restore_conftest(original_conftest_content) if not best_optimization: return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}") return Success(best_optimization) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 2b7722fd7..1dab67c97 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -1,5 +1,6 @@ from __future__ import annotations - +import libcst as cst +from codeflash.code_utils.code_replacer import AutouseFixtureModifier, PytestMarkAdder import dataclasses import os from collections import defaultdict @@ -819,7 +820,7 @@ def main_method(self): ) func_optimizer = FunctionOptimizer(function_to_optimize=func_top_optimize, test_cfg=test_config) code_context = func_optimizer.get_code_optimization_context().unwrap() - assert code_context.testgen_context_code == get_code_output + assert code_context.testgen_context_code.rstrip() == get_code_output.rstrip() def test_code_replacement11() -> None: @@ -2134,3 +2135,474 @@ def new_function2(value): new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) assert new_code.rstrip() == expected_code.rstrip() + + +class TestAutouseFixtureModifier: + """Test cases for AutouseFixtureModifier class.""" + + def test_modifies_autouse_fixture_with_pytest_decorator(self): + """Test that autouse fixture with @pytest.fixture is modified correctly.""" + source_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def my_fixture(request): + print("setup") + yield + print("teardown") +''' + expected_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def my_fixture(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + print("setup") + yield + print("teardown") +''' + module = cst.parse_module(source_code) + modifier = AutouseFixtureModifier() + modified_module = module.visit(modifier) + + # Parse expected to normalize formatting + expected_module = cst.parse_module(expected_code) + assert modified_module.code.strip() == expected_module.code.strip() + + def test_modifies_autouse_fixture_with_fixture_decorator(self): + """Test that autouse fixture with @fixture is modified correctly.""" + source_code = ''' +from pytest import fixture + +@fixture(autouse=True) +def my_fixture(request): + setup_code() + yield "value" + cleanup_code() +''' + expected_code = ''' +from pytest import fixture + +@fixture(autouse=True) +def my_fixture(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + setup_code() + yield "value" + cleanup_code() +''' + module = cst.parse_module(source_code) + modifier = AutouseFixtureModifier() + modified_module = module.visit(modifier) + + # Check that the if statement was added + assert modified_module.code.strip() == expected_code.strip() + + def test_ignores_non_autouse_fixture(self): + """Test that non-autouse fixtures are not modified.""" + source_code = ''' +import pytest + +@pytest.fixture +def my_fixture(request): + return "test_value" + +@pytest.fixture(scope="session") +def session_fixture(): + return "session_value" +''' + module = cst.parse_module(source_code) + modifier = AutouseFixtureModifier() + modified_module = module.visit(modifier) + + # Code should remain unchanged + assert modified_module.code == source_code + + def test_ignores_regular_functions(self): + """Test that regular functions are not modified.""" + source_code = ''' +def regular_function(): + return "not a fixture" + +@some_other_decorator +def decorated_function(): + return "also not a fixture" +''' + module = cst.parse_module(source_code) + modifier = AutouseFixtureModifier() + modified_module = module.visit(modifier) + + # Code should remain unchanged + assert modified_module.code == source_code + + def test_handles_multiple_autouse_fixtures(self): + """Test that multiple autouse fixtures in the same file are all modified.""" + source_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def fixture_one(request): + yield "one" + +@pytest.fixture(autouse=True) +def fixture_two(request): + yield "two" +''' + expected_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def fixture_one(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + yield "one" + +@pytest.fixture(autouse=True) +def fixture_two(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + yield "two" +''' + module = cst.parse_module(source_code) + modifier = AutouseFixtureModifier() + modified_module = module.visit(modifier) + + # Both fixtures should be modified + code = modified_module.code + assert code==expected_code + + def test_preserves_fixture_with_complex_body(self): + """Test that fixtures with complex bodies are handled correctly.""" + source_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def complex_fixture(request): + try: + setup_database() + configure_logging() + yield get_test_client() + finally: + cleanup_database() + reset_logging() +''' + expected_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def complex_fixture(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + try: + setup_database() + configure_logging() + yield get_test_client() + finally: + cleanup_database() + reset_logging() +''' + module = cst.parse_module(source_code) + modifier = AutouseFixtureModifier() + modified_module = module.visit(modifier) + + code = modified_module.code + assert code.rstrip()==expected_code.rstrip() + + +class TestPytestMarkAdder: + """Test cases for PytestMarkAdder class.""" + + def test_adds_pytest_import_when_missing(self): + """Test that pytest import is added when not present.""" + source_code = ''' +def test_something(): + assert True +''' + expected_code = ''' +import pytest +@pytest.mark.codeflash_no_autouse +def test_something(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + assert code==expected_code + + def test_skips_pytest_import_when_present(self): + """Test that pytest import is not duplicated when already present.""" + source_code = ''' +import pytest + +def test_something(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.mark.codeflash_no_autouse +def test_something(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + # Should only have one import pytest line + assert code==expected_code + + def test_handles_from_pytest_import(self): + """Test that existing 'from pytest import ...' is recognized.""" + source_code = ''' +from pytest import fixture + +def test_something(): + assert True +''' + expected_code = ''' +import pytest +from pytest import fixture + +@pytest.mark.codeflash_no_autouse +def test_something(): + assert True + ''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + # Should not add import pytest since pytest is already imported + assert code.strip()==expected_code.strip() + + def test_adds_mark_to_all_functions(self): + """Test that marks are added to all functions in the module.""" + source_code = ''' +import pytest + +def test_first(): + assert True + +def test_second(): + assert False + +def helper_function(): + return "not a test" +''' + expected_code = ''' +import pytest + +@pytest.mark.codeflash_no_autouse +def test_first(): + assert True + +@pytest.mark.codeflash_no_autouse +def test_second(): + assert False + +@pytest.mark.codeflash_no_autouse +def helper_function(): + return "not a test" +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + # All functions should get the mark + assert code==expected_code + + def test_skips_existing_mark(self): + """Test that existing marks are not duplicated.""" + source_code = ''' +import pytest + +@pytest.mark.codeflash_no_autouse +def test_already_marked(): + assert True + +def test_needs_mark(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.mark.codeflash_no_autouse +def test_already_marked(): + assert True + +@pytest.mark.codeflash_no_autouse +def test_needs_mark(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + # Should have exactly 2 marks total (one existing, one added) + assert code==expected_code + + def test_handles_different_mark_names(self): + """Test that different mark names work correctly.""" + source_code = ''' +import pytest + +def test_something(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.mark.slow +def test_something(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("slow") + modified_module = module.visit(mark_adder) + + code = modified_module.code + assert code==expected_code + + def test_preserves_existing_decorators(self): + """Test that existing decorators are preserved.""" + source_code = ''' +import pytest + +@pytest.mark.parametrize("value", [1, 2, 3]) +@pytest.fixture +def test_with_decorators(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.mark.parametrize("value", [1, 2, 3]) +@pytest.fixture +@pytest.mark.codeflash_no_autouse +def test_with_decorators(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + assert code==expected_code + + def test_handles_call_style_existing_marks(self): + """Test recognition of existing marks in call style (with parentheses).""" + source_code = ''' +import pytest + +@pytest.mark.codeflash_no_autouse() +def test_with_call_mark(): + assert True + +def test_needs_mark(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.mark.codeflash_no_autouse() +def test_with_call_mark(): + assert True + +@pytest.mark.codeflash_no_autouse +def test_needs_mark(): + assert True +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + # Should recognize the existing call-style mark and not duplicate + assert code==expected_code + + def test_empty_module(self): + """Test handling of empty module.""" + source_code = '' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + # Should just add the import + code = modified_module.code + assert code =='import pytest' + + def test_module_with_only_imports(self): + """Test handling of module with only imports.""" + source_code = ''' +import os +import sys +from pathlib import Path +''' + expected_code = ''' +import pytest +import os +import sys +from pathlib import Path +''' + module = cst.parse_module(source_code) + mark_adder = PytestMarkAdder("codeflash_no_autouse") + modified_module = module.visit(mark_adder) + + code = modified_module.code + assert code==expected_code + + +class TestIntegration: + """Integration tests for both transformers working together.""" + + def test_both_transformers_together(self): + """Test that both transformers can work on the same code.""" + source_code = ''' +import pytest + +@pytest.fixture(autouse=True) +def my_fixture(request): + yield "value" + +def test_something(): + assert True +''' + expected_code = ''' +import pytest + +@pytest.fixture(autouse=True) +@pytest.mark.codeflash_no_autouse +def my_fixture(request): + if request.node.get_closest_marker("codeflash_no_autouse"): + yield + else: + yield "value" + +@pytest.mark.codeflash_no_autouse +def test_something(): + assert True +''' + # First apply AutouseFixtureModifier + module = cst.parse_module(source_code) + autouse_modifier = AutouseFixtureModifier() + modified_module = module.visit(autouse_modifier) + + # Then apply PytestMarkAdder + mark_adder = PytestMarkAdder("codeflash_no_autouse") + final_module = modified_module.visit(mark_adder) + + code = final_module.code + # Should have both modifications + assert code==expected_code +