diff --git a/codeflash/code_utils/static_analysis.py b/codeflash/code_utils/static_analysis.py index 0151e29e7..b8b87cdfb 100644 --- a/codeflash/code_utils/static_analysis.py +++ b/codeflash/code_utils/static_analysis.py @@ -116,12 +116,24 @@ def analyze_imported_modules( def get_first_top_level_object_def_ast( object_name: str, object_type: type[ObjectDefT], node: ast.AST ) -> ObjectDefT | None: - for child in ast.iter_child_nodes(node): - if isinstance(child, object_type) and child.name == object_name: - return child - if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + # Use a local variable for allowed skip types to avoid repeating tuple allocation + skip_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef) + + # Use a list and manual iteration for better cache locality and reduced Python call overhead + children = list(ast.iter_child_nodes(node)) + for child in children: + # Shortcut: direct identity + string comparison at top level + if isinstance(child, object_type): + # hasattr check not needed, guaranteed by ast node type + if child.name == object_name: + return child + # Don't descend into this object's children continue - if descendant := get_first_top_level_object_def_ast(object_name, object_type, child): + # Only descend into child nodes that aren't functions, classes + if isinstance(child, skip_types): + continue + descendant = get_first_top_level_object_def_ast(object_name, object_type, child) + if descendant is not None: return descendant return None @@ -130,17 +142,19 @@ def get_first_top_level_function_or_method_ast( function_name: str, parents: list[FunctionParent], node: ast.AST ) -> ast.FunctionDef | ast.AsyncFunctionDef | None: if not parents: + # Try FunctionDef first, then AsyncFunctionDef only if needed. This prevents unnecessary tree walks. result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node) if result is not None: return result return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, node) - if parents[0].type == "ClassDef" and ( - class_node := get_first_top_level_object_def_ast(parents[0].name, ast.ClassDef, node) - ): - result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node) - if result is not None: - return result - return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, class_node) + # Only check ClassDef if required + if parents[0].type == "ClassDef": + class_node = get_first_top_level_object_def_ast(parents[0].name, ast.ClassDef, node) + if class_node is not None: + result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node) + if result is not None: + return result + return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, class_node) return None