Skip to content
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def process_pyproject_config(args: Namespace) -> Namespace:
"disable_telemetry",
"disable_imports_sorting",
"git_remote",
"override_fixtures",
]
for key in supported_keys:
if key in pyproject_config and (
Expand Down
140 changes: 140 additions & 0 deletions codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
from functools import lru_cache
from typing import TYPE_CHECKING, Optional, TypeVar

import isort
import libcst as cst
import libcst.matchers as m

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
from codeflash.code_utils.config_parser import find_conftest_files
from codeflash.code_utils.line_profile_utils import ImportAdder
from codeflash.models.models import FunctionParent

if TYPE_CHECKING:
Expand All @@ -33,6 +37,142 @@ def normalize_code(code: str) -> str:
return ast.unparse(normalize_node(ast.parse(code)))


class PytestMarkAdder(cst.CSTTransformer):
"""Transformer that adds pytest marks to test functions."""

def __init__(self, mark_name: str) -> None:
super().__init__()
self.mark_name = mark_name
self.has_pytest_import = False

def visit_Module(self, node: cst.Module) -> None:
"""Check if pytest is already imported."""
for statement in node.body:
if isinstance(statement, cst.SimpleStatementLine):
for stmt in statement.body:
if isinstance(stmt, cst.Import):
for import_alias in stmt.names:
if isinstance(import_alias, cst.ImportAlias) and import_alias.name.value == "pytest":
self.has_pytest_import = True

def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
"""Add pytest import if not present."""
if not self.has_pytest_import:
# Create import statement
import_stmt = cst.SimpleStatementLine(body=[cst.Import(names=[cst.ImportAlias(name=cst.Name("pytest"))])])
# Add import at the beginning
updated_node = updated_node.with_changes(body=[import_stmt, *updated_node.body])
return updated_node

def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
"""Add pytest mark to test functions."""
# Check if the mark already exists
for decorator in updated_node.decorators:
if self._is_pytest_mark(decorator.decorator, self.mark_name):
return updated_node

# Create the pytest mark decorator
mark_decorator = self._create_pytest_mark()

# Add the decorator
new_decorators = [*list(updated_node.decorators), mark_decorator]
return updated_node.with_changes(decorators=new_decorators)

def _is_pytest_mark(self, decorator: cst.BaseExpression, mark_name: str) -> bool:
"""Check if a decorator is a specific pytest mark."""
if isinstance(decorator, cst.Attribute):
if (
isinstance(decorator.value, cst.Attribute)
and isinstance(decorator.value.value, cst.Name)
and decorator.value.value.value == "pytest"
and decorator.value.attr.value == "mark"
and decorator.attr.value == mark_name
):
return True
elif isinstance(decorator, cst.Call) and isinstance(decorator.func, cst.Attribute):
return self._is_pytest_mark(decorator.func, mark_name)
return False

def _create_pytest_mark(self) -> cst.Decorator:
"""Create a pytest mark decorator."""
# Base: pytest.mark.{mark_name}
mark_attr = cst.Attribute(
value=cst.Attribute(value=cst.Name("pytest"), attr=cst.Name("mark")), attr=cst.Name(self.mark_name)
)
decorator = mark_attr
return cst.Decorator(decorator=decorator)


class AutouseFixtureModifier(cst.CSTTransformer):
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
# Matcher for '@fixture' or '@pytest.fixture'
fixture_decorator_func = m.Name("fixture") | m.Attribute(value=m.Name("pytest"), attr=m.Name("fixture"))

for decorator in original_node.decorators:
if m.matches(
decorator,
m.Decorator(
decorator=m.Call(
func=fixture_decorator_func, args=[m.Arg(value=m.Name("True"), keyword=m.Name("autouse"))]
)
),
):
# Found a matching fixture with autouse=True

# 1. The original body of the function will become the 'else' block.
# updated_node.body is an IndentedBlock, which is what cst.Else expects.
else_block = cst.Else(body=updated_node.body)

# 2. Create the new 'if' block that will exit the fixture early.
if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")')
yield_statement = cst.parse_statement("yield")
if_body = cst.IndentedBlock(body=[yield_statement])

# 3. Construct the full if/else statement.
new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block)

# 4. Replace the entire function's body with our new single statement.
return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement]))
return updated_node


def disable_autouse(test_path: Path) -> str:
file_content = test_path.read_text(encoding="utf-8")
module = cst.parse_module(file_content)
disable_autouse_fixture = AutouseFixtureModifier()
modified_module = module.visit(disable_autouse_fixture)
test_path.write_text(modified_module.code, encoding="utf-8")
return file_content


def modify_autouse_fixture(test_paths: list[Path]) -> dict[Path, list[str]]:
# find fixutre definition in conftetst.py (the one closest to the test)
# get fixtures present in override-fixtures in pyproject.toml
# add if marker closest return
file_content_map = {}
conftest_files = find_conftest_files(test_paths)
for cf_file in conftest_files:
# iterate over all functions in the file
# if function has autouse fixture, modify function to bypass with custom marker
original_content = disable_autouse(cf_file)
file_content_map[cf_file] = original_content
return file_content_map


# # reuse line profiler utils to add decorator and import to test fns
def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None:
for test_path in test_paths:
# read file
file_content = test_path.read_text(encoding="utf-8")
module = cst.parse_module(file_content)
importadder = ImportAdder("import pytest")
modified_module = module.visit(importadder)
modified_module = cst.parse_module(isort.code(modified_module.code, float_to_top=True))
pytest_mark_adder = PytestMarkAdder("codeflash_no_autouse")
modified_module = modified_module.visit(pytest_mark_adder)
test_path.write_text(modified_module.code, encoding="utf-8")


class OptimFunctionCollector(cst.CSTVisitor):
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)

Expand Down
5 changes: 5 additions & 0 deletions codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,8 @@ def cleanup_paths(paths: list[Path]) -> None:
shutil.rmtree(path, ignore_errors=True)
else:
path.unlink(missing_ok=True)


def restore_conftest(path_to_content_map: dict[Path, str]) -> None:
for path, file_content in path_to_content_map.items():
path.write_text(file_content, encoding="utf8")
22 changes: 21 additions & 1 deletion codeflash/code_utils/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path:
raise ValueError(msg)


def find_conftest_files(test_paths: list[Path]) -> list[Path]:
list_of_conftest_files = set()
for test_path in test_paths:
# Find the conftest file on the root of the project
dir_path = Path.cwd()
cur_path = test_path
while cur_path != dir_path:
config_file = cur_path / "conftest.py"
if config_file.exists():
list_of_conftest_files.add(config_file)
# Search for conftest.py in the parent directories
cur_path = cur_path.parent
return list(list_of_conftest_files)


def parse_config_file(
config_file_path: Path | None = None,
override_formatter_check: bool = False, # noqa: FBT001, FBT002
Expand All @@ -56,7 +71,12 @@ def parse_config_file(
path_keys = ["module-root", "tests-root", "benchmarks-root"]
path_list_keys = ["ignore-paths"]
str_keys = {"pytest-cmd": "pytest", "git-remote": "origin"}
bool_keys = {"disable-telemetry": False, "disable-imports-sorting": False, "benchmark": False}
bool_keys = {
"override-fixtures": False,
"disable-telemetry": False,
"disable-imports-sorting": False,
"benchmark": False,
}
list_str_keys = {"formatter-cmds": ["black $file"]}

for key, default_value in str_keys.items():
Expand Down
18 changes: 17 additions & 1 deletion codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@
from codeflash.benchmarking.utils import process_benchmark_data
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
from codeflash.code_utils.code_replacer import (
add_custom_marker_to_all_tests,
modify_autouse_fixture,
replace_function_definitions_in_module,
)
from codeflash.code_utils.code_utils import (
ImportErrorPattern,
cleanup_paths,
file_name_from_test_module_name,
get_run_tmp_file,
has_any_async_functions,
module_name_from_file_path,
restore_conftest,
)
from codeflash.code_utils.config_consts import (
INDIVIDUAL_TESTCASE_TIMEOUT,
Expand Down Expand Up @@ -212,6 +217,11 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
for key in set(self.function_to_tests) | set(function_to_concolic_tests)
}
instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests)
if self.args.override_fixtures:
logger.info("Disabling all autouse fixtures associated with the generated test files")
original_conftest_content = modify_autouse_fixture(generated_test_paths + generated_perf_test_paths)
logger.info("Add custom marker to generated test files")
add_custom_marker_to_all_tests(generated_test_paths + generated_perf_test_paths)

# Get a dict of file_path_to_classes of fto and helpers_of_fto
file_path_to_helper_classes = defaultdict(set)
Expand All @@ -234,13 +244,17 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
)

if not is_successful(baseline_result):
if self.args.override_fixtures:
restore_conftest(original_conftest_content)
cleanup_paths(paths_to_cleanup)
return Failure(baseline_result.failure())

original_code_baseline, test_functions_to_remove = baseline_result.unwrap()
if isinstance(original_code_baseline, OriginalCodeBaseline) and not coverage_critic(
original_code_baseline.coverage_results, self.args.test_framework
):
if self.args.override_fixtures:
restore_conftest(original_conftest_content)
cleanup_paths(paths_to_cleanup)
return Failure("The threshold for test coverage was not met.")
# request for new optimizations but don't block execution, check for completion later
Expand Down Expand Up @@ -356,6 +370,8 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
)
self.log_successful_optimization(explanation, generated_tests, exp_type)

if self.args.override_fixtures:
restore_conftest(original_conftest_content)
if not best_optimization:
return Failure(f"No best optimizations found for function {self.function_to_optimize.qualified_name}")
return Success(best_optimization)
Expand Down
Loading
Loading