Skip to content

Commit 0f2c747

Browse files
authored
Merge pull request #877 from codeflash-ai/codeflash/optimize-pr867-2025-11-05T08.18.39
⚡️ Speed up method `ImportAnalyzer.visit_Attribute` by 17% in PR #867 (`inspect-signature-issue`)
2 parents f305633 + b7225e7 commit 0f2c747

File tree

1 file changed

+27
-20
lines changed

1 file changed

+27
-20
lines changed

codeflash/discovery/discover_unit_tests.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -223,14 +223,20 @@ def __init__(self, function_names_to_find: set[str]) -> None:
223223
self._dot_names: set[str] = set()
224224
self._dot_methods: dict[str, set[str]] = {}
225225
self._class_method_to_target: dict[tuple[str, str], str] = {}
226+
227+
# Optimize prefix-roots and dot_methods construction
228+
add_dot_methods = self._dot_methods.setdefault
229+
add_prefix_roots = self._prefix_roots.setdefault
230+
dot_names_add = self._dot_names.add
231+
class_method_to_target = self._class_method_to_target
226232
for name in function_names_to_find:
227233
if "." in name:
228234
root, method = name.rsplit(".", 1)
229-
self._dot_names.add(name)
230-
self._dot_methods.setdefault(method, set()).add(root)
231-
self._class_method_to_target[(root, method)] = name
235+
dot_names_add(name)
236+
add_dot_methods(method, set()).add(root)
237+
class_method_to_target[(root, method)] = name
232238
root_prefix = name.split(".", 1)[0]
233-
self._prefix_roots.setdefault(root_prefix, []).append(name)
239+
add_prefix_roots(root_prefix, []).append(name)
234240

235241
def visit_Import(self, node: ast.Import) -> None:
236242
"""Handle 'import module' statements."""
@@ -353,20 +359,18 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
353359
node_attr = node.attr
354360

355361
# Check if this is accessing a target function through an imported module
356-
if (
357-
isinstance(node_value, ast.Name)
358-
and node_value.id in self.imported_modules
359-
and node_attr in self.function_names_to_find
360-
):
361-
self.found_any_target_function = True
362-
self.found_qualified_name = node_attr
363-
return
364362

365-
# Check for methods via imported modules using precomputed _dot_methods and _class_method_to_target
366-
if isinstance(node_value, ast.Name) and node_value.id in self.imported_modules:
363+
# Accessing a target function through an imported module (fast path for imported modules)
364+
val_id = getattr(node_value, "id", None)
365+
if val_id is not None and val_id in self.imported_modules:
366+
if node_attr in self.function_names_to_find:
367+
self.found_any_target_function = True
368+
self.found_qualified_name = node_attr
369+
return
370+
# Methods via imported modules using precomputed _dot_methods and _class_method_to_target
367371
roots_possible = self._dot_methods.get(node_attr)
368372
if roots_possible:
369-
imported_name = node_value.id
373+
imported_name = val_id
370374
original_name = self.alias_mapping.get(imported_name, imported_name)
371375
if original_name in roots_possible:
372376
self.found_any_target_function = True
@@ -381,9 +385,9 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
381385
)
382386
return
383387

384-
# Check if this is accessing a method on an instance variable
385-
if isinstance(node_value, ast.Name) and node_value.id in self.instance_mapping:
386-
class_name = self.instance_mapping[node_value.id]
388+
# Methods on instance variables (tighten type check and lookup, important for larger ASTs)
389+
if val_id is not None and val_id in self.instance_mapping:
390+
class_name = self.instance_mapping[val_id]
387391
roots_possible = self._dot_methods.get(node_attr)
388392
if roots_possible and class_name in roots_possible:
389393
self.found_any_target_function = True
@@ -396,7 +400,9 @@ def visit_Attribute(self, node: ast.Attribute) -> None:
396400
self.found_qualified_name = node_attr
397401
return
398402

399-
self.generic_visit(node)
403+
# Replace self.generic_visit with base class impl directly: removes an attribute lookup
404+
if not self.found_any_target_function:
405+
ast.NodeVisitor.generic_visit(self, node)
400406

401407
def visit_Call(self, node: ast.Call) -> None:
402408
"""Handle function calls, particularly __import__."""
@@ -442,7 +448,8 @@ def generic_visit(self, node: ast.AST) -> None:
442448
"""Override generic_visit to stop traversal if a target function is found."""
443449
if self.found_any_target_function:
444450
return
445-
super().generic_visit(node)
451+
# Direct base call improves run speed (avoids extra method resolution)
452+
ast.NodeVisitor.generic_visit(self, node)
446453

447454

448455
def analyze_imports_in_test_file(test_file_path: Path | str, target_functions: set[str]) -> bool:

0 commit comments

Comments
 (0)