11# ruff: noqa: SLF001
22from __future__ import annotations
33
4+ import ast
45import hashlib
56import os
67import pickle
1213from pathlib import Path
1314from typing import TYPE_CHECKING , Callable , Optional
1415
16+ if TYPE_CHECKING :
17+ from codeflash .discovery .functions_to_optimize import FunctionToOptimize
18+
1519import pytest
1620from pydantic .dataclasses import dataclass
1721from rich .panel import Panel
@@ -137,20 +141,175 @@ def close(self) -> None:
137141 self .connection .close ()
138142
139143
144+ class ImportAnalyzer (ast .NodeVisitor ):
145+ """AST-based analyzer to find all imports in a test file."""
146+
147+ def __init__ (self , function_names_to_find : set [str ]) -> None :
148+ self .function_names_to_find = function_names_to_find
149+ self .imported_names : set [str ] = set ()
150+ self .imported_modules : set [str ] = set ()
151+ self .found_target_functions : set [str ] = set ()
152+
153+ def visit_Import (self , node : ast .Import ) -> None :
154+ """Handle 'import module' statements."""
155+ for alias in node .names :
156+ module_name = alias .asname if alias .asname else alias .name
157+ self .imported_modules .add (module_name )
158+ self .imported_names .add (module_name )
159+ self .generic_visit (node )
160+
161+ def visit_ImportFrom (self , node : ast .ImportFrom ) -> None :
162+ """Handle 'from module import name' statements."""
163+ if node .module :
164+ self .imported_modules .add (node .module )
165+
166+ for alias in node .names :
167+ if alias .name == "*" :
168+ # Star imports - we can't know what's imported, so be conservative
169+ self .imported_names .add ("*" )
170+ else :
171+ imported_name = alias .asname if alias .asname else alias .name
172+ self .imported_names .add (imported_name )
173+
174+ # Check if this import matches any target function
175+ if alias .name in self .function_names_to_find :
176+ self .found_target_functions .add (alias .name )
177+ self .generic_visit (node )
178+
179+ def visit_Call (self , node : ast .Call ) -> None :
180+ """Handle dynamic imports like importlib.import_module() or __import__()."""
181+ if isinstance (node .func , ast .Name ) and node .func .id == "__import__" and node .args :
182+ # __import__("module_name")
183+ if isinstance (node .args [0 ], ast .Constant ) and isinstance (node .args [0 ].value , str ):
184+ self .imported_modules .add (node .args [0 ].value )
185+ elif (isinstance (node .func , ast .Attribute )
186+ and isinstance (node .func .value , ast .Name )
187+ and node .func .value .id == "importlib"
188+ and node .func .attr == "import_module"
189+ and node .args ):
190+ # importlib.import_module("module_name")
191+ if isinstance (node .args [0 ], ast .Constant ) and isinstance (node .args [0 ].value , str ):
192+ self .imported_modules .add (node .args [0 ].value )
193+ self .generic_visit (node )
194+
195+ def visit_Name (self , node : ast .Name ) -> None :
196+ """Check if any name usage matches our target functions."""
197+ if node .id in self .function_names_to_find :
198+ self .found_target_functions .add (node .id )
199+ self .generic_visit (node )
200+
201+ def visit_Attribute (self , node : ast .Attribute ) -> None :
202+ """Handle module.function_name patterns."""
203+ if node .attr in self .function_names_to_find :
204+ self .found_target_functions .add (node .attr )
205+ self .generic_visit (node )
206+
207+
208+ def analyze_imports_in_test_file (test_file_path : Path , target_functions : set [str ]) -> tuple [bool , set [str ]]:
209+ """Analyze imports in a test file to determine if it might test any target functions.
210+
211+ Args:
212+ test_file_path: Path to the test file
213+ target_functions: Set of function names we're looking for
214+
215+ Returns:
216+ Tuple of (should_process_with_jedi, found_function_names)
217+
218+ """
219+ try :
220+ with test_file_path .open ("r" , encoding = "utf-8" ) as f :
221+ content = f .read ()
222+
223+ tree = ast .parse (content , filename = str (test_file_path ))
224+ analyzer = ImportAnalyzer (target_functions )
225+ analyzer .visit (tree )
226+
227+ # If we found direct function matches, definitely process
228+ if analyzer .found_target_functions :
229+ return True , analyzer .found_target_functions
230+
231+ # If there are star imports, we need to be conservative
232+ if "*" in analyzer .imported_names :
233+ return True , set ()
234+
235+ # Check for direct name matches first (higher priority)
236+ name_matches = analyzer .imported_names & target_functions
237+ if name_matches :
238+ return True , name_matches
239+
240+ # If no direct matches, check if any imported modules could contain our target functions
241+ # This is a heuristic - we look for common patterns
242+ potential_matches = set ()
243+ for module in analyzer .imported_modules :
244+ # Check if module name suggests it could contain target functions
245+ for func_name in target_functions :
246+ # Only match if the module name is a prefix of the function qualified name
247+ func_parts = func_name .split ("." )
248+ if len (func_parts ) > 1 and module == func_parts [0 ]:
249+ # Module matches the first part of qualified name (e.g., mycode in mycode.target_function)
250+ # But only if we don't have specific import information suggesting otherwise
251+ potential_matches .add (func_name )
252+ elif any (part in module for part in func_name .split ("_" )) and len (func_name .split ("_" )) > 1 :
253+ # Function name parts match module name (for underscore-separated names)
254+ potential_matches .add (func_name )
255+
256+ # Only use heuristic matches if we haven't found specific function imports that contradict them
257+ return bool (potential_matches ), potential_matches
258+
259+ except (SyntaxError , UnicodeDecodeError , OSError ) as e :
260+ logger .debug (f"Failed to analyze imports in { test_file_path } : { e } " )
261+ # If we can't parse the file, be conservative and process it
262+ return True , set ()
263+
264+
265+ def filter_test_files_by_imports (
266+ file_to_test_map : dict [Path , list [TestsInFile ]],
267+ target_functions : set [str ]
268+ ) -> tuple [dict [Path , list [TestsInFile ]], dict [Path , set [str ]]]:
269+ """Filter test files based on import analysis to reduce Jedi processing.
270+
271+ Args:
272+ file_to_test_map: Original mapping of test files to test functions
273+ target_functions: Set of function names we're optimizing
274+
275+ Returns:
276+ Tuple of (filtered_file_map, import_analysis_results)
277+
278+ """
279+ if not target_functions :
280+ # If no target functions specified, process all files
281+ return file_to_test_map , {}
282+
283+ filtered_map = {}
284+ import_results = {}
285+
286+ for test_file , test_functions in file_to_test_map .items ():
287+ should_process , found_functions = analyze_imports_in_test_file (test_file , target_functions )
288+ import_results [test_file ] = found_functions
289+
290+ if should_process :
291+ filtered_map [test_file ] = test_functions
292+ else :
293+ logger .debug (f"Skipping { test_file } - no relevant imports found" )
294+
295+ logger .info (f"Import filter: Processing { len (filtered_map )} /{ len (file_to_test_map )} test files" )
296+ return filtered_map , import_results
297+
298+
140299def discover_unit_tests (
141- cfg : TestConfig , discover_only_these_tests : list [Path ] | None = None
300+ cfg : TestConfig , discover_only_these_tests : list [Path ] | None = None , functions_to_optimize : list [ FunctionToOptimize ] | None = None
142301) -> dict [str , list [FunctionCalledInTest ]]:
143302 framework_strategies : dict [str , Callable ] = {"pytest" : discover_tests_pytest , "unittest" : discover_tests_unittest }
144303 strategy = framework_strategies .get (cfg .test_framework , None )
145304 if not strategy :
146305 error_message = f"Unsupported test framework: { cfg .test_framework } "
147306 raise ValueError (error_message )
148307
149- return strategy (cfg , discover_only_these_tests )
308+ return strategy (cfg , discover_only_these_tests , functions_to_optimize )
150309
151310
152311def discover_tests_pytest (
153- cfg : TestConfig , discover_only_these_tests : list [Path ] | None = None
312+ cfg : TestConfig , discover_only_these_tests : list [Path ] | None = None , functions_to_optimize : list [ FunctionToOptimize ] | None = None
154313) -> dict [Path , list [FunctionCalledInTest ]]:
155314 tests_root = cfg .tests_root
156315 project_root = cfg .project_root_path
@@ -220,11 +379,11 @@ def discover_tests_pytest(
220379 continue
221380 file_to_test_map [test_obj .test_file ].append (test_obj )
222381 # Within these test files, find the project functions they are referring to and return their names/locations
223- return process_test_files (file_to_test_map , cfg )
382+ return process_test_files (file_to_test_map , cfg , functions_to_optimize )
224383
225384
226385def discover_tests_unittest (
227- cfg : TestConfig , discover_only_these_tests : list [str ] | None = None
386+ cfg : TestConfig , discover_only_these_tests : list [str ] | None = None , functions_to_optimize : list [ FunctionToOptimize ] | None = None
228387) -> dict [Path , list [FunctionCalledInTest ]]:
229388 tests_root : Path = cfg .tests_root
230389 loader : unittest .TestLoader = unittest .TestLoader ()
@@ -277,7 +436,7 @@ def get_test_details(_test: unittest.TestCase) -> TestsInFile | None:
277436 details = get_test_details (test )
278437 if details is not None :
279438 file_to_test_map [str (details .test_file )].append (details )
280- return process_test_files (file_to_test_map , cfg )
439+ return process_test_files (file_to_test_map , cfg , functions_to_optimize )
281440
282441
283442def discover_parameters_unittest (function_name : str ) -> tuple [bool , str , str | None ]:
@@ -289,13 +448,29 @@ def discover_parameters_unittest(function_name: str) -> tuple[bool, str, str | N
289448
290449
291450def process_test_files (
292- file_to_test_map : dict [Path , list [TestsInFile ]], cfg : TestConfig
451+ file_to_test_map : dict [Path , list [TestsInFile ]], cfg : TestConfig , functions_to_optimize : list [ FunctionToOptimize ] | None = None
293452) -> dict [str , list [FunctionCalledInTest ]]:
294453 import jedi
295454
296455 project_root_path = cfg .project_root_path
297456 test_framework = cfg .test_framework
298457
458+ # Apply import filter if functions to optimize are provided
459+ if functions_to_optimize :
460+ # Extract target function names from FunctionToOptimize objects
461+ # Include both qualified names and simple function names for better matching
462+ target_function_names = set ()
463+ for func in functions_to_optimize :
464+ target_function_names .add (func .qualified_name_with_modules_from_root (project_root_path ))
465+ target_function_names .add (func .function_name ) # Add simple name too
466+ # Also add qualified name without module
467+ if func .parents :
468+ target_function_names .add (f"{ func .parents [0 ].name } .{ func .function_name } " )
469+
470+ logger .debug (f"Target functions for import filtering: { target_function_names } " )
471+ file_to_test_map , import_results = filter_test_files_by_imports (file_to_test_map , target_function_names )
472+ logger .debug (f"Import analysis results: { len (import_results )} files analyzed" )
473+
299474 function_to_test_map = defaultdict (set )
300475 jedi_project = jedi .Project (path = project_root_path )
301476 goto_cache = {}
0 commit comments