Skip to content

Commit 26b738f

Browse files
authored
Merge branch 'main' into mp-test-processing
2 parents c90c6b8 + 2d7db72 commit 26b738f

File tree

6 files changed

+658
-4
lines changed

6 files changed

+658
-4
lines changed

codeflash/cli_cmds/cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def process_pyproject_config(args: Namespace) -> Namespace:
123123
"disable_telemetry",
124124
"disable_imports_sorting",
125125
"git_remote",
126+
"override_fixtures",
126127
]
127128
for key in supported_keys:
128129
if key in pyproject_config and (

codeflash/code_utils/code_replacer.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
from functools import lru_cache
66
from typing import TYPE_CHECKING, Optional, TypeVar
77

8+
import isort
89
import libcst as cst
10+
import libcst.matchers as m
911

1012
from codeflash.cli_cmds.console import logger
1113
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
14+
from codeflash.code_utils.config_parser import find_conftest_files
15+
from codeflash.code_utils.line_profile_utils import ImportAdder
1216
from codeflash.models.models import FunctionParent
1317

1418
if TYPE_CHECKING:
@@ -33,6 +37,142 @@ def normalize_code(code: str) -> str:
3337
return ast.unparse(normalize_node(ast.parse(code)))
3438

3539

40+
class PytestMarkAdder(cst.CSTTransformer):
41+
"""Transformer that adds pytest marks to test functions."""
42+
43+
def __init__(self, mark_name: str) -> None:
44+
super().__init__()
45+
self.mark_name = mark_name
46+
self.has_pytest_import = False
47+
48+
def visit_Module(self, node: cst.Module) -> None:
49+
"""Check if pytest is already imported."""
50+
for statement in node.body:
51+
if isinstance(statement, cst.SimpleStatementLine):
52+
for stmt in statement.body:
53+
if isinstance(stmt, cst.Import):
54+
for import_alias in stmt.names:
55+
if isinstance(import_alias, cst.ImportAlias) and import_alias.name.value == "pytest":
56+
self.has_pytest_import = True
57+
58+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
59+
"""Add pytest import if not present."""
60+
if not self.has_pytest_import:
61+
# Create import statement
62+
import_stmt = cst.SimpleStatementLine(body=[cst.Import(names=[cst.ImportAlias(name=cst.Name("pytest"))])])
63+
# Add import at the beginning
64+
updated_node = updated_node.with_changes(body=[import_stmt, *updated_node.body])
65+
return updated_node
66+
67+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
68+
"""Add pytest mark to test functions."""
69+
# Check if the mark already exists
70+
for decorator in updated_node.decorators:
71+
if self._is_pytest_mark(decorator.decorator, self.mark_name):
72+
return updated_node
73+
74+
# Create the pytest mark decorator
75+
mark_decorator = self._create_pytest_mark()
76+
77+
# Add the decorator
78+
new_decorators = [*list(updated_node.decorators), mark_decorator]
79+
return updated_node.with_changes(decorators=new_decorators)
80+
81+
def _is_pytest_mark(self, decorator: cst.BaseExpression, mark_name: str) -> bool:
82+
"""Check if a decorator is a specific pytest mark."""
83+
if isinstance(decorator, cst.Attribute):
84+
if (
85+
isinstance(decorator.value, cst.Attribute)
86+
and isinstance(decorator.value.value, cst.Name)
87+
and decorator.value.value.value == "pytest"
88+
and decorator.value.attr.value == "mark"
89+
and decorator.attr.value == mark_name
90+
):
91+
return True
92+
elif isinstance(decorator, cst.Call) and isinstance(decorator.func, cst.Attribute):
93+
return self._is_pytest_mark(decorator.func, mark_name)
94+
return False
95+
96+
def _create_pytest_mark(self) -> cst.Decorator:
97+
"""Create a pytest mark decorator."""
98+
# Base: pytest.mark.{mark_name}
99+
mark_attr = cst.Attribute(
100+
value=cst.Attribute(value=cst.Name("pytest"), attr=cst.Name("mark")), attr=cst.Name(self.mark_name)
101+
)
102+
decorator = mark_attr
103+
return cst.Decorator(decorator=decorator)
104+
105+
106+
class AutouseFixtureModifier(cst.CSTTransformer):
107+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
108+
# Matcher for '@fixture' or '@pytest.fixture'
109+
fixture_decorator_func = m.Name("fixture") | m.Attribute(value=m.Name("pytest"), attr=m.Name("fixture"))
110+
111+
for decorator in original_node.decorators:
112+
if m.matches(
113+
decorator,
114+
m.Decorator(
115+
decorator=m.Call(
116+
func=fixture_decorator_func, args=[m.Arg(value=m.Name("True"), keyword=m.Name("autouse"))]
117+
)
118+
),
119+
):
120+
# Found a matching fixture with autouse=True
121+
122+
# 1. The original body of the function will become the 'else' block.
123+
# updated_node.body is an IndentedBlock, which is what cst.Else expects.
124+
else_block = cst.Else(body=updated_node.body)
125+
126+
# 2. Create the new 'if' block that will exit the fixture early.
127+
if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")')
128+
yield_statement = cst.parse_statement("yield")
129+
if_body = cst.IndentedBlock(body=[yield_statement])
130+
131+
# 3. Construct the full if/else statement.
132+
new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block)
133+
134+
# 4. Replace the entire function's body with our new single statement.
135+
return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement]))
136+
return updated_node
137+
138+
139+
def disable_autouse(test_path: Path) -> str:
140+
file_content = test_path.read_text(encoding="utf-8")
141+
module = cst.parse_module(file_content)
142+
disable_autouse_fixture = AutouseFixtureModifier()
143+
modified_module = module.visit(disable_autouse_fixture)
144+
test_path.write_text(modified_module.code, encoding="utf-8")
145+
return file_content
146+
147+
148+
def modify_autouse_fixture(test_paths: list[Path]) -> dict[Path, list[str]]:
149+
# find fixutre definition in conftetst.py (the one closest to the test)
150+
# get fixtures present in override-fixtures in pyproject.toml
151+
# add if marker closest return
152+
file_content_map = {}
153+
conftest_files = find_conftest_files(test_paths)
154+
for cf_file in conftest_files:
155+
# iterate over all functions in the file
156+
# if function has autouse fixture, modify function to bypass with custom marker
157+
original_content = disable_autouse(cf_file)
158+
file_content_map[cf_file] = original_content
159+
return file_content_map
160+
161+
162+
# # reuse line profiler utils to add decorator and import to test fns
163+
def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None:
164+
for test_path in test_paths:
165+
# read file
166+
file_content = test_path.read_text(encoding="utf-8")
167+
module = cst.parse_module(file_content)
168+
importadder = ImportAdder("import pytest")
169+
modified_module = module.visit(importadder)
170+
modified_module = cst.parse_module(isort.code(modified_module.code, float_to_top=True))
171+
pytest_mark_adder = PytestMarkAdder("codeflash_no_autouse")
172+
modified_module = modified_module.visit(pytest_mark_adder)
173+
test_path.write_text(modified_module.code, encoding="utf-8")
174+
175+
36176
class OptimFunctionCollector(cst.CSTVisitor):
37177
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)
38178

codeflash/code_utils/code_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,3 +208,8 @@ def cleanup_paths(paths: list[Path]) -> None:
208208
shutil.rmtree(path, ignore_errors=True)
209209
else:
210210
path.unlink(missing_ok=True)
211+
212+
213+
def restore_conftest(path_to_content_map: dict[Path, str]) -> None:
214+
for path, file_content in path_to_content_map.items():
215+
path.write_text(file_content, encoding="utf8")

codeflash/code_utils/config_parser.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,21 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path:
3131
raise ValueError(msg)
3232

3333

34+
def find_conftest_files(test_paths: list[Path]) -> list[Path]:
35+
list_of_conftest_files = set()
36+
for test_path in test_paths:
37+
# Find the conftest file on the root of the project
38+
dir_path = Path.cwd()
39+
cur_path = test_path
40+
while cur_path != dir_path:
41+
config_file = cur_path / "conftest.py"
42+
if config_file.exists():
43+
list_of_conftest_files.add(config_file)
44+
# Search for conftest.py in the parent directories
45+
cur_path = cur_path.parent
46+
return list(list_of_conftest_files)
47+
48+
3449
def parse_config_file(
3550
config_file_path: Path | None = None,
3651
override_formatter_check: bool = False, # noqa: FBT001, FBT002
@@ -56,7 +71,12 @@ def parse_config_file(
5671
path_keys = ["module-root", "tests-root", "benchmarks-root"]
5772
path_list_keys = ["ignore-paths"]
5873
str_keys = {"pytest-cmd": "pytest", "git-remote": "origin"}
59-
bool_keys = {"disable-telemetry": False, "disable-imports-sorting": False, "benchmark": False}
74+
bool_keys = {
75+
"override-fixtures": False,
76+
"disable-telemetry": False,
77+
"disable-imports-sorting": False,
78+
"benchmark": False,
79+
}
6080
list_str_keys = {"formatter-cmds": ["black $file"]}
6181

6282
for key, default_value in str_keys.items():

codeflash/optimization/function_optimizer.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,19 @@
2121
from codeflash.benchmarking.utils import process_benchmark_data
2222
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
2323
from codeflash.code_utils import env_utils
24-
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
24+
from codeflash.code_utils.code_replacer import (
25+
add_custom_marker_to_all_tests,
26+
modify_autouse_fixture,
27+
replace_function_definitions_in_module,
28+
)
2529
from codeflash.code_utils.code_utils import (
2630
ImportErrorPattern,
2731
cleanup_paths,
2832
file_name_from_test_module_name,
2933
get_run_tmp_file,
3034
has_any_async_functions,
3135
module_name_from_file_path,
36+
restore_conftest,
3237
)
3338
from codeflash.code_utils.config_consts import (
3439
INDIVIDUAL_TESTCASE_TIMEOUT,
@@ -212,6 +217,11 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
212217
for key in set(self.function_to_tests) | set(function_to_concolic_tests)
213218
}
214219
instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests)
220+
if self.args.override_fixtures:
221+
logger.info("Disabling all autouse fixtures associated with the generated test files")
222+
original_conftest_content = modify_autouse_fixture(generated_test_paths + generated_perf_test_paths)
223+
logger.info("Add custom marker to generated test files")
224+
add_custom_marker_to_all_tests(generated_test_paths + generated_perf_test_paths)
215225

216226
# Get a dict of file_path_to_classes of fto and helpers_of_fto
217227
file_path_to_helper_classes = defaultdict(set)
@@ -234,13 +244,17 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
234244
)
235245

236246
if not is_successful(baseline_result):
247+
if self.args.override_fixtures:
248+
restore_conftest(original_conftest_content)
237249
cleanup_paths(paths_to_cleanup)
238250
return Failure(baseline_result.failure())
239251

240252
original_code_baseline, test_functions_to_remove = baseline_result.unwrap()
241253
if isinstance(original_code_baseline, OriginalCodeBaseline) and not coverage_critic(
242254
original_code_baseline.coverage_results, self.args.test_framework
243255
):
256+
if self.args.override_fixtures:
257+
restore_conftest(original_conftest_content)
244258
cleanup_paths(paths_to_cleanup)
245259
return Failure("The threshold for test coverage was not met.")
246260
# request for new optimizations but don't block execution, check for completion later
@@ -356,6 +370,8 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
356370
)
357371
self.log_successful_optimization(explanation, generated_tests, exp_type)
358372

373+
if self.args.override_fixtures:
374+
restore_conftest(original_conftest_content)
359375
if not best_optimization:
360376
return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}")
361377
return Success(best_optimization)

0 commit comments

Comments
 (0)