Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 27 additions & 20 deletions codeflash/discovery/discover_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,14 +223,20 @@ def __init__(self, function_names_to_find: set[str]) -> None:
self._dot_names: set[str] = set()
self._dot_methods: dict[str, set[str]] = {}
self._class_method_to_target: dict[tuple[str, str], str] = {}

# Optimize prefix-roots and dot_methods construction
add_dot_methods = self._dot_methods.setdefault
add_prefix_roots = self._prefix_roots.setdefault
dot_names_add = self._dot_names.add
class_method_to_target = self._class_method_to_target
for name in function_names_to_find:
if "." in name:
root, method = name.rsplit(".", 1)
self._dot_names.add(name)
self._dot_methods.setdefault(method, set()).add(root)
self._class_method_to_target[(root, method)] = name
dot_names_add(name)
add_dot_methods(method, set()).add(root)
class_method_to_target[(root, method)] = name
root_prefix = name.split(".", 1)[0]
self._prefix_roots.setdefault(root_prefix, []).append(name)
add_prefix_roots(root_prefix, []).append(name)

def visit_Import(self, node: ast.Import) -> None:
"""Handle 'import module' statements."""
Expand Down Expand Up @@ -353,20 +359,18 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
node_attr = node.attr

# Check if this is accessing a target function through an imported module
if (
isinstance(node_value, ast.Name)
and node_value.id in self.imported_modules
and node_attr in self.function_names_to_find
):
self.found_any_target_function = True
self.found_qualified_name = node_attr
return

# Check for methods via imported modules using precomputed _dot_methods and _class_method_to_target
if isinstance(node_value, ast.Name) and node_value.id in self.imported_modules:
# Accessing a target function through an imported module (fast path for imported modules)
val_id = getattr(node_value, "id", None)
if val_id is not None and val_id in self.imported_modules:
if node_attr in self.function_names_to_find:
self.found_any_target_function = True
self.found_qualified_name = node_attr
return
# Methods via imported modules using precomputed _dot_methods and _class_method_to_target
roots_possible = self._dot_methods.get(node_attr)
if roots_possible:
imported_name = node_value.id
imported_name = val_id
original_name = self.alias_mapping.get(imported_name, imported_name)
if original_name in roots_possible:
self.found_any_target_function = True
Expand All @@ -381,9 +385,9 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
)
return

# Check if this is accessing a method on an instance variable
if isinstance(node_value, ast.Name) and node_value.id in self.instance_mapping:
class_name = self.instance_mapping[node_value.id]
# Methods on instance variables (tighten type check and lookup, important for larger ASTs)
if val_id is not None and val_id in self.instance_mapping:
class_name = self.instance_mapping[val_id]
roots_possible = self._dot_methods.get(node_attr)
if roots_possible and class_name in roots_possible:
self.found_any_target_function = True
Expand All @@ -396,7 +400,9 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
self.found_qualified_name = node_attr
return

self.generic_visit(node)
# Replace self.generic_visit with base class impl directly: removes an attribute lookup
if not self.found_any_target_function:
ast.NodeVisitor.generic_visit(self, node)

def visit_Call(self, node: ast.Call) -> None:
"""Handle function calls, particularly __import__."""
Expand Down Expand Up @@ -442,7 +448,8 @@ def generic_visit(self, node: ast.AST) -> None:
"""Override generic_visit to stop traversal if a target function is found."""
if self.found_any_target_function:
return
super().generic_visit(node)
# Direct base call improves run speed (avoids extra method resolution)
ast.NodeVisitor.generic_visit(self, node)


def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> bool:
Expand Down
Loading