diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 932053fc6..17ba821b1 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -3,7 +3,7 @@ import ast from collections import defaultdict from functools import lru_cache -from typing import TYPE_CHECKING, Optional, TypeVar +from typing import TYPE_CHECKING, Optional, TypeVar, Union import isort import libcst as cst @@ -16,6 +16,7 @@ from codeflash.models.models import FunctionParent if TYPE_CHECKING: + from _ast import AST from pathlib import Path from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -24,6 +25,116 @@ ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST) +class BenchmarkFunctionRemover(ast.NodeTransformer): + """AST transformer that removes functions using pytest-benchmark fixture.""" + + def _uses_benchmark_fixture(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]) -> bool: + """Check if a function uses the benchmark fixture.""" + # Check function arguments for 'benchmark' parameter + for arg in node.args.args: + if arg.arg == "benchmark": + return True + + # Check for pytest markers that might indicate benchmarking + for decorator in node.decorator_list: + if self._is_benchmark_marker(decorator): + return True + + # Check function body for benchmark usage + return any(isinstance(stmt, ast.Call) and self._is_benchmark_call(stmt) for stmt in ast.walk(node)) + + @staticmethod + def _is_benchmark_marker(decorator: ast.expr) -> bool: + """Check if decorator is a benchmark-related pytest marker.""" + if isinstance(decorator, ast.Call): + if isinstance(decorator.func, ast.Attribute): + # Check for @pytest.mark.benchmark + if ( + isinstance(decorator.func.value, ast.Attribute) + and isinstance(decorator.func.value.value, ast.Name) + and decorator.func.value.value.id == "pytest" + and decorator.func.value.attr == "mark" + and decorator.func.attr == "benchmark" + ): + return True + elif isinstance(decorator.func, ast.Name) and decorator.func.id == "benchmark": + return True + elif isinstance(decorator, ast.Attribute): + # Check for @pytest.mark.benchmark (without call) + if ( + isinstance(decorator.value, ast.Attribute) + and isinstance(decorator.value.value, ast.Name) + and decorator.value.value.id == "pytest" + and decorator.value.attr == "mark" + and decorator.attr == "benchmark" + ): + return True + elif isinstance(decorator, ast.Name) and decorator.id == "benchmark": + return True + + return False + + @staticmethod + def _is_benchmark_call(call: ast.Call) -> bool: + """Check if a call is using the benchmark fixture.""" + if isinstance(call.func, ast.Name) and call.func.id == "benchmark": + return True + return bool( + isinstance(call.func, ast.Attribute) + and call.func.attr in ["benchmark", "__call__"] + and isinstance(call.func.value, ast.Name) + and call.func.value.id == "benchmark" + ) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> Optional[AST]: + """Visit function definitions and remove if they use benchmark fixture.""" + if self._uses_benchmark_fixture(node): + return None # Remove the function + return self.generic_visit(node) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Optional[AST]: + """Visit async function definitions and remove if they use benchmark fixture.""" + if self._uses_benchmark_fixture(node): + return None # Remove the function + return self.generic_visit(node) + + def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: + """Visit class definitions and remove benchmark methods.""" + original_body = node.body[:] + new_body = [] + + for item in original_body: + if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)): + if not self._uses_benchmark_fixture(item): + new_body.append(self.visit(item)) + + else: + new_body.append(self.visit(item)) + + node.body = new_body + return node + + +def remove_benchmark_functions(tree: AST) -> AST: + """Remove benchmark functions from Python source code. + + Args: + tree: Python source code as ast module + + Returns: + Tuple of (modified_source_code, set_of_removed_function_names) + + """ + try: + # Create and apply the transformer + remover = BenchmarkFunctionRemover() + return remover.visit(tree) + + except Exception as e: + print(f"Error processing code: {e}") + return tree + + def normalize_node(node: ASTNodeT) -> ASTNodeT: if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and ast.get_docstring(node): node.body = node.body[1:] diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index 6eac52809..33a011d8b 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -7,6 +7,7 @@ import isort from codeflash.cli_cmds.console import logger +from codeflash.code_utils.code_replacer import remove_benchmark_functions from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import FunctionParent, TestingMode, VerificationType @@ -355,6 +356,8 @@ def inject_profiling_into_existing_test( if test_framework == "unittest": new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")])) tree.body = [*new_imports, create_wrapper_function(mode), *tree.body] + # remove benchmark functions + tree = remove_benchmark_functions(tree) return True, isort.code(ast.unparse(tree), float_to_top=True) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 1dab67c97..aa098a435 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -1,6 +1,7 @@ from __future__ import annotations import libcst as cst -from codeflash.code_utils.code_replacer import AutouseFixtureModifier, PytestMarkAdder +from codeflash.code_utils.code_replacer import AutouseFixtureModifier, PytestMarkAdder, BenchmarkFunctionRemover, \ + remove_benchmark_functions import dataclasses import os from collections import defaultdict @@ -2606,3 +2607,378 @@ def test_something(): # Should have both modifications assert code==expected_code + +import ast +import pytest +from typing import Set + + +class TestBenchmarkFunctionRemover: + """Test cases for BenchmarkFunctionRemover class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.remover = BenchmarkFunctionRemover() + + def test_removes_function_with_benchmark_parameter(self): + """Test that functions with 'benchmark' parameter are removed.""" + source = """ +def test_performance(benchmark): + result = benchmark(some_function) + assert result is not None + +def test_normal(): + assert True +""" + expected = """def test_normal(): + assert True""" + tree = ast.parse(source) + result = self.remover.visit(tree) + + assert ast.unparse(result) == expected + + def test_removes_async_function_with_benchmark_parameter(self): + """Test that async functions with 'benchmark' parameter are removed.""" + source = """ +async def test_async_performance(benchmark): + result = benchmark(some_async_function) + assert result is not None + +async def test_async_normal(): + assert True +""" + expected = """async def test_async_normal(): + assert True""" + tree = ast.parse(source) + result = self.remover.visit(tree) + + # Should only have one async function left + assert ast.unparse(result) == expected + + def test_removes_function_with_pytest_mark_benchmark_decorator(self): + """Test that functions with @pytest.mark.benchmark decorator are removed.""" + source = """ +import pytest + +@pytest.mark.benchmark +def test_with_benchmark_marker(): + pass + +def test_normal(): + pass +""" + expected = """import pytest + +def test_normal(): + pass""" + tree = ast.parse(source) + result = self.remover.visit(tree) + + assert ast.unparse(result) == expected + + def test_removes_function_with_benchmark_decorator_call(self): + """Test that functions with @pytest.mark.benchmark() decorator are removed.""" + source = """ +import pytest + +@pytest.mark.benchmark() +def test_with_benchmark_marker_call(): + pass + +@pytest.mark.parametrize("x", [1, 2, 3]) +def test_normal_with_marker(): + pass +""" + expected = """import pytest + +@pytest.mark.parametrize('x', [1, 2, 3]) +def test_normal_with_marker(): + pass""" + tree = ast.parse(source) + result = self.remover.visit(tree) + + assert ast.unparse(result) == expected + + def test_removes_function_with_simple_benchmark_decorator(self): + """Test that functions with @benchmark decorator are removed.""" + source = """ +@benchmark +def test_simple_benchmark(): + pass + +def test_normal(): + pass +""" + expected = """def test_normal(): + pass""" + tree = ast.parse(source) + result = self.remover.visit(tree) + + assert ast.unparse(result) == expected + + def test_removes_function_with_benchmark_call_in_body(self): + """Test that functions calling benchmark() in body are removed.""" + source = """ +def test_with_benchmark_call(): + result = benchmark(some_function) + assert result + +def test_normal(): + some_other_function() + assert True +""" + expected = """def test_normal(): + some_other_function() + assert True""" + tree = ast.parse(source) + result = self.remover.visit(tree) + + assert ast.dump(result) == ast.dump(ast.parse(expected)) + + def test_removes_benchmark_methods_from_class(self): + """Test that benchmark methods are removed from classes.""" + source = """ +class TestClass: + def test_normal_method(self): + assert True + + def test_benchmark_method(self, benchmark): + result = benchmark(some_function) + assert result + + @pytest.mark.benchmark + def test_decorated_benchmark(self): + pass +""" + expected = """class TestClass:\n \n def test_normal_method(self):\n assert True""" + tree = ast.parse(source) + result = self.remover.visit(tree) + + assert ast.dump(result) == ast.dump(ast.parse(expected)) + + def test_preserves_non_benchmark_functions(self): + """Test that non-benchmark functions are preserved.""" + source = """ +def test_normal_function(): + assert True + +def helper_function(param1, param2): + return param1 + param2 + +@pytest.mark.parametrize("x", [1, 2, 3]) +def test_parametrized(x): + assert x > 0 +""" + expected = """ +def test_normal_function(): + assert True + +def helper_function(param1, param2): + return param1 + param2 + +@pytest.mark.parametrize("x", [1, 2, 3]) +def test_parametrized(x): + assert x > 0 +""" + tree = ast.parse(source) + + result = self.remover.visit(tree) + + assert ast.dump(result) == ast.dump(ast.parse(expected)) + + def test_handles_empty_class(self): + """Test handling of classes that become empty after removing benchmark methods.""" + source = """ +class TestBenchmarks: + @pytest.mark.benchmark + def test_only_benchmark(self): + pass +""" + expected = """class TestBenchmarks:""" + tree = ast.parse(source) + result = self.remover.visit(tree) + + assert ast.unparse(result) == expected + + def test_handles_mixed_decorators(self): + """Test functions with multiple decorators including benchmark.""" + source = """ +@pytest.mark.parametrize("x", [1, 2]) +@pytest.mark.benchmark +def test_multiple_decorators(x): + pass + +@pytest.mark.parametrize("y", [3, 4]) +def test_normal_with_decorator(y): + pass +""" + expected = """@pytest.mark.parametrize('y', [3, 4]) +def test_normal_with_decorator(y): + pass""" + tree = ast.parse(source) + result = self.remover.visit(tree) + + assert ast.unparse(result) == expected + + +class TestRemoveBenchmarkFunctions: + """Test cases for the remove_benchmark_functions function.""" + + def test_remove_benchmark_functions_success(self): + """Test successful removal of benchmark functions.""" + source = """ +def test_normal(): + assert True + +def test_benchmark(benchmark): + result = benchmark(some_function) + assert result +""" + expected = """ +def test_normal(): + assert True +""" + tree = ast.parse(source) + result = remove_benchmark_functions(tree) + + assert ast.dump(result) == ast.dump(ast.parse(expected)) + + def test_remove_benchmark_functions_handles_exception(self, capsys): + """Test that exceptions are handled gracefully.""" + # Create a malformed tree that might cause issues + tree = ast.parse("def test(): pass") + + # Mock the BenchmarkFunctionRemover to raise an exception + original_visit = BenchmarkFunctionRemover.visit + + def mock_visit(self, node): + raise ValueError("Test exception") + + BenchmarkFunctionRemover.visit = mock_visit + + try: + result = remove_benchmark_functions(tree) + # Should return original tree on exception + assert result == tree + + # Check that error was printed + captured = capsys.readouterr() + assert "Error processing code: Test exception" in captured.out + finally: + # Restore original method + BenchmarkFunctionRemover.visit = original_visit + + def test_remove_benchmark_functions_with_complex_code(self): + """Test with more complex code structure.""" + source = """ +import pytest +from some_module import some_function + +class TestPerformance: + def setup_method(self): + self.data = [1, 2, 3, 4, 5] + + def test_normal_operation(self): + assert len(self.data) == 5 + + @pytest.mark.benchmark + def test_benchmark_operation(self): + result = some_function(self.data) + assert result is not None + + def test_with_benchmark_param(self, benchmark): + result = benchmark(some_function, self.data) + assert result + +def standalone_function(): + return "not a test" + +@pytest.mark.benchmark +async def test_async_benchmark(): + await some_async_function() +""" + expected = """ +import pytest +from some_module import some_function + +class TestPerformance: + def setup_method(self): + self.data = [1, 2, 3, 4, 5] + + def test_normal_operation(self): + assert len(self.data) == 5 + +def standalone_function(): + return "not a test" +""" + tree = ast.parse(source) + result = remove_benchmark_functions(tree) + + assert ast.dump(result) == ast.dump(ast.parse(expected)) + +class TestBenchmarkDetectionMethods: + """Test the individual detection methods.""" + + def setup_method(self): + self.remover = BenchmarkFunctionRemover() + + def test_is_benchmark_marker_with_various_decorators(self): + """Test _is_benchmark_marker with different decorator types.""" + # Test @pytest.mark.benchmark + decorator_code = "pytest.mark.benchmark" + decorator_ast = ast.parse(decorator_code, mode='eval').body + assert self.remover._is_benchmark_marker(decorator_ast) + + # Test @benchmark + decorator_code = "benchmark" + decorator_ast = ast.parse(decorator_code, mode='eval').body + assert self.remover._is_benchmark_marker(decorator_ast) + + # Test @pytest.mark.parametrize (should return False) + decorator_code = "pytest.mark.parametrize" + decorator_ast = ast.parse(decorator_code, mode='eval').body + assert not self.remover._is_benchmark_marker(decorator_ast) + + def test_is_benchmark_call_detection(self): + """Test _is_benchmark_call with various call patterns.""" + # Test benchmark() + call_code = "benchmark(some_func)" + call_ast = ast.parse(call_code, mode='eval').body + assert self.remover._is_benchmark_call(call_ast) + + # Test benchmark.__call__() + call_code = "benchmark.__call__(some_func)" + call_ast = ast.parse(call_code, mode='eval').body + assert self.remover._is_benchmark_call(call_ast) + + # Test other_function() (should return False) + call_code = "other_function()" + call_ast = ast.parse(call_code, mode='eval').body + assert not self.remover._is_benchmark_call(call_ast) + + def test_uses_benchmark_fixture_comprehensive(self): + """Test _uses_benchmark_fixture with comprehensive scenarios.""" + # Function with benchmark parameter + func_code = """ +def test_func(benchmark, other_param): + pass +""" + func_ast = ast.parse(func_code).body[0] + assert self.remover._uses_benchmark_fixture(func_ast) + + # Function with benchmark call in body + func_code = """ +def test_func(): + result = benchmark(some_function) + return result +""" + func_ast = ast.parse(func_code).body[0] + assert self.remover._uses_benchmark_fixture(func_ast) + + # Normal function (should return False) + func_code = """ +def test_func(normal_param): + return normal_param * 2 +""" + func_ast = ast.parse(func_code).body[0] + assert not self.remover._uses_benchmark_fixture(func_ast) \ No newline at end of file