Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 40 additions & 23 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@ def __init__(self, function_names_to_find: set[str]) -> None:
self.has_dynamic_imports: bool = False
self.wildcard_modules: set[str] = set()

# Precompute function_names for prefix search
# For prefix match, store mapping from prefix-root to candidates for O(1) matching
self._exact_names = function_names_to_find
self._prefix_roots = {}
for name in function_names_to_find:
if "." in name:
root = name.split(".", 1)[0]
self._prefix_roots.setdefault(root, []).append(name)

def visit_Import(self, node: ast.Import) -> None:
"""Handle 'import module' statements."""
if self.found_any_target_function:
Expand Down Expand Up @@ -181,40 +190,48 @@ def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if self.found_any_target_function:
return

if not node.module:
mod = node.module
if not mod:
return

fnames = self._exact_names
proots = self._prefix_roots

for alias in node.names:
if alias.name == "*":
self.wildcard_modules.add(node.module)
else:
imported_name = alias.asname if alias.asname else alias.name
self.imported_modules.add(imported_name)
aname = alias.name
if aname == "*":
self.wildcard_modules.add(mod)
continue

# Check for dynamic import functions
if node.module == "importlib" and alias.name == "import_module":
self.has_dynamic_imports = True
imported_name = alias.asname if alias.asname else aname
self.imported_modules.add(imported_name)

# Fast check for dynamic import
if mod == "importlib" and aname == "import_module":
self.has_dynamic_imports = True

qualified_name = f"{node.module}.{alias.name}"
potential_matches = {alias.name, qualified_name}
qname = f"{mod}.{aname}"

if any(name in self.function_names_to_find for name in potential_matches):
# Fast exact match check
if aname in fnames:
self.found_any_target_function = True
self.found_qualified_name = next(
name for name in potential_matches if name in self.function_names_to_find
)
self.found_qualified_name = aname
return

qualified_prefix = qualified_name + "."
if any(target_func.startswith(qualified_prefix) for target_func in self.function_names_to_find):
if qname in fnames:
self.found_any_target_function = True
self.found_qualified_name = next(
target_func
for target_func in self.function_names_to_find
if target_func.startswith(qualified_prefix)
)
self.found_qualified_name = qname
return

# Fast prefix match: only for relevant roots
prefix = qname + "."
# Only bother if one of the targets startswith the prefix-root
candidates = proots.get(qname, ())
for target_func in candidates:
if target_func.startswith(prefix):
self.found_any_target_function = True
self.found_qualified_name = target_func
return

def visit_Attribute(self, node: ast.Attribute) -> None:
"""Handle attribute access like module.function_name."""
if self.found_any_target_function:
Expand Down
Loading