diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index ad12b44c4..0e5578c57 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -151,6 +151,15 @@ def __init__(self, function_names_to_find: set[str]) -> None: self.has_dynamic_imports: bool = False self.wildcard_modules: set[str] = set() + # Precompute function_names for prefix search + # For prefix match, store mapping from prefix-root to candidates for O(1) matching + self._exact_names = function_names_to_find + self._prefix_roots = {} + for name in function_names_to_find: + if "." in name: + root = name.split(".", 1)[0] + self._prefix_roots.setdefault(root, []).append(name) + def visit_Import(self, node: ast.Import) -> None: """Handle 'import module' statements.""" if self.found_any_target_function: @@ -181,40 +190,48 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None: if self.found_any_target_function: return - if not node.module: + mod = node.module + if not mod: return + fnames = self._exact_names + proots = self._prefix_roots + for alias in node.names: - if alias.name == "*": - self.wildcard_modules.add(node.module) - else: - imported_name = alias.asname if alias.asname else alias.name - self.imported_modules.add(imported_name) + aname = alias.name + if aname == "*": + self.wildcard_modules.add(mod) + continue - # Check for dynamic import functions - if node.module == "importlib" and alias.name == "import_module": - self.has_dynamic_imports = True + imported_name = alias.asname if alias.asname else aname + self.imported_modules.add(imported_name) + + # Fast check for dynamic import + if mod == "importlib" and aname == "import_module": + self.has_dynamic_imports = True - qualified_name = f"{node.module}.{alias.name}" - potential_matches = {alias.name, qualified_name} + qname = f"{mod}.{aname}" - if any(name in self.function_names_to_find for name in potential_matches): + # Fast exact match check + if aname in fnames: self.found_any_target_function = True - self.found_qualified_name = next( - name for name in potential_matches if name in self.function_names_to_find - ) + self.found_qualified_name = aname return - - qualified_prefix = qualified_name + "." - if any(target_func.startswith(qualified_prefix) for target_func in self.function_names_to_find): + if qname in fnames: self.found_any_target_function = True - self.found_qualified_name = next( - target_func - for target_func in self.function_names_to_find - if target_func.startswith(qualified_prefix) - ) + self.found_qualified_name = qname return + # Fast prefix match: only for relevant roots + prefix = qname + "." + # Only bother if one of the targets startswith the prefix-root + candidates = proots.get(qname, ()) + for target_func in candidates: + if target_func.startswith(prefix): + self.found_any_target_function = True + self.found_qualified_name = target_func + return + def visit_Attribute(self, node: ast.Attribute) -> None: """Handle attribute access like module.function_name.""" if self.found_any_target_function: