diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index 273455d0b..bf0668b83 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -219,10 +219,18 @@ def __init__(self, function_names_to_find: set[str]) -> None: # For prefix match, store mapping from prefix-root to candidates for O(1) matching self._exact_names = function_names_to_find self._prefix_roots: dict[str, list[str]] = {} + # Precompute sets for faster lookup during visit_Attribute() + self._dot_names: set[str] = set() + self._dot_methods: dict[str, set[str]] = {} + self._class_method_to_target: dict[tuple[str, str], str] = {} for name in function_names_to_find: if "." in name: - root = name.split(".", 1)[0] - self._prefix_roots.setdefault(root, []).append(name) + root, method = name.rsplit(".", 1) + self._dot_names.add(name) + self._dot_methods.setdefault(method, set()).add(root) + self._class_method_to_target[(root, method)] = name + root_prefix = name.split(".", 1)[0] + self._prefix_roots.setdefault(root_prefix, []).append(name) def visit_Import(self, node: ast.Import) -> None: """Handle 'import module' statements.""" @@ -321,44 +329,45 @@ def visit_Attribute(self, node: ast.Attribute) -> None: if self.found_any_target_function: return + # Check if this is accessing a target function through an imported module + + node_value = node.value + node_attr = node.attr + # Check if this is accessing a target function through an imported module if ( - isinstance(node.value, ast.Name) - and node.value.id in self.imported_modules - and node.attr in self.function_names_to_find + isinstance(node_value, ast.Name) + and node_value.id in self.imported_modules + and node_attr in self.function_names_to_find ): self.found_any_target_function = True - self.found_qualified_name = node.attr + self.found_qualified_name = node_attr return - if isinstance(node.value, ast.Name) and node.value.id in self.imported_modules: - for target_func in self.function_names_to_find: - if "." in target_func: - class_name, method_name = target_func.rsplit(".", 1) - if node.attr == method_name: - imported_name = node.value.id - original_name = self.alias_mapping.get(imported_name, imported_name) - if original_name == class_name: - self.found_any_target_function = True - self.found_qualified_name = target_func - return + # Check for methods via imported modules using precomputed _dot_methods and _class_method_to_target + if isinstance(node_value, ast.Name) and node_value.id in self.imported_modules: + roots_possible = self._dot_methods.get(node_attr) + if roots_possible: + imported_name = node_value.id + original_name = self.alias_mapping.get(imported_name, imported_name) + if original_name in roots_possible: + self.found_any_target_function = True + self.found_qualified_name = self._class_method_to_target[(original_name, node_attr)] + return # Check if this is accessing a method on an instance variable - if isinstance(node.value, ast.Name) and node.value.id in self.instance_mapping: - class_name = self.instance_mapping[node.value.id] - for target_func in self.function_names_to_find: - if "." in target_func: - target_class, method_name = target_func.rsplit(".", 1) - if node.attr == method_name and class_name == target_class: - self.found_any_target_function = True - self.found_qualified_name = target_func - return - - # Check if this is accessing a target function through a dynamically imported module - # Only if we've detected dynamic imports are being used - if self.has_dynamic_imports and node.attr in self.function_names_to_find: + if isinstance(node_value, ast.Name) and node_value.id in self.instance_mapping: + class_name = self.instance_mapping[node_value.id] + roots_possible = self._dot_methods.get(node_attr) + if roots_possible and class_name in roots_possible: + self.found_any_target_function = True + self.found_qualified_name = self._class_method_to_target[(class_name, node_attr)] + return + + # Check for dynamic import match + if self.has_dynamic_imports and node_attr in self.function_names_to_find: self.found_any_target_function = True - self.found_qualified_name = node.attr + self.found_qualified_name = node_attr return self.generic_visit(node)