@@ -141,144 +141,175 @@ def close(self) -> None:
141141
142142
143143class ImportAnalyzer (ast .NodeVisitor ):
144- """AST-based analyzer to find all imports in a test file."""
144+ """AST-based analyzer to check if any qualified names from function_names_to_find are imported or used in a test file."""
145145
146146 def __init__ (self , function_names_to_find : set [str ]) -> None :
147147 self .function_names_to_find = function_names_to_find
148- self .imported_names : set [str ] = set ()
148+ self .found_any_target_function : bool = False
149+ self .found_qualified_name = None
149150 self .imported_modules : set [str ] = set ()
150- self .found_target_functions : set [ str ] = set ()
151- self .qualified_names_called : set [str ] = set ()
151+ self .has_dynamic_imports : bool = False
152+ self .wildcard_modules : set [str ] = set ()
152153
153154 def visit_Import (self , node : ast .Import ) -> None :
154155 """Handle 'import module' statements."""
156+ if self .found_any_target_function :
157+ return
158+
155159 for alias in node .names :
156160 module_name = alias .asname if alias .asname else alias .name
157161 self .imported_modules .add (module_name )
158- self .imported_names .add (module_name )
159- self .generic_visit (node )
162+
163+ # Check for dynamic import modules
164+ if alias .name == "importlib" :
165+ self .has_dynamic_imports = True
166+
167+ # Check if module itself is a target qualified name
168+ if module_name in self .function_names_to_find :
169+ self .found_any_target_function = True
170+ self .found_qualified_name = module_name
171+ return
172+ # Check if any target qualified name starts with this module
173+ for target_func in self .function_names_to_find :
174+ if target_func .startswith (f"{ module_name } ." ):
175+ self .found_any_target_function = True
176+ self .found_qualified_name = target_func
177+ return
160178
161179 def visit_ImportFrom (self , node : ast .ImportFrom ) -> None :
162180 """Handle 'from module import name' statements."""
163- if node .module :
164- self .imported_modules .add (node .module )
181+ if self .found_any_target_function :
182+ return
183+
184+ if not node .module :
185+ return
165186
166187 for alias in node .names :
167188 if alias .name == "*" :
168- continue
169- imported_name = alias .asname if alias .asname else alias .name
170- self .imported_names .add (imported_name )
171- if alias .name in self .function_names_to_find :
172- self .found_target_functions .add (alias .name )
173- # Check for qualified name matches
174- if node .module :
189+ self .wildcard_modules .add (node .module )
190+ else :
191+ imported_name = alias .asname if alias .asname else alias .name
192+ self .imported_modules .add (imported_name )
193+
194+ # Check for dynamic import functions
195+ if node .module == "importlib" and alias .name == "import_module" :
196+ self .has_dynamic_imports = True
197+
198+ # Check if imported name is a target qualified name
199+ if alias .name in self .function_names_to_find :
200+ self .found_any_target_function = True
201+ self .found_qualified_name = alias .name
202+ return
203+ # Check if module.name forms a target qualified name
175204 qualified_name = f"{ node .module } .{ alias .name } "
176205 if qualified_name in self .function_names_to_find :
177- self .found_target_functions .add (qualified_name )
178- self .generic_visit (node )
206+ self .found_any_target_function = True
207+ self .found_qualified_name = qualified_name
208+ return
209+
210+ def visit_Attribute (self , node : ast .Attribute ) -> None :
211+ """Handle attribute access like module.function_name."""
212+ if self .found_any_target_function :
213+ return
179214
180- def visit_Call (self , node : ast .Call ) -> None :
181- """Handle dynamic imports like importlib.import_module() or __import__()."""
215+ # Check if this is accessing a target function through an imported module
182216 if (
183- isinstance (node .func , ast .Name )
184- and node .func .id == "__import__"
185- and node .args
186- and isinstance (node .args [0 ], ast .Constant )
187- and isinstance (node .args [0 ].value , str )
217+ isinstance (node .value , ast .Name )
218+ and node .value .id in self .imported_modules
219+ and node .attr in self .function_names_to_find
188220 ):
189- # __import__("module_name")
190- self .imported_modules .add (node .args [0 ].value )
191- elif (
192- isinstance (node .func , ast .Attribute )
193- and isinstance (node .func .value , ast .Name )
194- and node .func .value .id == "importlib"
195- and node .func .attr == "import_module"
196- and node .args
197- and isinstance (node .args [0 ], ast .Constant )
198- and isinstance (node .args [0 ].value , str )
199- ):
200- # importlib.import_module("module_name")
201- self .imported_modules .add (node .args [0 ].value )
202- self .generic_visit (node )
221+ self .found_any_target_function = True
222+ self .found_qualified_name = node .attr
223+ return
203224
204- def visit_Name (self , node : ast .Name ) -> None :
205- """Check if any name usage matches our target functions."""
206- if node .id in self .function_names_to_find :
207- self .found_target_functions .add (node .id )
208- self .generic_visit (node )
225+ # Check if this is accessing a target function through a dynamically imported module
226+ # Only if we've detected dynamic imports are being used
227+ if self .has_dynamic_imports and node .attr in self .function_names_to_find :
228+ self .found_any_target_function = True
229+ self .found_qualified_name = node .attr
230+ return
209231
210- def visit_Attribute (self , node : ast .Attribute ) -> None :
211- """Handle module.function_name patterns."""
212- if node .attr in self .function_names_to_find :
213- self .found_target_functions .add (node .attr )
214- if isinstance (node .value , ast .Name ):
215- qualified_name = f"{ node .value .id } .{ node .attr } "
216- self .qualified_names_called .add (qualified_name )
217232 self .generic_visit (node )
218233
234+ def visit_Name (self , node : ast .Name ) -> None :
235+ """Handle direct name usage like target_function()."""
236+ if self .found_any_target_function :
237+ return
219238
220- def analyze_imports_in_test_file (test_file_path : Path | str , target_functions : set [str ]) -> tuple [bool , set [str ]]:
221- """Analyze imports in a test file to determine if it might test any target functions.
239+ # Check for __import__ usage
240+ if node .id == "__import__" :
241+ self .has_dynamic_imports = True
222242
223- Args:
224- test_file_path: Path to the test file
225- target_functions: Set of function names we're looking for
243+ if node .id in self .function_names_to_find :
244+ self .found_any_target_function = True
245+ self .found_qualified_name = node .id
246+ return
247+
248+ # Check if this name could come from a wildcard import
249+ for wildcard_module in self .wildcard_modules :
250+ for target_func in self .function_names_to_find :
251+ # Check if target_func is from this wildcard module and name matches
252+ if target_func .startswith (f"{ wildcard_module } ." ) and target_func .endswith (f".{ node .id } " ):
253+ self .found_any_target_function = True
254+ self .found_qualified_name = target_func
255+ return
226256
227- Returns:
228- Tuple of (should_process_with_jedi, found_function_names)
257+ self .generic_visit (node )
229258
230- """
231- if isinstance (test_file_path , str ):
232- test_file_path = Path (test_file_path )
259+ def generic_visit (self , node : ast .AST ) -> None :
260+ """Override generic_visit to stop traversal if a target function is found."""
261+ if self .found_any_target_function :
262+ return
263+ super ().generic_visit (node )
233264
234- try :
235- with test_file_path .open ("r" , encoding = "utf-8" ) as f :
236- content = f .read ()
237265
238- tree = ast .parse (content , filename = str (test_file_path ))
266+ def analyze_imports_in_test_file (test_file_path : Path | str , target_functions : set [str ]) -> bool :
267+ """Analyze a test file to see if it imports any of the target functions."""
268+ try :
269+ with Path (test_file_path ).open ("r" , encoding = "utf-8" ) as f :
270+ source_code = f .read ()
271+ tree = ast .parse (source_code , filename = str (test_file_path ))
239272 analyzer = ImportAnalyzer (target_functions )
240273 analyzer .visit (tree )
241-
242- if analyzer .found_target_functions :
243- return True , analyzer .found_target_functions
244-
245- return False , set () # noqa: TRY300
246-
247- except (SyntaxError , UnicodeDecodeError , OSError ) as e :
274+ except (SyntaxError , FileNotFoundError ) as e :
248275 logger .debug (f"Failed to analyze imports in { test_file_path } : { e } " )
249- return True , set ()
276+ return True
277+ else :
278+ if analyzer .found_any_target_function :
279+ logger .debug (f"Test file { test_file_path } imports target function: { analyzer .found_qualified_name } " )
280+ return True
281+ logger .debug (f"Test file { test_file_path } does not import any target functions." )
282+ return False
250283
251284
252285def filter_test_files_by_imports (
253286 file_to_test_map : dict [Path , list [TestsInFile ]], target_functions : set [str ]
254- ) -> tuple [ dict [Path , list [TestsInFile ]], dict [ Path , set [ str ] ]]:
287+ ) -> dict [Path , list [TestsInFile ]]:
255288 """Filter test files based on import analysis to reduce Jedi processing.
256289
257290 Args:
258291 file_to_test_map: Original mapping of test files to test functions
259292 target_functions: Set of function names we're optimizing
260293
261294 Returns:
262- Tuple of (filtered_file_map, import_analysis_results)
295+ Filtered mapping of test files to test functions
263296
264297 """
265298 if not target_functions :
266- return file_to_test_map , {}
299+ return file_to_test_map
267300
268- filtered_map = {}
269- import_results = {}
301+ logger .debug (f"Target functions for import filtering: { target_functions } " )
270302
303+ filtered_map = {}
271304 for test_file , test_functions in file_to_test_map .items ():
272- should_process , found_functions = analyze_imports_in_test_file (test_file , target_functions )
273- import_results [test_file ] = found_functions
274-
305+ should_process = analyze_imports_in_test_file (test_file , target_functions )
275306 if should_process :
276307 filtered_map [test_file ] = test_functions
277- else :
278- logger .debug (f"Skipping { test_file } - no relevant imports found" )
279308
280- logger .debug (f"Import filter: Processing { len (filtered_map )} /{ len (file_to_test_map )} test files" )
281- return filtered_map , import_results
309+ logger .debug (
310+ f"analyzed { len (file_to_test_map )} test files for imports, filtered down to { len (filtered_map )} relevant files"
311+ )
312+ return filtered_map
282313
283314
284315def discover_unit_tests (
@@ -296,7 +327,6 @@ def discover_unit_tests(
296327 functions_to_optimize = None
297328 if file_to_funcs_to_optimize :
298329 functions_to_optimize = [func for funcs_list in file_to_funcs_to_optimize .values () for func in funcs_list ]
299-
300330 function_to_tests , num_discovered_tests = strategy (cfg , discover_only_these_tests , functions_to_optimize )
301331 return function_to_tests , num_discovered_tests
302332
@@ -455,12 +485,8 @@ def process_test_files(
455485 test_framework = cfg .test_framework
456486
457487 if functions_to_optimize :
458- target_function_names = set ()
459- for func in functions_to_optimize :
460- target_function_names .add (func .qualified_name )
461- logger .debug (f"Target functions for import filtering: { target_function_names } " )
462- file_to_test_map , import_results = filter_test_files_by_imports (file_to_test_map , target_function_names )
463- logger .debug (f"Import analysis results: { len (import_results )} files analyzed" )
488+ target_function_names = {func .qualified_name for func in functions_to_optimize }
489+ file_to_test_map = filter_test_files_by_imports (file_to_test_map , target_function_names )
464490
465491 function_to_test_map = defaultdict (set )
466492 num_discovered_tests = 0
0 commit comments