Skip to content

Commit 23a1aad

Browse files
committed
remove test fns with the benchmark fixture
1 parent 10e8a13 commit 23a1aad

File tree

3 files changed

+484
-2
lines changed

3 files changed

+484
-2
lines changed

codeflash/code_utils/code_replacer.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import ast
44
from collections import defaultdict
55
from functools import lru_cache
6-
from typing import TYPE_CHECKING, Optional, TypeVar
6+
from typing import TYPE_CHECKING, Optional, TypeVar, Union
77

88
import isort
99
import libcst as cst
@@ -16,6 +16,7 @@
1616
from codeflash.models.models import FunctionParent
1717

1818
if TYPE_CHECKING:
19+
from _ast import AST
1920
from pathlib import Path
2021

2122
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
@@ -24,6 +25,115 @@
2425
ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST)
2526

2627

28+
class BenchmarkFunctionRemover(ast.NodeTransformer):
29+
"""AST transformer that removes functions using pytest-benchmark fixture."""
30+
31+
def _uses_benchmark_fixture(self, node: Union[ast.FunctionDef, ast.AsyncFunctionDef]) -> bool:
32+
"""Check if a function uses the benchmark fixture."""
33+
# Check function arguments for 'benchmark' parameter
34+
for arg in node.args.args:
35+
if arg.arg == "benchmark":
36+
return True
37+
38+
# Check for pytest markers that might indicate benchmarking
39+
for decorator in node.decorator_list:
40+
if self._is_benchmark_marker(decorator):
41+
return True
42+
43+
# Check function body for benchmark usage
44+
return any(isinstance(stmt, ast.Call) and self._is_benchmark_call(stmt) for stmt in ast.walk(node))
45+
46+
def _is_benchmark_marker(self, decorator: ast.expr) -> bool:
47+
"""Check if decorator is a benchmark-related pytest marker."""
48+
if isinstance(decorator, ast.Call):
49+
if isinstance(decorator.func, ast.Attribute):
50+
# Check for @pytest.mark.benchmark
51+
if (
52+
isinstance(decorator.func.value, ast.Attribute)
53+
and isinstance(decorator.func.value.value, ast.Name)
54+
and decorator.func.value.value.id == "pytest"
55+
and decorator.func.value.attr == "mark"
56+
and decorator.func.attr == "benchmark"
57+
):
58+
return True
59+
elif isinstance(decorator.func, ast.Name) and decorator.func.id == "benchmark":
60+
return True
61+
elif isinstance(decorator, ast.Attribute):
62+
# Check for @pytest.mark.benchmark (without call)
63+
if (
64+
isinstance(decorator.value, ast.Attribute)
65+
and isinstance(decorator.value.value, ast.Name)
66+
and decorator.value.value.id == "pytest"
67+
and decorator.value.attr == "mark"
68+
and decorator.attr == "benchmark"
69+
):
70+
return True
71+
elif isinstance(decorator, ast.Name) and decorator.id == "benchmark":
72+
return True
73+
74+
return False
75+
76+
@staticmethod
77+
def _is_benchmark_call(call: ast.Call) -> bool:
78+
"""Check if a call is using the benchmark fixture."""
79+
if isinstance(call.func, ast.Name) and call.func.id == "benchmark":
80+
return True
81+
return bool(
82+
isinstance(call.func, ast.Attribute)
83+
and call.func.attr in ["benchmark", "__call__"]
84+
and isinstance(call.func.value, ast.Name)
85+
and call.func.value.id == "benchmark"
86+
)
87+
88+
def visit_FunctionDef(self, node: ast.FunctionDef) -> Optional[AST]:
89+
"""Visit function definitions and remove if they use benchmark fixture."""
90+
if self._uses_benchmark_fixture(node):
91+
return None # Remove the function
92+
return self.generic_visit(node)
93+
94+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Optional[AST]:
95+
"""Visit async function definitions and remove if they use benchmark fixture."""
96+
if self._uses_benchmark_fixture(node):
97+
return None # Remove the function
98+
return self.generic_visit(node)
99+
100+
def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef:
101+
"""Visit class definitions and remove benchmark methods."""
102+
original_body = node.body[:]
103+
new_body = []
104+
105+
for item in original_body:
106+
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
107+
if not self._uses_benchmark_fixture(item):
108+
new_body.append(self.visit(item))
109+
110+
else:
111+
new_body.append(self.visit(item))
112+
113+
node.body = new_body
114+
return node
115+
116+
117+
def remove_benchmark_functions(tree: AST) -> AST:
118+
"""Remove benchmark functions from Python source code.
119+
120+
Args:
121+
tree: Python source code as ast module
122+
123+
Returns:
124+
Tuple of (modified_source_code, set_of_removed_function_names)
125+
126+
"""
127+
try:
128+
# Create and apply the transformer
129+
remover = BenchmarkFunctionRemover()
130+
return remover.visit(tree)
131+
132+
except Exception as e:
133+
print(f"Error processing code: {e}")
134+
return tree
135+
136+
27137
def normalize_node(node: ASTNodeT) -> ASTNodeT:
28138
if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)) and ast.get_docstring(node):
29139
node.body = node.body[1:]

codeflash/code_utils/instrument_existing_tests.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import isort
88

99
from codeflash.cli_cmds.console import logger
10+
from codeflash.code_utils.code_replacer import remove_benchmark_functions
1011
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
1112
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1213
from codeflash.models.models import FunctionParent, TestingMode, VerificationType
@@ -355,6 +356,8 @@ def inject_profiling_into_existing_test(
355356
if test_framework == "unittest":
356357
new_imports.append(ast.Import(names=[ast.alias(name="timeout_decorator")]))
357358
tree.body = [*new_imports, create_wrapper_function(mode), *tree.body]
359+
# remove benchmark functions
360+
tree = remove_benchmark_functions(tree)
358361
return True, isort.code(ast.unparse(tree), float_to_top=True)
359362

360363

0 commit comments

Comments
 (0)