Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 88 additions & 88 deletions codeflash/context/unused_definition_remover.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,53 +551,53 @@ def _analyze_imports_in_optimized_code(
"""
imported_names_map = defaultdict(set)

# Precompute a two-level dict: module_name -> func_name -> [helpers]
helpers_by_file_and_func = defaultdict(dict)
helpers_by_file = defaultdict(list) # preserved for "import module"
helpers_append = helpers_by_file_and_func.setdefault
# Prepare one-pass lookup: module_name -> func_name -> [helpers], and module_name -> [helpers]
helpers_by_file_and_func = {}
helpers_by_file = {}
for helper in code_context.helper_functions:
jedi_type = helper.jedi_definition.type
if jedi_type != "class":
func_name = helper.only_function_name
module_name = helper.file_path.stem
# Cache function lookup for this (module, func)
file_entry = helpers_by_file_and_func[module_name]
if func_name in file_entry:
file_entry[func_name].append(helper)
else:
file_entry[func_name] = [helper]
helpers_by_file[module_name].append(helper)

# Optimize attribute lookups and method binding outside the loop
if jedi_type == "class":
continue
func_name = helper.only_function_name
module_name = helper.file_path.stem
file_entry = helpers_by_file_and_func.setdefault(module_name, {})
file_entry.setdefault(func_name, []).append(helper)
helpers_by_file.setdefault(module_name, []).append(helper)

# Optimize lookups: create shortcut functions
helpers_by_file_and_func_get = helpers_by_file_and_func.get
helpers_by_file_get = helpers_by_file.get

for node in ast.walk(optimized_ast):
# Only walk once for imports, use a generator for both Import and ImportFrom
nodes = [n for n in ast.iter_child_nodes(optimized_ast) if isinstance(n, (ast.Import, ast.ImportFrom))]
for node in nodes:
if isinstance(node, ast.ImportFrom):
# Handle "from module import function" statements
module_name = node.module
if module_name:
file_entry = helpers_by_file_and_func_get(module_name, None)
file_entry = helpers_by_file_and_func_get(module_name)
if file_entry:
for alias in node.names:
imported_name = alias.asname if alias.asname else alias.name
imported_name = alias.asname or alias.name
original_name = alias.name
helpers = file_entry.get(original_name, None)
helpers = file_entry.get(original_name)
if helpers:
# Only add each possible helper name once
imported_set = imported_names_map[imported_name]
for helper in helpers:
imported_names_map[imported_name].add(helper.qualified_name)
imported_names_map[imported_name].add(helper.fully_qualified_name)

imported_set.add(helper.qualified_name)
imported_set.add(helper.fully_qualified_name)
elif isinstance(node, ast.Import):
# Handle "import module" statements
for alias in node.names:
imported_name = alias.asname if alias.asname else alias.name
imported_name = alias.asname or alias.name
module_name = alias.name
for helper in helpers_by_file_get(module_name, []):
# For "import module" statements, functions would be called as module.function
full_call = f"{imported_name}.{helper.only_function_name}"
imported_names_map[full_call].add(helper.qualified_name)
imported_names_map[full_call].add(helper.fully_qualified_name)
helpers_list = helpers_by_file_get(module_name)
if helpers_list:
for helper in helpers_list:
# "import module": functions called as module.function
full_call = f"{imported_name}.{helper.only_function_name}"
callset = imported_names_map[full_call]
callset.add(helper.qualified_name)
callset.add(helper.fully_qualified_name)

return dict(imported_names_map)

Expand All @@ -616,88 +616,88 @@ def detect_unused_helper_functions(

"""
try:
# Parse the optimized code to analyze function calls and imports
optimized_ast = ast.parse(optimized_code)

# Find the optimized entrypoint function
entrypoint_function_ast = None
for node in ast.walk(optimized_ast):
if isinstance(node, ast.FunctionDef) and node.name == function_to_optimize.function_name:
entrypoint_function_ast = node
break
# Find the optimized entrypoint function early (using generator for early break)
entrypoint_function_name = function_to_optimize.function_name
entrypoint_function_ast = next(
(
node
for node in ast.walk(optimized_ast)
if isinstance(node, ast.FunctionDef) and node.name == entrypoint_function_name
),
None,
)

if not entrypoint_function_ast:
logger.debug(f"Could not find entrypoint function {function_to_optimize.function_name} in optimized code")
logger.debug(f"Could not find entrypoint function {entrypoint_function_name} in optimized code")
return []

# First, analyze imports to build a mapping of imported names to their original qualified names
imported_names_map = _analyze_imports_in_optimized_code(optimized_ast, code_context)

# Extract all function calls in the entrypoint function
# Extract all called function names in entrypoint AST, collecting variants in one pass
called_function_names = set()
parents = getattr(function_to_optimize, "parents", None)
class_name = parents[0].name if parents else None

for node in ast.walk(entrypoint_function_ast):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
# Regular function call: function_name()
called_name = node.func.id
func = node.func
if isinstance(func, ast.Name):
called_name = func.id
called_function_names.add(called_name)
# Also add the qualified name if this is an imported function
if called_name in imported_names_map:
called_function_names.update(imported_names_map[called_name])
elif isinstance(node.func, ast.Attribute):
# Method call: obj.method() or self.method() or module.function()
if isinstance(node.func.value, ast.Name):
if node.func.value.id == "self":
# self.method_name() -> add both method_name and ClassName.method_name
called_function_names.add(node.func.attr)
# For class methods, also add the qualified name
if hasattr(function_to_optimize, "parents") and function_to_optimize.parents:
class_name = function_to_optimize.parents[0].name
called_function_names.add(f"{class_name}.{node.func.attr}")
elif isinstance(func, ast.Attribute):
val = func.value
attr_name = func.attr
# Method call: self.method() or module.function() or obj.method()
if isinstance(val, ast.Name):
val_id = val.id
if val_id == "self":
called_function_names.add(attr_name)
if class_name:
called_function_names.add(f"{class_name}.{attr_name}")
else:
# obj.method() or module.function()
attr_name = node.func.attr
called_function_names.add(attr_name)
called_function_names.add(f"{node.func.value.id}.{attr_name}")
# Check if this is a module.function call that maps to a helper
full_call = f"{node.func.value.id}.{attr_name}"
full_call = f"{val_id}.{attr_name}"
called_function_names.add(full_call)
if full_call in imported_names_map:
called_function_names.update(imported_names_map[full_call])
# Handle nested attribute access like obj.attr.method()
else:
called_function_names.add(node.func.attr)
# obj.attr.method()
called_function_names.add(attr_name)

logger.debug(f"Functions called in optimized entrypoint: {called_function_names}")
logger.debug(f"Imported names mapping: {imported_names_map}")

# Find helper functions that are no longer called
# Precompute entrypoint's file_path for fast comparison
entrypoint_file_path = function_to_optimize.file_path

# Prefetch attributes to reduce lookup cost inside loop
unused_helpers = []
for helper_function in code_context.helper_functions:
if helper_function.jedi_definition.type != "class":
# Check if the helper function is called using multiple name variants
helper_qualified_name = helper_function.qualified_name
helper_simple_name = helper_function.only_function_name
helper_fully_qualified_name = helper_function.fully_qualified_name

# Create a set of all possible names this helper might be called by
possible_call_names = {helper_qualified_name, helper_simple_name, helper_fully_qualified_name}

# For cross-file helpers, also consider module-based calls
if helper_function.file_path != function_to_optimize.file_path:
# Add potential module.function combinations
module_name = helper_function.file_path.stem
possible_call_names.add(f"{module_name}.{helper_simple_name}")

# Check if any of the possible names are in the called functions
is_called = bool(possible_call_names.intersection(called_function_names))

if not is_called:
unused_helpers.append(helper_function)
logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code")
logger.debug(f" Checked names: {possible_call_names}")
else:
logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code")
logger.debug(f" Called via: {possible_call_names.intersection(called_function_names)}")
jedi_type = helper_function.jedi_definition.type
if jedi_type == "class":
continue

helper_qualified_name = helper_function.qualified_name
helper_simple_name = helper_function.only_function_name
helper_fully_qualified_name = helper_function.fully_qualified_name
possible_call_names = {helper_qualified_name, helper_simple_name, helper_fully_qualified_name}

# For cross-file helpers, add module.function variant
if helper_function.file_path != entrypoint_file_path:
module_name = helper_function.file_path.stem
possible_call_names.add(f"{module_name}.{helper_simple_name}")

if not possible_call_names & called_function_names:
unused_helpers.append(helper_function)
logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code")
logger.debug(f" Checked names: {possible_call_names}")
else:
logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code")
logger.debug(f" Called via: {possible_call_names & called_function_names}")

return unused_helpers

Expand Down
Loading