11from __future__ import annotations
22
33import ast
4- import contextlib
54from collections import defaultdict
65from functools import lru_cache
76from typing import TYPE_CHECKING , Optional , TypeVar
1110import libcst .matchers as m
1211
1312from codeflash .cli_cmds .console import logger
14- from codeflash .code_utils .code_extractor import (
15- add_global_assignments ,
16- add_needed_imports_from_module ,
17- get_function_qualified_names ,
18- )
13+ from codeflash .code_utils .code_extractor import add_global_assignments , add_needed_imports_from_module
1914from codeflash .code_utils .config_parser import find_conftest_files
20- from codeflash .code_utils .line_profile_utils import ImportAdder , add_decorator_to_qualified_function
15+ from codeflash .code_utils .line_profile_utils import ImportAdder
2116from codeflash .models .models import FunctionParent
2217
2318if TYPE_CHECKING :
@@ -42,6 +37,74 @@ def normalize_code(code: str) -> str:
4237 return ast .unparse (normalize_node (ast .parse (code )))
4338
4439
40+ class PytestMarkAdder (cst .CSTTransformer ):
41+ """Transformer that adds pytest marks to test functions."""
42+
43+ def __init__ (self , mark_name : str ) -> None :
44+ super ().__init__ ()
45+ self .mark_name = mark_name
46+ self .has_pytest_import = False
47+
48+ def visit_Module (self , node : cst .Module ) -> None :
49+ """Check if pytest is already imported."""
50+ for statement in node .body :
51+ if isinstance (statement , cst .SimpleStatementLine ):
52+ for stmt in statement .body :
53+ if isinstance (stmt , cst .Import ):
54+ for import_alias in stmt .names :
55+ if isinstance (import_alias , cst .ImportAlias ) and import_alias .name .value == "pytest" :
56+ self .has_pytest_import = True
57+ elif isinstance (stmt , cst .ImportFrom ) and stmt .module and stmt .module .value == "pytest" :
58+ self .has_pytest_import = True
59+
60+ def leave_Module (self , original_node : cst .Module , updated_node : cst .Module ) -> cst .Module : # noqa: ARG002
61+ """Add pytest import if not present."""
62+ if not self .has_pytest_import :
63+ # Create import statement
64+ import_stmt = cst .SimpleStatementLine (body = [cst .Import (names = [cst .ImportAlias (name = cst .Name ("pytest" ))])])
65+ # Add import at the beginning
66+ updated_node = updated_node .with_changes (body = [import_stmt , * updated_node .body ])
67+ return updated_node
68+
69+ def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef : # noqa: ARG002
70+ """Add pytest mark to test functions."""
71+ # Check if the mark already exists
72+ for decorator in updated_node .decorators :
73+ if self ._is_pytest_mark (decorator .decorator , self .mark_name ):
74+ return updated_node
75+
76+ # Create the pytest mark decorator
77+ mark_decorator = self ._create_pytest_mark ()
78+
79+ # Add the decorator
80+ new_decorators = [* list (updated_node .decorators ), mark_decorator ]
81+ return updated_node .with_changes (decorators = new_decorators )
82+
83+ def _is_pytest_mark (self , decorator : cst .BaseExpression , mark_name : str ) -> bool :
84+ """Check if a decorator is a specific pytest mark."""
85+ if isinstance (decorator , cst .Attribute ):
86+ if (
87+ isinstance (decorator .value , cst .Attribute )
88+ and isinstance (decorator .value .value , cst .Name )
89+ and decorator .value .value .value == "pytest"
90+ and decorator .value .attr .value == "mark"
91+ and decorator .attr .value == mark_name
92+ ):
93+ return True
94+ elif isinstance (decorator , cst .Call ) and isinstance (decorator .func , cst .Attribute ):
95+ return self ._is_pytest_mark (decorator .func , mark_name )
96+ return False
97+
98+ def _create_pytest_mark (self ) -> cst .Decorator :
99+ """Create a pytest mark decorator."""
100+ # Base: pytest.mark.{mark_name}
101+ mark_attr = cst .Attribute (
102+ value = cst .Attribute (value = cst .Name ("pytest" ), attr = cst .Name ("mark" )), attr = cst .Name (self .mark_name )
103+ )
104+ decorator = mark_attr
105+ return cst .Decorator (decorator = decorator )
106+
107+
45108class AutouseFixtureModifier (cst .CSTTransformer ):
46109 def leave_FunctionDef (self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef ) -> cst .FunctionDef :
47110 # Matcher for '@fixture' or '@pytest.fixture'
@@ -63,7 +126,7 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
63126 else_block = cst .Else (body = updated_node .body )
64127
65128 # 2. Create the new 'if' block that will exit the fixture early.
66- if_test = cst .parse_expression ('request.node.get_closest_marker("no_autouse ")' )
129+ if_test = cst .parse_expression ('request.node.get_closest_marker("codeflash_no_autouse ")' )
67130 yield_statement = cst .parse_statement ("yield" )
68131 if_body = cst .IndentedBlock (body = [yield_statement ])
69132
@@ -75,7 +138,6 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
75138 return updated_node
76139
77140
78- @contextlib .contextmanager
79141def disable_autouse (test_path : Path ) -> str :
80142 file_content = test_path .read_text (encoding = "utf-8" )
81143 module = cst .parse_module (file_content )
@@ -107,13 +169,9 @@ def add_custom_marker_to_all_tests(test_paths: list[Path]) -> None:
107169 module = cst .parse_module (file_content )
108170 importadder = ImportAdder ("import pytest" )
109171 modified_module = module .visit (importadder )
110- modified_module = isort .code (modified_module .code , float_to_top = True )
111- qualified_fn_names = get_function_qualified_names (test_path )
112- for fn_name in qualified_fn_names :
113- modified_module = add_decorator_to_qualified_function (
114- modified_module , fn_name , "pytest.mark.codeflash_no_autouse"
115- )
116- # write the modified module back to the file
172+ modified_module = cst .parse_module (isort .code (modified_module .code , float_to_top = True ))
173+ pytest_mark_adder = PytestMarkAdder ("codeflash_no_autouse" )
174+ modified_module = modified_module .visit (pytest_mark_adder )
117175 test_path .write_text (modified_module .code , encoding = "utf-8" )
118176
119177
0 commit comments