@@ -212,6 +212,8 @@ def __init__(self, function_names_to_find: set[str]) -> None:
212212 self .wildcard_modules : set [str ] = set ()
213213 # Track aliases: alias_name -> original_name
214214 self .alias_mapping : dict [str , str ] = {}
215+ # Track instances: variable_name -> class_name
216+ self .instance_mapping : dict [str , str ] = {}
215217
216218 # Precompute function_names for prefix search
217219 # For prefix match, store mapping from prefix-root to candidates for O(1) matching
@@ -247,6 +249,24 @@ def visit_Import(self, node: ast.Import) -> None:
247249 self .found_qualified_name = target_func
248250 return
249251
252+ def visit_Assign (self , node : ast .Assign ) -> None :
253+ """Track variable assignments, especially class instantiations."""
254+ if self .found_any_target_function :
255+ return
256+
257+ # Check if the assignment is a class instantiation
258+ if isinstance (node .value , ast .Call ) and isinstance (node .value .func , ast .Name ):
259+ class_name = node .value .func .id
260+ if class_name in self .imported_modules :
261+ # Track all target variables as instances of the imported class
262+ for target in node .targets :
263+ if isinstance (target , ast .Name ):
264+ # Map the variable to the actual class name (handling aliases)
265+ original_class = self .alias_mapping .get (class_name , class_name )
266+ self .instance_mapping [target .id ] = original_class
267+
268+ self .generic_visit (node )
269+
250270 def visit_ImportFrom (self , node : ast .ImportFrom ) -> None :
251271 """Handle 'from module import name' statements."""
252272 if self .found_any_target_function :
@@ -296,17 +316,6 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
296316 self .found_qualified_name = target_func
297317 return
298318
299- # Check if any target function is a method of the imported class/module
300- # e.g., importing Graph and looking for Graph.topologicalSort
301- # TODO will pick up all tests which have the same class name (could be coming from a different file)
302- for target_func in fnames :
303- if "." in target_func :
304- class_name , method_name = target_func .split ("." , 1 )
305- if aname == class_name :
306- self .found_any_target_function = True
307- self .found_qualified_name = target_func
308- return
309-
310319 def visit_Attribute (self , node : ast .Attribute ) -> None :
311320 """Handle attribute access like module.function_name."""
312321 if self .found_any_target_function :
@@ -334,6 +343,17 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
334343 self .found_qualified_name = target_func
335344 return
336345
346+ # Check if this is accessing a method on an instance variable
347+ if isinstance (node .value , ast .Name ) and node .value .id in self .instance_mapping :
348+ class_name = self .instance_mapping [node .value .id ]
349+ for target_func in self .function_names_to_find :
350+ if "." in target_func :
351+ target_class , method_name = target_func .rsplit ("." , 1 )
352+ if node .attr == method_name and class_name == target_class :
353+ self .found_any_target_function = True
354+ self .found_qualified_name = target_func
355+ return
356+
337357 # Check if this is accessing a target function through a dynamically imported module
338358 # Only if we've detected dynamic imports are being used
339359 if self .has_dynamic_imports and node .attr in self .function_names_to_find :
0 commit comments