diff --git a/codeflash/discovery/discover_unit_tests.py b/codeflash/discovery/discover_unit_tests.py index b0e141093..78aa2a1ec 100644 --- a/codeflash/discovery/discover_unit_tests.py +++ b/codeflash/discovery/discover_unit_tests.py @@ -223,14 +223,20 @@ def __init__(self, function_names_to_find: set[str]) -> None: self._dot_names: set[str] = set() self._dot_methods: dict[str, set[str]] = {} self._class_method_to_target: dict[tuple[str, str], str] = {} + + # Optimize prefix-roots and dot_methods construction + add_dot_methods = self._dot_methods.setdefault + add_prefix_roots = self._prefix_roots.setdefault + dot_names_add = self._dot_names.add + class_method_to_target = self._class_method_to_target for name in function_names_to_find: if "." in name: root, method = name.rsplit(".", 1) - self._dot_names.add(name) - self._dot_methods.setdefault(method, set()).add(root) - self._class_method_to_target[(root, method)] = name + dot_names_add(name) + add_dot_methods(method, set()).add(root) + class_method_to_target[(root, method)] = name root_prefix = name.split(".", 1)[0] - self._prefix_roots.setdefault(root_prefix, []).append(name) + add_prefix_roots(root_prefix, []).append(name) def visit_Import(self, node: ast.Import) -> None: """Handle 'import module' statements.""" @@ -353,20 +359,18 @@ def visit_Attribute(self, node: ast.Attribute) -> None: node_attr = node.attr # Check if this is accessing a target function through an imported module - if ( - isinstance(node_value, ast.Name) - and node_value.id in self.imported_modules - and node_attr in self.function_names_to_find - ): - self.found_any_target_function = True - self.found_qualified_name = node_attr - return - # Check for methods via imported modules using precomputed _dot_methods and _class_method_to_target - if isinstance(node_value, ast.Name) and node_value.id in self.imported_modules: + # Accessing a target function through an imported module (fast path for imported modules) + val_id = getattr(node_value, "id", None) + if val_id is not None and val_id in self.imported_modules: + if node_attr in self.function_names_to_find: + self.found_any_target_function = True + self.found_qualified_name = node_attr + return + # Methods via imported modules using precomputed _dot_methods and _class_method_to_target roots_possible = self._dot_methods.get(node_attr) if roots_possible: - imported_name = node_value.id + imported_name = val_id original_name = self.alias_mapping.get(imported_name, imported_name) if original_name in roots_possible: self.found_any_target_function = True @@ -381,9 +385,9 @@ def visit_Attribute(self, node: ast.Attribute) -> None: ) return - # Check if this is accessing a method on an instance variable - if isinstance(node_value, ast.Name) and node_value.id in self.instance_mapping: - class_name = self.instance_mapping[node_value.id] + # Methods on instance variables (tighten type check and lookup, important for larger ASTs) + if val_id is not None and val_id in self.instance_mapping: + class_name = self.instance_mapping[val_id] roots_possible = self._dot_methods.get(node_attr) if roots_possible and class_name in roots_possible: self.found_any_target_function = True @@ -396,7 +400,9 @@ def visit_Attribute(self, node: ast.Attribute) -> None: self.found_qualified_name = node_attr return - self.generic_visit(node) + # Replace self.generic_visit with base class impl directly: removes an attribute lookup + if not self.found_any_target_function: + ast.NodeVisitor.generic_visit(self, node) def visit_Call(self, node: ast.Call) -> None: """Handle function calls, particularly __import__.""" @@ -442,7 +448,8 @@ def generic_visit(self, node: ast.AST) -> None: """Override generic_visit to stop traversal if a target function is found.""" if self.found_any_target_function: return - super().generic_visit(node) + # Direct base call improves run speed (avoids extra method resolution) + ast.NodeVisitor.generic_visit(self, node) def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> bool: