55from functools import lru_cache
66from typing import TYPE_CHECKING , Optional , TypeVar
77
8+ import isort
89import libcst as cst
10+ import libcst .matchers as m
911
1012from codeflash .cli_cmds .console import logger
1113from codeflash .code_utils .code_extractor import add_global_assignments , add_needed_imports_from_module
14+ from codeflash .code_utils .config_parser import find_conftest_files
15+ from codeflash .code_utils .line_profile_utils import ImportAdder
1216from codeflash .models .models import FunctionParent
1317
1418if TYPE_CHECKING :
@@ -33,6 +37,142 @@ def normalize_code(code: str) -> str:
3337 return ast .unparse (normalize_node (ast .parse (code )))
3438
3539
40+ class PytestMarkAdder (cst .CSTTransformer ):
41+ """Transformer that adds pytest marks to test functions."""
42+
43+ def __init__ (self , mark_name : str ) -> None :
44+ super ().__init__ ()
45+ self .mark_name = mark_name
46+ self .has_pytest_import = False
47+
48+ def visit_Module (self , node : cst .Module ) -> None :
49+ """Check if pytest is already imported."""
50+ for statement in node .body :
51+ if isinstance (statement , cst .SimpleStatementLine ):
52+ for stmt in statement .body :
53+ if isinstance (stmt , cst .Import ):
54+ for import_alias in stmt .names :
55+ if isinstance (import_alias , cst .ImportAlias ) and import_alias .name .value == "pytest" :
56+ self .has_pytest_import = True
57+
58+ def leave_Module (self , original_node : cst .Module , updated_node : cst .Module ) -> cst .Module : # noqa: ARG002
59+ """Add pytest import if not present."""
60+ if not self .has_pytest_import :
61+ # Create import statement
62+ import_stmt = cst .SimpleStatementLine (body = [cst .Import (names = [cst .ImportAlias (name = cst .Name ("pytest" ))])])
63+ # Add import at the beginning
64+ updated_node = updated_node .with_changes (body = [import_stmt , * updated_node .body ])
65+ return updated_node
66+
67+ def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef : # noqa: ARG002
68+ """Add pytest mark to test functions."""
69+ # Check if the mark already exists
70+ for decorator in updated_node .decorators :
71+ if self ._is_pytest_mark (decorator .decorator , self .mark_name ):
72+ return updated_node
73+
74+ # Create the pytest mark decorator
75+ mark_decorator = self ._create_pytest_mark ()
76+
77+ # Add the decorator
78+ new_decorators = [* list (updated_node .decorators ), mark_decorator ]
79+ return updated_node .with_changes (decorators = new_decorators )
80+
81+ def _is_pytest_mark (self , decorator : cst .BaseExpression , mark_name : str ) -> bool :
82+ """Check if a decorator is a specific pytest mark."""
83+ if isinstance (decorator , cst .Attribute ):
84+ if (
85+ isinstance (decorator .value , cst .Attribute )
86+ and isinstance (decorator .value .value , cst .Name )
87+ and decorator .value .value .value == "pytest"
88+ and decorator .value .attr .value == "mark"
89+ and decorator .attr .value == mark_name
90+ ):
91+ return True
92+ elif isinstance (decorator , cst .Call ) and isinstance (decorator .func , cst .Attribute ):
93+ return self ._is_pytest_mark (decorator .func , mark_name )
94+ return False
95+
96+ def _create_pytest_mark (self ) -> cst .Decorator :
97+ """Create a pytest mark decorator."""
98+ # Base: pytest.mark.{mark_name}
99+ mark_attr = cst .Attribute (
100+ value = cst .Attribute (value = cst .Name ("pytest" ), attr = cst .Name ("mark" )), attr = cst .Name (self .mark_name )
101+ )
102+ decorator = mark_attr
103+ return cst .Decorator (decorator = decorator )
104+
105+
106+ class AutouseFixtureModifier (cst .CSTTransformer ):
107+ def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef :
108+ # Matcher for '@fixture' or '@pytest.fixture'
109+ fixture_decorator_func = m .Name ("fixture" ) | m .Attribute (value = m .Name ("pytest" ), attr = m .Name ("fixture" ))
110+
111+ for decorator in original_node .decorators :
112+ if m .matches (
113+ decorator ,
114+ m .Decorator (
115+ decorator = m .Call (
116+ func = fixture_decorator_func , args = [m .Arg (value = m .Name ("True" ), keyword = m .Name ("autouse" ))]
117+ )
118+ ),
119+ ):
120+ # Found a matching fixture with autouse=True
121+
122+ # 1. The original body of the function will become the 'else' block.
123+ # updated_node.body is an IndentedBlock, which is what cst.Else expects.
124+ else_block = cst .Else (body = updated_node .body )
125+
126+ # 2. Create the new 'if' block that will exit the fixture early.
127+ if_test = cst .parse_expression ('request.node.get_closest_marker("codeflash_no_autouse")' )
128+ yield_statement = cst .parse_statement ("yield" )
129+ if_body = cst .IndentedBlock (body = [yield_statement ])
130+
131+ # 3. Construct the full if/else statement.
132+ new_if_statement = cst .If (test = if_test , body = if_body , orelse = else_block )
133+
134+ # 4. Replace the entire function's body with our new single statement.
135+ return updated_node .with_changes (body = cst .IndentedBlock (body = [new_if_statement ]))
136+ return updated_node
137+
138+
139+ def disable_autouse (test_path : Path ) -> str :
140+ file_content = test_path .read_text (encoding = "utf-8" )
141+ module = cst .parse_module (file_content )
142+ disable_autouse_fixture = AutouseFixtureModifier ()
143+ modified_module = module .visit (disable_autouse_fixture )
144+ test_path .write_text (modified_module .code , encoding = "utf-8" )
145+ return file_content
146+
147+
148+ def modify_autouse_fixture (test_paths : list [Path ]) -> dict [Path , list [str ]]:
149+ # find fixutre definition in conftetst.py (the one closest to the test)
150+ # get fixtures present in override-fixtures in pyproject.toml
151+ # add if marker closest return
152+ file_content_map = {}
153+ conftest_files = find_conftest_files (test_paths )
154+ for cf_file in conftest_files :
155+ # iterate over all functions in the file
156+ # if function has autouse fixture, modify function to bypass with custom marker
157+ original_content = disable_autouse (cf_file )
158+ file_content_map [cf_file ] = original_content
159+ return file_content_map
160+
161+
162+ # # reuse line profiler utils to add decorator and import to test fns
163+ def add_custom_marker_to_all_tests (test_paths : list [Path ]) -> None :
164+ for test_path in test_paths :
165+ # read file
166+ file_content = test_path .read_text (encoding = "utf-8" )
167+ module = cst .parse_module (file_content )
168+ importadder = ImportAdder ("import pytest" )
169+ modified_module = module .visit (importadder )
170+ modified_module = cst .parse_module (isort .code (modified_module .code , float_to_top = True ))
171+ pytest_mark_adder = PytestMarkAdder ("codeflash_no_autouse" )
172+ modified_module = modified_module .visit (pytest_mark_adder )
173+ test_path .write_text (modified_module .code , encoding = "utf-8" )
174+
175+
36176class OptimFunctionCollector (cst .CSTVisitor ):
37177 METADATA_DEPENDENCIES = (cst .metadata .ParentNodeProvider ,)
38178
0 commit comments