Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 132 additions & 96 deletions codeflash/context/unused_definition_remover.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
from codeflash.models.models import CodeString, CodeStringsMarkdown
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import CodeOptimizationContext, CodeString, CodeStringsMarkdown, FunctionSource

if TYPE_CHECKING:
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
Expand Down Expand Up @@ -561,52 +562,65 @@ def _analyze_imports_in_optimized_code(
"""
imported_names_map = defaultdict(set)

# Precompute a two-level dict: module_name -> func_name -> [helpers]
helpers_by_file_and_func = defaultdict(dict)
helpers_by_file = defaultdict(list) # preserved for "import module"
for helper in code_context.helper_functions:
jedi_type = helper.jedi_definition.type
if jedi_type != "class":
func_name = helper.only_function_name
module_name = helper.file_path.stem
# Cache function lookup for this (module, func)
file_entry = helpers_by_file_and_func[module_name]
if func_name in file_entry:
file_entry[func_name].append(helper)
else:
file_entry[func_name] = [helper]
helpers_by_file[module_name].append(helper)
helpers_by_file = defaultdict(list)
# Local variable bindings for inner loop speed-up
helper_functions = code_context.helper_functions
append_hbf = helpers_by_file.__getitem__
# Precompute helper info as lists to reduce attribute lookups
for helper in helper_functions:
jedi_def = helper.jedi_definition
if jedi_def.type == "class":
continue
func_name = helper.only_function_name
module_name = helper.file_path.stem
file_entry = helpers_by_file_and_func[module_name]
file_entry.setdefault(func_name, []).append(helper)
append_hbf(module_name).append(helper)

# Optimize attribute lookups and method binding outside the loop
helpers_by_file_and_func_get = helpers_by_file_and_func.get
helpers_by_file_get = helpers_by_file.get

# Cache node class checks
ImportFrom = ast.ImportFrom
Import = ast.Import

# AST walk is the main hot loop
for node in ast.walk(optimized_ast):
if isinstance(node, ast.ImportFrom):
# We avoid isinstance lookup for every node attribute; only check for relevant node types
node_type = type(node)
if node_type is ImportFrom:
# Handle "from module import function" statements
module_name = node.module
if module_name:
file_entry = helpers_by_file_and_func_get(module_name, None)
if file_entry:
for alias in node.names:
imported_name = alias.asname if alias.asname else alias.name
original_name = alias.name
helpers = file_entry.get(original_name, None)
if helpers:
for helper in helpers:
imported_names_map[imported_name].add(helper.qualified_name)
imported_names_map[imported_name].add(helper.fully_qualified_name)

elif isinstance(node, ast.Import):
if not module_name:
continue
file_entry = helpers_by_file_and_func_get(module_name)
if not file_entry:
continue
for alias in node.names:
imported_name = alias.asname if alias.asname else alias.name
original_name = alias.name
helpers = file_entry.get(original_name)
if helpers:
# Invariant: 1 or more helpers per name, no setdefault needed
s = imported_names_map[imported_name]
for helper in helpers:
s.add(helper.qualified_name)
s.add(helper.fully_qualified_name)
elif node_type is Import:
# Handle "import module" statements
for alias in node.names:
imported_name = alias.asname if alias.asname else alias.name
module_name = alias.name
for helper in helpers_by_file_get(module_name, []):
helpers = helpers_by_file_get(module_name)
if not helpers:
continue
for helper in helpers:
# For "import module" statements, functions would be called as module.function
full_call = f"{imported_name}.{helper.only_function_name}"
imported_names_map[full_call].add(helper.qualified_name)
imported_names_map[full_call].add(helper.fully_qualified_name)
s = imported_names_map[full_call]
s.add(helper.qualified_name)
s.add(helper.fully_qualified_name)

return dict(imported_names_map)

Expand All @@ -627,97 +641,119 @@ def detect_unused_helper_functions(
List of FunctionSource objects representing unused helper functions

"""
# Fast return for markdown multi-code search (flatten result by chaining)
if isinstance(optimized_code, CodeStringsMarkdown) and len(optimized_code.code_strings) > 0:
return list(
chain.from_iterable(
detect_unused_helper_functions(function_to_optimize, code_context, code.code)
for code in optimized_code.code_strings
)
)

try:
# Parse the optimized code to analyze function calls and imports
optimized_ast = ast.parse(optimized_code)

# Find the optimized entrypoint function
# Find the optimized entrypoint function efficiently by scanning top-level nodes first
entrypoint_function_ast = None
for node in ast.walk(optimized_ast):
if isinstance(node, ast.FunctionDef) and node.name == function_to_optimize.function_name:
fn_name = function_to_optimize.function_name
for node in ast.iter_child_nodes(optimized_ast):
if isinstance(node, ast.FunctionDef) and node.name == fn_name:
entrypoint_function_ast = node
break
# If not found at top-level, fallback to full AST walk (rare)
if not entrypoint_function_ast:
for node in ast.walk(optimized_ast):
if isinstance(node, ast.FunctionDef) and node.name == fn_name:
entrypoint_function_ast = node
break

if not entrypoint_function_ast:
logger.debug(f"Could not find entrypoint function {function_to_optimize.function_name} in optimized code")
logger.debug(f"Could not find entrypoint function {fn_name} in optimized code")
return []

# First, analyze imports to build a mapping of imported names to their original qualified names
# Pre-analyze and cache all needed values
imported_names_map = _analyze_imports_in_optimized_code(optimized_ast, code_context)

# Extract all function calls in the entrypoint function
called_function_names = set()
# AST walk for all calls inside entrypoint, batched local var reuse for speed
entry_parents = getattr(function_to_optimize, "parents", None)
# Hot attributes for helper handling
Name = ast.Name
Attribute = ast.Attribute
Call = ast.Call

for node in ast.walk(entrypoint_function_ast):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
# Regular function call: function_name()
called_name = node.func.id
called_function_names.add(called_name)
# Also add the qualified name if this is an imported function
if called_name in imported_names_map:
called_function_names.update(imported_names_map[called_name])
elif isinstance(node.func, ast.Attribute):
# Method call: obj.method() or self.method() or module.function()
if isinstance(node.func.value, ast.Name):
if node.func.value.id == "self":
# self.method_name() -> add both method_name and ClassName.method_name
called_function_names.add(node.func.attr)
# For class methods, also add the qualified name
if hasattr(function_to_optimize, "parents") and function_to_optimize.parents:
class_name = function_to_optimize.parents[0].name
called_function_names.add(f"{class_name}.{node.func.attr}")
else:
# obj.method() or module.function()
attr_name = node.func.attr
called_function_names.add(attr_name)
called_function_names.add(f"{node.func.value.id}.{attr_name}")
# Check if this is a module.function call that maps to a helper
full_call = f"{node.func.value.id}.{attr_name}"
if full_call in imported_names_map:
called_function_names.update(imported_names_map[full_call])
# Handle nested attribute access like obj.attr.method()
# Skip everything but calls
if type(node) is not Call:
continue
func = node.func
func_type = type(func)
if func_type is Name:
# function_name()
called_name = func.id
called_function_names.add(called_name)
imported = imported_names_map.get(called_name)
if imported:
called_function_names.update(imported)
elif func_type is Attribute:
value = func.value
if isinstance(value, Name):
val_id = value.id
attr = func.attr
if val_id == "self":
# self.method_name()
called_function_names.add(attr)
# For class methods, also add the qualified name
if entry_parents:
class_name = entry_parents[0].name
called_function_names.add(f"{class_name}.{attr}")
else:
called_function_names.add(node.func.attr)
# obj.method() or module.function()
called_function_names.add(attr)
full_call = f"{val_id}.{attr}"
called_function_names.add(full_call)
imported = imported_names_map.get(full_call)
if imported:
called_function_names.update(imported)
else:
# obj.attr.method() (nested); just add the attr name for best effort
called_function_names.add(func.attr)

logger.debug(f"Functions called in optimized entrypoint: {called_function_names}")
logger.debug(f"Imported names mapping: {imported_names_map}")

# Find helper functions that are no longer called
unused_helpers = []
for helper_function in code_context.helper_functions:
if helper_function.jedi_definition.type != "class":
# Check if the helper function is called using multiple name variants
helper_qualified_name = helper_function.qualified_name
helper_simple_name = helper_function.only_function_name
helper_fully_qualified_name = helper_function.fully_qualified_name

# Create a set of all possible names this helper might be called by
possible_call_names = {helper_qualified_name, helper_simple_name, helper_fully_qualified_name}

# For cross-file helpers, also consider module-based calls
if helper_function.file_path != function_to_optimize.file_path:
# Add potential module.function combinations
module_name = helper_function.file_path.stem
possible_call_names.add(f"{module_name}.{helper_simple_name}")

# Check if any of the possible names are in the called functions
is_called = bool(possible_call_names.intersection(called_function_names))

if not is_called:
unused_helpers.append(helper_function)
logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code")
logger.debug(f" Checked names: {possible_call_names}")
else:
logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code")
logger.debug(f" Called via: {possible_call_names.intersection(called_function_names)}")
helper_functions = code_context.helper_functions
entrypoint_file = function_to_optimize.file_path
# Make called_function_names a set for fast lookup; .intersection() is slow for large sets
called_fn_names = called_function_names
# For every helper, check if any known name is present in the set
for helper_function in helper_functions:
if helper_function.jedi_definition.type == "class":
continue

helper_qualified_name = helper_function.qualified_name
helper_simple_name = helper_function.only_function_name
helper_fully_qualified_name = helper_function.fully_qualified_name

possible_call_names = [helper_qualified_name, helper_simple_name, helper_fully_qualified_name]
# For cross-file helpers, also consider module-based calls
if helper_function.file_path != entrypoint_file:
module_name = helper_function.file_path.stem
possible_call_names.append(f"{module_name}.{helper_simple_name}")
# Short-circuit as soon as any call name is found, avoid .intersection overhead
is_called = False
for n in possible_call_names:
if n in called_fn_names:
is_called = True
break
if not is_called:
unused_helpers.append(helper_function)
logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code")
logger.debug(f" Checked names: {set(possible_call_names)}")
else:
logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code")
logger.debug(f" Called via: {set(possible_call_names) & called_fn_names}")

ret_val = unused_helpers

Expand Down
Loading