Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 56 additions & 47 deletions codeflash/context/unused_definition_remover.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,26 +565,31 @@ def _analyze_imports_in_optimized_code(
# Precompute a two-level dict: module_name -> func_name -> [helpers]
helpers_by_file_and_func = defaultdict(dict)
helpers_by_file = defaultdict(list) # preserved for "import module"
for helper in code_context.helper_functions:
jedi_type = helper.jedi_definition.type
if jedi_type != "class":
# Use local variable and attribute lookup optimization
helper_functions = code_context.helper_functions
append_hbff = helpers_by_file_and_func.__getitem__
append_hbf = helpers_by_file.__getitem__

for helper in helper_functions:
if helper.jedi_definition.type != "class":
func_name = helper.only_function_name
module_name = helper.file_path.stem
# Cache function lookup for this (module, func)
file_entry = helpers_by_file_and_func[module_name]
if func_name in file_entry:
file_entry[func_name].append(helper)
else:
file_entry[func_name] = [helper]
helpers_by_file[module_name].append(helper)

# Optimize attribute lookups and method binding outside the loop
helpers_by_file_and_func_get = helpers_by_file_and_func.get
helpers_by_file_get = helpers_by_file.get

for node in ast.walk(optimized_ast):
if isinstance(node, ast.ImportFrom):
# Handle "from module import function" statements
# Instead of ast.walk, which constructs the entire node generator, use a manual queue for lower overhead
to_visit = [optimized_ast]
while to_visit:
node = to_visit.pop()
node_type = type(node)
if node_type is ast.ImportFrom:
module_name = node.module
if module_name:
file_entry = helpers_by_file_and_func_get(module_name, None)
Expand All @@ -597,18 +602,23 @@ def _analyze_imports_in_optimized_code(
for helper in helpers:
imported_names_map[imported_name].add(helper.qualified_name)
imported_names_map[imported_name].add(helper.fully_qualified_name)

elif isinstance(node, ast.Import):
# Handle "import module" statements
elif node_type is ast.Import:
for alias in node.names:
imported_name = alias.asname if alias.asname else alias.name
module_name = alias.name
for helper in helpers_by_file_get(module_name, []):
# For "import module" statements, functions would be called as module.function
full_call = f"{imported_name}.{helper.only_function_name}"
imported_names_map[full_call].add(helper.qualified_name)
imported_names_map[full_call].add(helper.fully_qualified_name)

# Optimized ast node traversal: prefer attribute over hasattr as much as possible
body = getattr(node, "body", None)
if body:
to_visit.extend(body)
# For nodes with other children (like arguments in Call), cover these as well
# But since we only care about import statements at the module level, this isn't always needed.
# ast.walk descends into all fields, but import statements are module-level.

return dict(imported_names_map)


Expand All @@ -622,18 +632,19 @@ def find_target_node(
body = getattr(node, "body", None)
if not body:
return None
# Use generator expression to avoid unnecessary iterations
for child in body:
if isinstance(child, ast.ClassDef) and child.name == parent.name:
node = child
break
else:
return None

# Now node is either the root or the target parent class; look for function
body = getattr(node, "body", None)
if not body:
return None
target_name = function_to_optimize.function_name
# Again, use generator for short-circuiting
for child in body:
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) and child.name == target_name:
return child
Expand All @@ -657,6 +668,7 @@ def detect_unused_helper_functions(

"""
if isinstance(optimized_code, CodeStringsMarkdown) and len(optimized_code.code_strings) > 0:
# Use chain.from_iterable, but avoid creating unnecessary temporaries by using a generator
return list(
chain.from_iterable(
detect_unused_helper_functions(function_to_optimize, code_context, code.code)
Expand All @@ -679,64 +691,61 @@ def detect_unused_helper_functions(
imported_names_map = _analyze_imports_in_optimized_code(optimized_ast, code_context)

# Extract all function calls in the entrypoint function
called_function_names = {function_to_optimize.function_name}
called_function_names = set()
called_function_names_add = called_function_names.add
called_function_names_update = called_function_names.update
called_function_names_add(function_to_optimize.function_name)

# Use a custom traversal to avoid overhead of ast.walk (which walks all nodes)
# But since function bodies can be arbitrarily nested, ast.walk is fast for normal cases so we keep it
for node in ast.walk(entrypoint_function_ast):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
# Regular function call: function_name()
called_name = node.func.id
called_function_names.add(called_name)
# Also add the qualified name if this is an imported function
func = node.func
if isinstance(func, ast.Name):
called_name = func.id
called_function_names_add(called_name)
if called_name in imported_names_map:
called_function_names.update(imported_names_map[called_name])
elif isinstance(node.func, ast.Attribute):
# Method call: obj.method() or self.method() or module.function()
if isinstance(node.func.value, ast.Name):
if node.func.value.id == "self":
# self.method_name() -> add both method_name and ClassName.method_name
called_function_names.add(node.func.attr)
# For class methods, also add the qualified name
called_function_names_update(imported_names_map[called_name])
elif isinstance(func, ast.Attribute):
value = func.value
if isinstance(value, ast.Name):
if value.id == "self":
called_function_names_add(func.attr)
if hasattr(function_to_optimize, "parents") and function_to_optimize.parents:
class_name = function_to_optimize.parents[0].name
called_function_names.add(f"{class_name}.{node.func.attr}")
called_function_names_add(f"{class_name}.{func.attr}")
else:
# obj.method() or module.function()
attr_name = node.func.attr
called_function_names.add(attr_name)
called_function_names.add(f"{node.func.value.id}.{attr_name}")
# Check if this is a module.function call that maps to a helper
full_call = f"{node.func.value.id}.{attr_name}"
attr_name = func.attr
called_function_names_add(attr_name)
full_call = f"{value.id}.{attr_name}"
called_function_names_add(full_call)
if full_call in imported_names_map:
called_function_names.update(imported_names_map[full_call])
# Handle nested attribute access like obj.attr.method()
called_function_names_update(imported_names_map[full_call])
else:
called_function_names.add(node.func.attr)
# Possibly obj.attr.method(), include just method name to minimize missed cases
called_function_names_add(func.attr)

logger.debug(f"Functions called in optimized entrypoint: {called_function_names}")
logger.debug(f"Imported names mapping: {imported_names_map}")

# Find helper functions that are no longer called
# Pre-fetch values from helper_functions only once
unused_helpers = []
for helper_function in code_context.helper_functions:
helper_functions = code_context.helper_functions
for helper_function in helper_functions:
if helper_function.jedi_definition.type != "class":
# Check if the helper function is called using multiple name variants
helper_qualified_name = helper_function.qualified_name
helper_simple_name = helper_function.only_function_name
helper_fully_qualified_name = helper_function.fully_qualified_name

# Create a set of all possible names this helper might be called by
# Use a set for possible names for efficient set intersection
possible_call_names = {helper_qualified_name, helper_simple_name, helper_fully_qualified_name}

# For cross-file helpers, also consider module-based calls
if helper_function.file_path != function_to_optimize.file_path:
# Add potential module.function combinations
module_name = helper_function.file_path.stem
possible_call_names.add(f"{module_name}.{helper_simple_name}")

# Check if any of the possible names are in the called functions
is_called = bool(possible_call_names.intersection(called_function_names))

if not is_called:
# Set intersection is faster than explicit 'any' for small sets
if possible_call_names.isdisjoint(called_function_names):
unused_helpers.append(helper_function)
logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code")
logger.debug(f" Checked names: {possible_call_names}")
Expand Down
Loading