Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions codeflash/code_utils/static_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,24 @@ def analyze_imported_modules(
def get_first_top_level_object_def_ast(
object_name: str, object_type: type[ObjectDefT], node: ast.AST
) -> ObjectDefT | None:
for child in ast.iter_child_nodes(node):
if isinstance(child, object_type) and child.name == object_name:
return child
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
# Use a local variable for allowed skip types to avoid repeating tuple allocation
skip_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)

# Use a list and manual iteration for better cache locality and reduced Python call overhead
children = list(ast.iter_child_nodes(node))
for child in children:
# Shortcut: direct identity + string comparison at top level
if isinstance(child, object_type):
# hasattr check not needed, guaranteed by ast node type
if child.name == object_name:
return child
# Don't descend into this object's children
continue
if descendant := get_first_top_level_object_def_ast(object_name, object_type, child):
# Only descend into child nodes that aren't functions, classes
if isinstance(child, skip_types):
continue
descendant = get_first_top_level_object_def_ast(object_name, object_type, child)
if descendant is not None:
return descendant
return None

Expand All @@ -130,17 +142,19 @@ def get_first_top_level_function_or_method_ast(
function_name: str, parents: list[FunctionParent], node: ast.AST
) -> ast.FunctionDef | ast.AsyncFunctionDef | None:
if not parents:
# Try FunctionDef first, then AsyncFunctionDef only if needed. This prevents unnecessary tree walks.
result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, node)
if result is not None:
return result
return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, node)
if parents[0].type == "ClassDef" and (
class_node := get_first_top_level_object_def_ast(parents[0].name, ast.ClassDef, node)
):
result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node)
if result is not None:
return result
return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, class_node)
# Only check ClassDef if required
if parents[0].type == "ClassDef":
class_node = get_first_top_level_object_def_ast(parents[0].name, ast.ClassDef, node)
if class_node is not None:
result = get_first_top_level_object_def_ast(function_name, ast.FunctionDef, class_node)
if result is not None:
return result
return get_first_top_level_object_def_ast(function_name, ast.AsyncFunctionDef, class_node)
return None


Expand Down
Loading