Skip to content

Commit 5c49694

Browse files
committed
start cleaning up and testing
1 parent b065cbc commit 5c49694

File tree

6 files changed

+95
-92
lines changed

6 files changed

+95
-92
lines changed

code_to_optimize/tests/pytest/conftest.py

Lines changed: 0 additions & 17 deletions
This file was deleted.

code_to_optimize/tests/pytest/test_bubble_sort.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from code_to_optimize.bubble_sort import sorter
2-
2+
import pytest
33

44
def test_sort():
55
input = [5, 4, 3, 2, 1, 0]

codeflash/cli_cmds/cmd_init.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import tomlkit
1717
from git import InvalidGitRepositoryError, Repo
1818
from pydantic.dataclasses import dataclass
19+
from tomlkit import table
1920

2021
from codeflash.api.cfapi import is_github_app_installed_on_repo
2122
from codeflash.cli_cmds.cli_common import apologize_and_exit, inquirer_wrapper, inquirer_wrapper_path
@@ -728,11 +729,16 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
728729
tool_section = pyproject_data.get("tool", tomlkit.table())
729730
tool_section["codeflash"] = codeflash_section
730731
pyproject_data["tool"] = tool_section
731-
if "tool.pytest.ini_options" not in pyproject_data:
732-
pyproject_data["tool.pytest.ini_options"] = {}
733-
if "markers" not in pyproject_data["tool.pytest.ini_options"]:
734-
pyproject_data["tool.pytest.ini_options"]["markers"] = []
735-
pyproject_data["tool.pytest.ini_options"]["markers"].append("codeflash_no_autouse")
732+
# Create [tool.pytest.ini_options] if it doesn't exist
733+
tool_section = pyproject_data.get("tool", table())
734+
pytest_section = tool_section.get("pytest", table())
735+
ini_options = pytest_section.get("ini_options", table())
736+
# Define or overwrite the 'markers' array
737+
ini_options["markers"] = ["codeflash_no_autouse"]
738+
# Set updated sections back
739+
pytest_section["ini_options"] = ini_options
740+
tool_section["pytest"] = pytest_section
741+
pyproject_data["tool"] = tool_section
736742
with toml_path.open("w", encoding="utf8") as pyproject_file:
737743
pyproject_file.write(tomlkit.dumps(pyproject_data))
738744
click.echo(f"✅ Added Codeflash configuration to {toml_path}")

codeflash/code_utils/code_extractor.py

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,9 @@
55
from typing import TYPE_CHECKING, Optional
66

77
import libcst as cst
8-
from libcst import MetadataWrapper
98
from libcst.codemod import CodemodContext
109
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
1110
from libcst.helpers import calculate_module_and_package
12-
from libcst.metadata import FullyQualifiedNameProvider
1311

1412
from codeflash.cli_cmds.console import logger
1513
from codeflash.models.models import FunctionParent
@@ -23,53 +21,6 @@
2321
from codeflash.models.models import FunctionSource
2422

2523

26-
class FunctionNameCollector(cst.CSTVisitor):
27-
"""A LibCST visitor that collects the fully qualified names of all functions."""
28-
29-
METADATA_DEPENDENCIES = (FullyQualifiedNameProvider,)
30-
31-
def __init__(self) -> None:
32-
super().__init__()
33-
self.qualified_names: set[str] = set()
34-
35-
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
36-
"""Visits a function definition node and extracts its qualified name."""
37-
try:
38-
q_names = self.get_metadata(FullyQualifiedNameProvider, node)
39-
for q_name in q_names:
40-
self.qualified_names.add(q_name.name)
41-
except KeyError:
42-
# This can happen for functions defined in scopes where a qualified
43-
# name cannot be determined.
44-
pass
45-
46-
47-
def get_function_qualified_names(file_path: Path) -> list[str]:
48-
"""Parse a Python file and returns a list of fully qualified function names.
49-
50-
Args:
51-
file_path: The path to the Python file.
52-
53-
Returns:
54-
A list of string representations of the qualified function names.
55-
56-
"""
57-
with file_path.open("r") as f:
58-
source_code = f.read()
59-
60-
# Parse the source code into a CST
61-
module = cst.parse_module(source_code)
62-
63-
# Wrap the module with a metadata wrapper to enable name resolution
64-
wrapper = MetadataWrapper(module)
65-
66-
# Create an instance of the visitor and visit the wrapped module
67-
visitor = FunctionNameCollector()
68-
wrapper.visit(visitor)
69-
70-
return list(visitor.qualified_names)
71-
72-
7324
class GlobalAssignmentCollector(cst.CSTVisitor):
7425
"""Collects all global assignment statements."""
7526

codeflash/code_utils/code_replacer.py

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import ast
4-
import contextlib
54
from collections import defaultdict
65
from functools import lru_cache
76
from typing import TYPE_CHECKING, Optional, TypeVar
@@ -11,13 +10,9 @@
1110
import libcst.matchers as m
1211

1312
from codeflash.cli_cmds.console import logger
14-
from codeflash.code_utils.code_extractor import (
15-
add_global_assignments,
16-
add_needed_imports_from_module,
17-
get_function_qualified_names,
18-
)
13+
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
1914
from codeflash.code_utils.config_parser import find_conftest_files
20-
from codeflash.code_utils.line_profile_utils import ImportAdder, add_decorator_to_qualified_function
15+
from codeflash.code_utils.line_profile_utils import ImportAdder
2116
from codeflash.models.models import FunctionParent
2217

2318
if TYPE_CHECKING:
@@ -42,6 +37,74 @@ def normalize_code(code: str) -> str:
4237
return ast.unparse(normalize_node(ast.parse(code)))
4338

4439

40+
class PytestMarkAdder(cst.CSTTransformer):
41+
"""Transformer that adds pytest marks to test functions."""
42+
43+
def __init__(self, mark_name: str) -> None:
44+
super().__init__()
45+
self.mark_name = mark_name
46+
self.has_pytest_import = False
47+
48+
def visit_Module(self, node: cst.Module) -> None:
49+
"""Check if pytest is already imported."""
50+
for statement in node.body:
51+
if isinstance(statement, cst.SimpleStatementLine):
52+
for stmt in statement.body:
53+
if isinstance(stmt, cst.Import):
54+
for import_alias in stmt.names:
55+
if isinstance(import_alias, cst.ImportAlias) and import_alias.name.value == "pytest":
56+
self.has_pytest_import = True
57+
elif isinstance(stmt, cst.ImportFrom) and stmt.module and stmt.module.value == "pytest":
58+
self.has_pytest_import = True
59+
60+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: # noqa: ARG002
61+
"""Add pytest import if not present."""
62+
if not self.has_pytest_import:
63+
# Create import statement
64+
import_stmt = cst.SimpleStatementLine(body=[cst.Import(names=[cst.ImportAlias(name=cst.Name("pytest"))])])
65+
# Add import at the beginning
66+
updated_node = updated_node.with_changes(body=[import_stmt, *updated_node.body])
67+
return updated_node
68+
69+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
70+
"""Add pytest mark to test functions."""
71+
# Check if the mark already exists
72+
for decorator in updated_node.decorators:
73+
if self._is_pytest_mark(decorator.decorator, self.mark_name):
74+
return updated_node
75+
76+
# Create the pytest mark decorator
77+
mark_decorator = self._create_pytest_mark()
78+
79+
# Add the decorator
80+
new_decorators = [*list(updated_node.decorators), mark_decorator]
81+
return updated_node.with_changes(decorators=new_decorators)
82+
83+
def _is_pytest_mark(self, decorator: cst.BaseExpression, mark_name: str) -> bool:
84+
"""Check if a decorator is a specific pytest mark."""
85+
if isinstance(decorator, cst.Attribute):
86+
if (
87+
isinstance(decorator.value, cst.Attribute)
88+
and isinstance(decorator.value.value, cst.Name)
89+
and decorator.value.value.value == "pytest"
90+
and decorator.value.attr.value == "mark"
91+
and decorator.attr.value == mark_name
92+
):
93+
return True
94+
elif isinstance(decorator, cst.Call) and isinstance(decorator.func, cst.Attribute):
95+
return self._is_pytest_mark(decorator.func, mark_name)
96+
return False
97+
98+
def _create_pytest_mark(self) -> cst.Decorator:
99+
"""Create a pytest mark decorator."""
100+
# Base: pytest.mark.{mark_name}
101+
mark_attr = cst.Attribute(
102+
value=cst.Attribute(value=cst.Name("pytest"), attr=cst.Name("mark")), attr=cst.Name(self.mark_name)
103+
)
104+
decorator = mark_attr
105+
return cst.Decorator(decorator=decorator)
106+
107+
45108
class AutouseFixtureModifier(cst.CSTTransformer):
46109
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
47110
# Matcher for '@fixture' or '@pytest.fixture'
@@ -63,7 +126,7 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
63126
else_block = cst.Else(body=updated_node.body)
64127

65128
# 2. Create the new 'if' block that will exit the fixture early.
66-
if_test = cst.parse_expression('request.node.get_closest_marker("no_autouse")')
129+
if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")')
67130
yield_statement = cst.parse_statement("yield")
68131
if_body = cst.IndentedBlock(body=[yield_statement])
69132

@@ -75,7 +138,6 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
75138
return updated_node
76139

77140

78-
@contextlib.contextmanager
79141
def disable_autouse(test_path: Path) -> str:
80142
file_content = test_path.read_text(encoding="utf-8")
81143
module = cst.parse_module(file_content)
@@ -107,13 +169,9 @@ def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None:
107169
module = cst.parse_module(file_content)
108170
importadder = ImportAdder("import pytest")
109171
modified_module = module.visit(importadder)
110-
modified_module = isort.code(modified_module.code, float_to_top=True)
111-
qualified_fn_names = get_function_qualified_names(test_path)
112-
for fn_name in qualified_fn_names:
113-
modified_module = add_decorator_to_qualified_function(
114-
modified_module, fn_name, "pytest.mark.codeflash_no_autouse"
115-
)
116-
# write the modified module back to the file
172+
modified_module = cst.parse_module(isort.code(modified_module.code, float_to_top=True))
173+
pytest_mark_adder = PytestMarkAdder("codeflash_no_autouse")
174+
modified_module = modified_module.visit(pytest_mark_adder)
117175
test_path.write_text(modified_module.code, encoding="utf-8")
118176

119177

codeflash/optimization/function_optimizer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,15 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
213213
for key in set(self.function_to_tests) | set(function_to_concolic_tests)
214214
}
215215
instrumented_unittests_created_for_function = self.instrument_existing_tests(function_to_all_tests)
216-
# logger.debug("disabling all autouse fixtures associated with the test files")
217-
original_conftest_content = modify_autouse_fixture(list(instrumented_unittests_created_for_function))
216+
logger.debug("disabling all autouse fixtures associated with the test files")
217+
original_conftest_content = modify_autouse_fixture(
218+
generated_test_paths + generated_perf_test_paths + list(instrumented_unittests_created_for_function)
219+
)
220+
logger.debug("add custom marker to all tests")
221+
add_custom_marker_to_all_tests(
222+
generated_test_paths + generated_perf_test_paths + list(instrumented_unittests_created_for_function)
223+
)
224+
218225
# Get a dict of file_path_to_classes of fto and helpers_of_fto
219226
file_path_to_helper_classes = defaultdict(set)
220227
for function_source in code_context.helper_functions:
@@ -750,8 +757,6 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, list[Functi
750757
f"{concolic_coverage_test_files_count} concolic coverage test file"
751758
f"{'s' if concolic_coverage_test_files_count != 1 else ''} for {func_qualname}"
752759
)
753-
logger.debug("add custom marker to all tests")
754-
add_custom_marker_to_all_tests(list(unique_instrumented_test_files))
755760
return unique_instrumented_test_files
756761

757762
def generate_tests_and_optimizations(

0 commit comments

Comments
 (0)