Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions codeflash/cli_cmds/cmd_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import tomlkit
from git import InvalidGitRepositoryError, Repo
from pydantic.dataclasses import dataclass
from tomlkit import table

from codeflash.api.cfapi import is_github_app_installed_on_repo
from codeflash.cli_cmds.cli_common import apologize_and_exit, inquirer_wrapper, inquirer_wrapper_path
Expand All @@ -34,7 +35,7 @@
from argparse import Namespace

CODEFLASH_LOGO: str = (
f"{LF}" # noqa: ISC003
f"{LF}" # noqa : ISC003
r" _ ___ _ _ " + f"{LF}"
r" | | / __)| | | | " + f"{LF}"
r" ____ ___ _ | | ____ | |__ | | ____ ___ | | _ " + f"{LF}"
Expand Down Expand Up @@ -723,11 +724,21 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
formatter_cmds.append("disabled")
check_formatter_installed(formatter_cmds, exit_on_failure=False)
codeflash_section["formatter-cmds"] = formatter_cmds
codeflash_section["override-fixtures"] = False # don't override fixtures by default, let the user decide
# Add the 'codeflash' section, ensuring 'tool' section exists
tool_section = pyproject_data.get("tool", tomlkit.table())
tool_section["codeflash"] = codeflash_section
pyproject_data["tool"] = tool_section

# Create [tool.pytest.ini_options] if it doesn't exist
tool_section = pyproject_data.get("tool", table())
pytest_section = tool_section.get("pytest", table())
ini_options = pytest_section.get("ini_options", table())
# Define or overwrite the 'markers' array
ini_options["markers"] = ["codeflash_no_autouse"]
# Set updated sections back
pytest_section["ini_options"] = ini_options
tool_section["pytest"] = pytest_section
pyproject_data["tool"] = tool_section
with toml_path.open("w", encoding="utf8") as pyproject_file:
pyproject_file.write(tomlkit.dumps(pyproject_data))
click.echo(f"✅ Added Codeflash configuration to {toml_path}")
Expand Down
142 changes: 142 additions & 0 deletions codeflash/code_utils/code_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
from functools import lru_cache
from typing import TYPE_CHECKING, Optional, TypeVar

import isort
import libcst as cst
import libcst.matchers as m

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
from codeflash.code_utils.config_parser import find_conftest_files
from codeflash.code_utils.line_profile_utils import ImportAdder
from codeflash.models.models import FunctionParent

if TYPE_CHECKING:
Expand All @@ -33,6 +37,144 @@ def normalize_code(code: str) -> str:
return ast.unparse(normalize_node(ast.parse(code)))


class PytestMarkAdder(cst.CSTTransformer):
"""Transformer that adds pytest marks to test functions."""

def __init__(self, mark_name: str) -> None:
super().__init__()
self.mark_name = mark_name
self.has_pytest_import = False

def visit_Module(self, node: cst.Module) -> None:
"""Check if pytest is already imported."""
for statement in node.body:
if isinstance(statement, cst.SimpleStatementLine):
for stmt in statement.body:
if isinstance(stmt, cst.Import):
for import_alias in stmt.names:
if isinstance(import_alias, cst.ImportAlias) and import_alias.name.value == "pytest":
self.has_pytest_import = True
elif isinstance(stmt, cst.ImportFrom) and stmt.module and stmt.module.value == "pytest":
self.has_pytest_import = True

def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
"""Add pytest import if not present."""
if not self.has_pytest_import:
# Create import statement
import_stmt = cst.SimpleStatementLine(body=[cst.Import(names=[cst.ImportAlias(name=cst.Name("pytest"))])])
# Add import at the beginning
updated_node = updated_node.with_changes(body=[import_stmt, *updated_node.body])
return updated_node

def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
"""Add pytest mark to test functions."""
# Check if the mark already exists
for decorator in updated_node.decorators:
if self._is_pytest_mark(decorator.decorator, self.mark_name):
return updated_node

# Create the pytest mark decorator
mark_decorator = self._create_pytest_mark()

# Add the decorator
new_decorators = [*list(updated_node.decorators), mark_decorator]
return updated_node.with_changes(decorators=new_decorators)

def _is_pytest_mark(self, decorator: cst.BaseExpression, mark_name: str) -> bool:
"""Check if a decorator is a specific pytest mark."""
if isinstance(decorator, cst.Attribute):
if (
isinstance(decorator.value, cst.Attribute)
and isinstance(decorator.value.value, cst.Name)
and decorator.value.value.value == "pytest"
and decorator.value.attr.value == "mark"
and decorator.attr.value == mark_name
):
return True
elif isinstance(decorator, cst.Call) and isinstance(decorator.func, cst.Attribute):
return self._is_pytest_mark(decorator.func, mark_name)
return False

def _create_pytest_mark(self) -> cst.Decorator:
"""Create a pytest mark decorator."""
# Base: pytest.mark.{mark_name}
mark_attr = cst.Attribute(
value=cst.Attribute(value=cst.Name("pytest"), attr=cst.Name("mark")), attr=cst.Name(self.mark_name)
)
decorator = mark_attr
return cst.Decorator(decorator=decorator)


class AutouseFixtureModifier(cst.CSTTransformer):
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
# Matcher for '@fixture' or '@pytest.fixture'
fixture_decorator_func = m.Name("fixture") | m.Attribute(value=m.Name("pytest"), attr=m.Name("fixture"))

for decorator in original_node.decorators:
if m.matches(
decorator,
m.Decorator(
decorator=m.Call(
func=fixture_decorator_func, args=[m.Arg(value=m.Name("True"), keyword=m.Name("autouse"))]
)
),
):
# Found a matching fixture with autouse=True

# 1. The original body of the function will become the 'else' block.
# updated_node.body is an IndentedBlock, which is what cst.Else expects.
else_block = cst.Else(body=updated_node.body)

# 2. Create the new 'if' block that will exit the fixture early.
if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")')
yield_statement = cst.parse_statement("yield")
if_body = cst.IndentedBlock(body=[yield_statement])

# 3. Construct the full if/else statement.
new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block)

# 4. Replace the entire function's body with our new single statement.
return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement]))
return updated_node


def disable_autouse(test_path: Path) -> str:
file_content = test_path.read_text(encoding="utf-8")
module = cst.parse_module(file_content)
disable_autouse_fixture = AutouseFixtureModifier()
modified_module = module.visit(disable_autouse_fixture)
test_path.write_text(modified_module.code, encoding="utf-8")
return file_content


def modify_autouse_fixture(test_paths: list[Path]) -> dict[Path, list[str]]:
# find fixutre definition in conftetst.py (the one closest to the test)
# get fixtures present in override-fixtures in pyproject.toml
# add if marker closest return
file_content_map = {}
conftest_files = find_conftest_files(test_paths)
for cf_file in conftest_files:
# iterate over all functions in the file
# if function has autouse fixture, modify function to bypass with custom marker
original_content = disable_autouse(cf_file)
file_content_map[cf_file] = original_content
return file_content_map


# # reuse line profiler utils to add decorator and import to test fns
def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None:
for test_path in test_paths:
# read file
file_content = test_path.read_text(encoding="utf-8")
module = cst.parse_module(file_content)
importadder = ImportAdder("import pytest")
modified_module = module.visit(importadder)
modified_module = cst.parse_module(isort.code(modified_module.code, float_to_top=True))
pytest_mark_adder = PytestMarkAdder("codeflash_no_autouse")
modified_module = modified_module.visit(pytest_mark_adder)
test_path.write_text(modified_module.code, encoding="utf-8")


class OptimFunctionCollector(cst.CSTVisitor):
METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)

Expand Down
5 changes: 5 additions & 0 deletions codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,8 @@ def cleanup_paths(paths: list[Path]) -> None:
shutil.rmtree(path, ignore_errors=True)
else:
path.unlink(missing_ok=True)


def restore_conftest(path_to_content_map: dict[Path, str]) -> None:
for path, file_content in path_to_content_map.items():
path.write_text(file_content, encoding="utf8")
15 changes: 15 additions & 0 deletions codeflash/code_utils/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,21 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path:
raise ValueError(msg)


def find_conftest_files(test_paths: list[Path]) -> list[Path]:
list_of_conftest_files = set()
for test_path in test_paths:
# Find the conftest file on the root of the project
dir_path = Path.cwd()
cur_path = test_path
while cur_path != dir_path:
config_file = cur_path / "conftest.py"
if config_file.exists():
list_of_conftest_files.add(config_file)
# Search for conftest.py in the parent directories
cur_path = cur_path.parent
return list(list_of_conftest_files)


def parse_config_file(
config_file_path: Path | None = None,
override_formatter_check: bool = False, # noqa: FBT001, FBT002
Expand Down
17 changes: 16 additions & 1 deletion codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,19 @@
from codeflash.benchmarking.utils import process_benchmark_data
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
from codeflash.code_utils import env_utils
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
from codeflash.code_utils.code_replacer import (
add_custom_marker_to_all_tests,
modify_autouse_fixture,
replace_function_definitions_in_module,
)
from codeflash.code_utils.code_utils import (
ImportErrorPattern,
cleanup_paths,
file_name_from_test_module_name,
get_run_tmp_file,
has_any_async_functions,
module_name_from_file_path,
restore_conftest,
)
from codeflash.code_utils.config_consts import (
INDIVIDUAL_TESTCASE_TIMEOUT,
Expand Down Expand Up @@ -208,6 +213,14 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
for key in set(self.function_to_tests) | set(function_to_concolic_tests)
}
instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests)
logger.debug("disabling all autouse fixtures associated with the test files")
original_conftest_content = modify_autouse_fixture(
generated_test_paths + generated_perf_test_paths + list(instrumented_unittests_created_for_function)
)
logger.debug("add custom marker to all tests")
add_custom_marker_to_all_tests(
generated_test_paths + generated_perf_test_paths + list(instrumented_unittests_created_for_function)
)

# Get a dict of file_path_to_classes of fto and helpers_of_fto
file_path_to_helper_classes = defaultdict(set)
Expand All @@ -230,13 +243,15 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
)

if not is_successful(baseline_result):
restore_conftest(original_conftest_content)
cleanup_paths(paths_to_cleanup)
return Failure(baseline_result.failure())

original_code_baseline, test_functions_to_remove = baseline_result.unwrap()
if isinstance(original_code_baseline, OriginalCodeBaseline) and not coverage_critic(
original_code_baseline.coverage_results, self.args.test_framework
):
restore_conftest(original_conftest_content)
cleanup_paths(paths_to_cleanup)
return Failure("The threshold for test coverage was not met.")
# request for new optimizations but don't block execution, check for completion later
Expand Down
Loading