Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 214 additions & 0 deletions codeflash/context/unused_definition_remover.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from __future__ import annotations

import ast
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

import libcst as cst

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
from codeflash.models.models import CodeOptimizationContext, FunctionSource


@dataclass
class UsageInfo:
Expand Down Expand Up @@ -480,3 +487,210 @@ def print_definitions(definitions: dict[str, UsageInfo]) -> None:
print(f" Used by qualified function: {info.used_by_qualified_function}")
print(f" Dependencies: {', '.join(sorted(info.dependencies)) if info.dependencies else 'None'}")
print()


def revert_unused_helper_functions(
project_root, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str]
) -> None:
"""Revert unused helper functions back to their original definitions.

Args:
unused_helpers: List of unused helper functions to revert
original_helper_code: Dictionary mapping file paths to their original code

"""
if not unused_helpers:
return

logger.info(f"Reverting {len(unused_helpers)} unused helper function(s) to original definitions")

Comment on lines +508 to +509
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the info level on purpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think its good for now, i want to know where i am reverting so that i can clean out any bugs. I plan to remove it later

# Group unused helpers by file path
unused_helpers_by_file = defaultdict(list)
for helper in unused_helpers:
unused_helpers_by_file[helper.file_path].append(helper)

# For each file, revert the unused helper functions to their original definitions
for file_path, helpers_in_file in unused_helpers_by_file.items():
if file_path in original_helper_code:
try:
# Read current file content
current_code = file_path.read_text(encoding="utf8")

# Get original code for this file
original_code = original_helper_code[file_path]

# Use the code replacer to selectively revert only the unused helper functions
helper_names = [helper.qualified_name for helper in helpers_in_file]
reverted_code = replace_function_definitions_in_module(
function_names=helper_names,
optimized_code=original_code, # Use original code as the "optimized" code to revert
module_abspath=file_path,
preexisting_objects=set(), # Empty set since we're reverting
project_root_path=project_root,
)

if reverted_code:
logger.debug(f"Reverted unused helpers in {file_path}: {', '.join(helper_names)}")

except Exception as e:
logger.error(f"Error reverting unused helpers in {file_path}: {e}")


def _analyze_imports_in_optimized_code(
optimized_ast: ast.AST, code_context: CodeOptimizationContext
) -> dict[str, set[str]]:
"""Analyze import statements in optimized code to map imported names to qualified helper names.

Args:
optimized_ast: The AST of the optimized code
code_context: The code optimization context containing helper functions

Returns:
Dictionary mapping imported names to sets of possible qualified helper names

"""
imported_names_map = defaultdict(set)

# Create a lookup of helper functions by their simple names and file paths
helpers_by_name = defaultdict(list)
helpers_by_file = defaultdict(list)

for helper in code_context.helper_functions:
if helper.jedi_definition.type != "class":
helpers_by_name[helper.only_function_name].append(helper)
module_name = helper.file_path.stem
helpers_by_file[module_name].append(helper)

# Analyze import statements in the optimized code
for node in ast.walk(optimized_ast):
if isinstance(node, ast.ImportFrom):
# Handle "from module import function" statements
if node.module:
module_name = node.module
for alias in node.names:
imported_name = alias.asname if alias.asname else alias.name
original_name = alias.name

# Find helpers that match this import
for helper in helpers_by_file.get(module_name, []):
if helper.only_function_name == original_name:
imported_names_map[imported_name].add(helper.qualified_name)
imported_names_map[imported_name].add(helper.fully_qualified_name)

elif isinstance(node, ast.Import):
# Handle "import module" statements
for alias in node.names:
imported_name = alias.asname if alias.asname else alias.name
module_name = alias.name

# For "import module" statements, functions would be called as module.function
for helper in helpers_by_file.get(module_name, []):
full_call = f"{imported_name}.{helper.only_function_name}"
imported_names_map[full_call].add(helper.qualified_name)
imported_names_map[full_call].add(helper.fully_qualified_name)

return dict(imported_names_map)


def detect_unused_helper_functions(
function_to_optimize, code_context: CodeOptimizationContext, optimized_code: str
) -> list[FunctionSource]:
"""Detect helper functions that are no longer called by the optimized entrypoint function.

Args:
code_context: The code optimization context containing helper functions
optimized_code: The optimized code to analyze

Returns:
List of FunctionSource objects representing unused helper functions

"""
try:
# Parse the optimized code to analyze function calls and imports
optimized_ast = ast.parse(optimized_code)

# Find the optimized entrypoint function
entrypoint_function_ast = None
for node in ast.walk(optimized_ast):
if isinstance(node, ast.FunctionDef) and node.name == function_to_optimize.function_name:
entrypoint_function_ast = node
break

if not entrypoint_function_ast:
logger.debug(f"Could not find entrypoint function {function_to_optimize.function_name} in optimized code")
return []

# First, analyze imports to build a mapping of imported names to their original qualified names
imported_names_map = _analyze_imports_in_optimized_code(optimized_ast, code_context)

# Extract all function calls in the entrypoint function
called_function_names = set()
for node in ast.walk(entrypoint_function_ast):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
# Regular function call: function_name()
called_name = node.func.id
called_function_names.add(called_name)
# Also add the qualified name if this is an imported function
if called_name in imported_names_map:
called_function_names.update(imported_names_map[called_name])
elif isinstance(node.func, ast.Attribute):
# Method call: obj.method() or self.method() or module.function()
if isinstance(node.func.value, ast.Name):
if node.func.value.id == "self":
# self.method_name() -> add both method_name and ClassName.method_name
called_function_names.add(node.func.attr)
# For class methods, also add the qualified name
if hasattr(function_to_optimize, "parents") and function_to_optimize.parents:
class_name = function_to_optimize.parents[0].name
called_function_names.add(f"{class_name}.{node.func.attr}")
else:
# obj.method() or module.function()
attr_name = node.func.attr
called_function_names.add(attr_name)
called_function_names.add(f"{node.func.value.id}.{attr_name}")
# Check if this is a module.function call that maps to a helper
full_call = f"{node.func.value.id}.{attr_name}"
if full_call in imported_names_map:
called_function_names.update(imported_names_map[full_call])
# Handle nested attribute access like obj.attr.method()
else:
called_function_names.add(node.func.attr)

logger.debug(f"Functions called in optimized entrypoint: {called_function_names}")
logger.debug(f"Imported names mapping: {imported_names_map}")

# Find helper functions that are no longer called
unused_helpers = []
for helper_function in code_context.helper_functions:
if helper_function.jedi_definition.type != "class":
# Check if the helper function is called using multiple name variants
helper_qualified_name = helper_function.qualified_name
helper_simple_name = helper_function.only_function_name
helper_fully_qualified_name = helper_function.fully_qualified_name

# Create a set of all possible names this helper might be called by
possible_call_names = {helper_qualified_name, helper_simple_name, helper_fully_qualified_name}

# For cross-file helpers, also consider module-based calls
if helper_function.file_path != function_to_optimize.file_path:
# Add potential module.function combinations
module_name = helper_function.file_path.stem
possible_call_names.add(f"{module_name}.{helper_simple_name}")

# Check if any of the possible names are in the called functions
is_called = bool(possible_call_names.intersection(called_function_names))

if not is_called:
unused_helpers.append(helper_function)
logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code")
logger.debug(f" Checked names: {possible_call_names}")
else:
logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code")
logger.debug(f" Called via: {possible_call_names.intersection(called_function_names)}")

return unused_helpers

except Exception as e:
logger.debug(f"Error detecting unused helper functions: {e}")
return []
15 changes: 12 additions & 3 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
from codeflash.code_utils.time_utils import humanize_runtime
from codeflash.context import code_context_extractor
from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions
from codeflash.either import Failure, Success, is_successful
from codeflash.models.ExperimentMetadata import ExperimentMetadata
from codeflash.models.models import (
Expand Down Expand Up @@ -298,7 +299,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
self.log_successful_optimization(explanation, generated_tests, exp_type)

self.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=best_optimization.candidate.source_code
code_context=code_context,
optimized_code=best_optimization.candidate.source_code,
original_helper_code=original_helper_code,
)

new_code, new_helper_code = self.reformat_code_and_helpers(
Expand Down Expand Up @@ -411,7 +414,7 @@ def determine_best_candidate(
code_print(candidate.source_code)
try:
did_update = self.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=candidate.source_code
code_context=code_context, optimized_code=candidate.source_code, original_helper_code=original_helper_code,
)
if not did_update:
logger.warning(
Expand Down Expand Up @@ -612,7 +615,7 @@ def reformat_code_and_helpers(
return new_code, new_helper_code

def replace_function_and_helpers_with_optimized_code(
self, code_context: CodeOptimizationContext, optimized_code: str
self, code_context: CodeOptimizationContext, optimized_code: str, original_helper_code: str
) -> bool:
did_update = False
read_writable_functions_by_file_path = defaultdict(set)
Expand All @@ -630,6 +633,12 @@ def replace_function_and_helpers_with_optimized_code(
preexisting_objects=code_context.preexisting_objects,
project_root_path=self.project_root,
)
unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code)

# Revert unused helper functions to their original definitions
if unused_helpers:
revert_unused_helper_functions(self.project_root, unused_helpers, original_helper_code)

return did_update

def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:
Expand Down
20 changes: 11 additions & 9 deletions tests/test_code_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def sorter(arr):
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
final_output = code_path.read_text(encoding="utf-8")
assert "inconsequential_var = '123'" in final_output
Expand Down Expand Up @@ -804,7 +804,8 @@ def __init__(self, name):
self.name = name

def main_method(self):
return HelperClass(self.name).helper_method()"""
return HelperClass(self.name).helper_method()
"""
file_path = Path(__file__).resolve()
func_top_optimize = FunctionToOptimize(
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
Expand Down Expand Up @@ -1662,6 +1663,7 @@ def new_function2(value):
)
assert new_code == original_code


def test_global_reassignment() -> None:
original_code = """a=1
print("Hello world")
Expand Down Expand Up @@ -1733,7 +1735,7 @@ def new_function2(value):
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
Expand Down Expand Up @@ -1809,7 +1811,7 @@ def new_function2(value):
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
Expand Down Expand Up @@ -1886,7 +1888,7 @@ def new_function2(value):
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
Expand Down Expand Up @@ -1962,7 +1964,7 @@ def new_function2(value):
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
Expand Down Expand Up @@ -2039,7 +2041,7 @@ def new_function2(value):
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
Expand Down Expand Up @@ -2127,8 +2129,8 @@ def new_function2(value):
original_helper_code[helper_function_path] = helper_code
func_optimizer.args = Args()
func_optimizer.replace_function_and_helpers_with_optimized_code(
code_context=code_context, optimized_code=optimized_code
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
)
new_code = code_path.read_text(encoding="utf-8")
code_path.unlink(missing_ok=True)
assert new_code.rstrip() == expected_code.rstrip()
assert new_code.rstrip() == expected_code.rstrip()
Loading
Loading