Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions codeflash/code_utils/code_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,47 +935,51 @@ def _is_target_function_call(self, node: ast.Call) -> bool:
if not call_name:
return False

# Check if it matches directly
# Fast path: direct match
if call_name == self.target_function_name:
return True

# Check if it's just the base name matching
# Fast path: base name match, possibly imported, possibly local
if call_name == self.target_base_name:
# Could be imported with a different name, check imports
if call_name in self.imports:
imported_path = self.imports[call_name]
if imported_path == self.target_function_name or imported_path.endswith(
f".{self.target_function_name}"
):
imp = self.imports.get(call_name)
if imp is not None:
if imp == self.target_function_name or imp.endswith(f".{self.target_function_name}"):
return True
# Could also be a direct call if we're in the same file
return True

# Check for qualified calls with imports
# Fast path: check for qualified call using imports
call_parts = call_name.split(".")
if call_parts[0] in self.imports:
# Resolve the full path using imports
base_import = self.imports[call_parts[0]]
full_path = f"{base_import}.{'.'.join(call_parts[1:])}" if len(call_parts) > 1 else base_import

base_part = call_parts[0]
imp = self.imports.get(base_part)
if imp is not None:
# Compose the full import path once
if len(call_parts) > 1:
full_path = f"{imp}.{'.'.join(call_parts[1:])}"
else:
full_path = imp
if full_path == self.target_function_name or full_path.endswith(f".{self.target_function_name}"):
return True

return False

def _get_call_name(self, func_node) -> Optional[str]: # noqa : ANN001
def _get_call_name(self, func_node) -> Optional[str]:
"""Extract the name being called from a function node."""
# Fast path short-circuit for ast.Name nodes
if isinstance(func_node, ast.Name):
return func_node.id
# Build attribute chain with a single walk and no string op in loop
if isinstance(func_node, ast.Attribute):
# Use list-prepend and join for reversed chain (faster than += string)
parts = []
current = func_node
while isinstance(current, ast.Attribute):
parts.append(current.attr)
current = current.value
if isinstance(current, ast.Name):
parts.append(current.id)
return ".".join(reversed(parts))
# Avoid extra function call by reversing while joining
return ".".join(parts[::-1]) # slightly faster than reversed()
return None

def _extract_source_code(self, node: ast.FunctionDef) -> str:
Expand Down
Loading