Skip to content

Commit c5c9533

Browse files
committed
some progress
1 parent 83f4354 commit c5c9533

File tree

6 files changed

+144
-28
lines changed

6 files changed

+144
-28
lines changed
Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,17 @@
11
import pytest
22
import time
33

4-
# @pytest.fixture(autouse=True)
5-
# def fixture(request):
6-
# if request.node.get_closest_marker("no_autouse"):
7-
# # Skip the fixture logic
8-
# yield
9-
# else:
10-
# start_time = time.time()
11-
# time.sleep(0.1)
12-
# yield
13-
# print(f"Took {time.time() - start_time} seconds")
14-
154

165
@pytest.fixture(autouse=True)
17-
def fixture1(request): # We don't need this fixture during testing
6+
def fixture1(request):
187
start_time = time.time()
198
time.sleep(0.1)
209
yield
2110
print(f"Took {time.time() - start_time} seconds")
2211

2312

2413
@pytest.fixture(autouse=True)
25-
def fixture2(request): # We need it
14+
def fixture2(request): # We don't need this fixture during testing
15+
print("not doing anything")
2616
yield
17+
print("did nothing")

code_to_optimize/tests/pytest/test_bubble_sort.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from code_to_optimize.bubble_sort import sorter
2-
import pytest
32

4-
@pytest.mark.no_autouse
3+
54
def test_sort():
65
input = [5, 4, 3, 2, 1, 0]
76
output = sorter(input)

codeflash/cli_cmds/cmd_init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from argparse import Namespace
3535

3636
CODEFLASH_LOGO: str = (
37-
f"{LF}"
37+
f"{LF}" # noqa : ISC003
3838
r" _ ___ _ _ " + f"{LF}"
3939
r" | | / __)| | | | " + f"{LF}"
4040
r" ____ ___ _ | | ____ | |__ | | ____ ___ | | _ " + f"{LF}"

codeflash/code_utils/code_extractor.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from typing import TYPE_CHECKING, Optional
66

77
import libcst as cst
8+
from libcst import MetadataWrapper
89
from libcst.codemod import CodemodContext
910
from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor
1011
from libcst.helpers import calculate_module_and_package
12+
from libcst.metadata import FullyQualifiedNameProvider
1113

1214
from codeflash.cli_cmds.console import logger
1315
from codeflash.models.models import FunctionParent
@@ -21,6 +23,53 @@
2123
from codeflash.models.models import FunctionSource
2224

2325

26+
class FunctionNameCollector(cst.CSTVisitor):
27+
"""A LibCST visitor that collects the fully qualified names of all functions."""
28+
29+
METADATA_DEPENDENCIES = (FullyQualifiedNameProvider,)
30+
31+
def __init__(self) -> None:
32+
super().__init__()
33+
self.qualified_names: set[str] = set()
34+
35+
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
36+
"""Visits a function definition node and extracts its qualified name."""
37+
try:
38+
q_names = self.get_metadata(FullyQualifiedNameProvider, node)
39+
for q_name in q_names:
40+
self.qualified_names.add(q_name.name)
41+
except KeyError:
42+
# This can happen for functions defined in scopes where a qualified
43+
# name cannot be determined.
44+
pass
45+
46+
47+
def get_function_qualified_names(file_path: Path) -> list[str]:
48+
"""Parse a Python file and returns a list of fully qualified function names.
49+
50+
Args:
51+
file_path: The path to the Python file.
52+
53+
Returns:
54+
A list of string representations of the qualified function names.
55+
56+
"""
57+
with file_path.open("r") as f:
58+
source_code = f.read()
59+
60+
# Parse the source code into a CST
61+
module = cst.parse_module(source_code)
62+
63+
# Wrap the module with a metadata wrapper to enable name resolution
64+
wrapper = MetadataWrapper(module)
65+
66+
# Create an instance of the visitor and visit the wrapped module
67+
visitor = FunctionNameCollector()
68+
wrapper.visit(visitor)
69+
70+
return list(visitor.qualified_names)
71+
72+
2473
class GlobalAssignmentCollector(cst.CSTVisitor):
2574
"""Collects all global assignment statements."""
2675

codeflash/code_utils/code_replacer.py

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
11
from __future__ import annotations
22

33
import ast
4+
import contextlib
45
from collections import defaultdict
56
from functools import lru_cache
67
from typing import TYPE_CHECKING, Optional, TypeVar
78

9+
import isort
810
import libcst as cst
11+
import libcst.matchers as m
912

1013
from codeflash.cli_cmds.console import logger
11-
from codeflash.code_utils.code_extractor import add_global_assignments, add_needed_imports_from_module
14+
from codeflash.code_utils.code_extractor import (
15+
add_global_assignments,
16+
add_needed_imports_from_module,
17+
get_function_qualified_names,
18+
)
19+
from codeflash.code_utils.config_parser import find_conftest_files
20+
from codeflash.code_utils.line_profile_utils import ImportAdder, add_decorator_to_qualified_function
1221
from codeflash.models.models import FunctionParent
1322

1423
if TYPE_CHECKING:
@@ -33,18 +42,78 @@ def normalize_code(code: str) -> str:
3342
return ast.unparse(normalize_node(ast.parse(code)))
3443

3544

36-
# def modify_autouse_fixture():
37-
# # find fixutre definition in conftetst.py (the one closest to the test)
38-
# # get fixtures present in override-fixtures in pyproject.toml
39-
# # add if marker closest return
40-
# conftest_files = find_conftest_files()
41-
# for cf_file in conftest_files:
42-
# # iterate over all functions in the file
43-
# # if function has autouse fixture, modify function to bypass with custom marker
44-
# pass
45+
class AutouseFixtureModifier(cst.CSTTransformer):
46+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
47+
# Matcher for '@fixture' or '@pytest.fixture'
48+
fixture_decorator_func = m.Name("fixture") | m.Attribute(value=m.Name("pytest"), attr=m.Name("fixture"))
49+
50+
for decorator in original_node.decorators:
51+
if m.matches(
52+
decorator,
53+
m.Decorator(
54+
decorator=m.Call(
55+
func=fixture_decorator_func, args=[m.Arg(value=m.Name("True"), keyword=m.Name("autouse"))]
56+
)
57+
),
58+
):
59+
# Found a matching fixture with autouse=True
60+
61+
# 1. The original body of the function will become the 'else' block.
62+
# updated_node.body is an IndentedBlock, which is what cst.Else expects.
63+
else_block = cst.Else(body=updated_node.body)
4564

65+
# 2. Create the new 'if' block that will exit the fixture early.
66+
if_test = cst.parse_expression('request.node.get_closest_marker("no_autouse")')
67+
yield_statement = cst.parse_statement("yield")
68+
if_body = cst.IndentedBlock(body=[yield_statement])
4669

47-
# reuse line profiler utils to add decorator and import to test fns
70+
# 3. Construct the full if/else statement.
71+
new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block)
72+
73+
# 4. Replace the entire function's body with our new single statement.
74+
return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement]))
75+
return updated_node
76+
77+
78+
@contextlib.contextmanager
79+
def disable_autouse(test_path: Path) -> None:
80+
file_content = test_path.read_text(encoding="utf-8")
81+
try:
82+
module = cst.parse_module(file_content)
83+
disable_autouse_fixture = AutouseFixtureModifier()
84+
modified_module = module.visit(disable_autouse_fixture)
85+
test_path.write_text(modified_module.code, encoding="utf-8")
86+
finally:
87+
test_path.write_text(file_content, encoding="utf-8")
88+
89+
90+
def modify_autouse_fixture(test_paths: list[Path]) -> None:
91+
# find fixutre definition in conftetst.py (the one closest to the test)
92+
# get fixtures present in override-fixtures in pyproject.toml
93+
# add if marker closest return
94+
conftest_files = find_conftest_files(test_paths)
95+
for cf_file in conftest_files:
96+
# iterate over all functions in the file
97+
# if function has autouse fixture, modify function to bypass with custom marker
98+
disable_autouse(cf_file)
99+
100+
101+
# # reuse line profiler utils to add decorator and import to test fns
102+
def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None:
103+
for test_path in test_paths:
104+
# read file
105+
file_content = test_path.read_text(encoding="utf-8")
106+
module = cst.parse_module(file_content)
107+
importadder = ImportAdder("import pytest")
108+
modified_module = module.visit(importadder)
109+
modified_module = isort.code(modified_module.code, float_to_top=True)
110+
qualified_fn_names = get_function_qualified_names(test_path)
111+
for fn_name in qualified_fn_names:
112+
modified_module = add_decorator_to_qualified_function(
113+
modified_module, fn_name, "pytest.mark.codeflash_no_autouse"
114+
)
115+
# write the modified module back to the file
116+
test_path.write_text(modified_module.code, encoding="utf-8")
48117

49118

50119
class OptimFunctionCollector(cst.CSTVisitor):

codeflash/optimization/function_optimizer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
from codeflash.benchmarking.utils import process_benchmark_data
2222
from codeflash.cli_cmds.console import code_print, console, logger, progress_bar
2323
from codeflash.code_utils import env_utils
24-
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
24+
from codeflash.code_utils.code_replacer import (
25+
add_custom_marker_to_all_tests,
26+
modify_autouse_fixture,
27+
replace_function_definitions_in_module,
28+
)
2529
from codeflash.code_utils.code_utils import (
2630
ImportErrorPattern,
2731
cleanup_paths,
@@ -742,6 +746,10 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, list[Functi
742746
f"{concolic_coverage_test_files_count} concolic coverage test file"
743747
f"{'s' if concolic_coverage_test_files_count != 1 else ''} for {func_qualname}"
744748
)
749+
logger.debug("disabling all autouse fixtures associated with the test files")
750+
modify_autouse_fixture(list(unique_instrumented_test_files))
751+
logger.debug("add custom marker to all tests")
752+
add_custom_marker_to_all_tests(list(unique_instrumented_test_files))
745753
return unique_instrumented_test_files
746754

747755
def generate_tests_and_optimizations(

0 commit comments

Comments
 (0)