Skip to content

Commit 14409f7

Browse files
committed
Revert helper functions definitions when they are not used anymore in the optimized FTO
1 parent 19dcbfb commit 14409f7

File tree

4 files changed

+1657
-11
lines changed

4 files changed

+1657
-11
lines changed

codeflash/context/unused_definition_remover.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
from __future__ import annotations
22

3+
import ast
4+
from collections import defaultdict
35
from dataclasses import dataclass, field
6+
from pathlib import Path
47
from typing import Optional
58

69
import libcst as cst
710

11+
from codeflash.cli_cmds.console import logger
12+
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
13+
from codeflash.models.models import CodeOptimizationContext, FunctionSource
14+
815

916
@dataclass
1017
class UsageInfo:
@@ -480,3 +487,210 @@ def print_definitions(definitions: dict[str, UsageInfo]) -> None:
480487
print(f" Used by qualified function: {info.used_by_qualified_function}")
481488
print(f" Dependencies: {', '.join(sorted(info.dependencies)) if info.dependencies else 'None'}")
482489
print()
490+
491+
492+
def revert_unused_helper_functions(
493+
project_root, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str]
494+
) -> None:
495+
"""Revert unused helper functions back to their original definitions.
496+
497+
Args:
498+
unused_helpers: List of unused helper functions to revert
499+
original_helper_code: Dictionary mapping file paths to their original code
500+
501+
"""
502+
if not unused_helpers:
503+
return
504+
505+
logger.info(f"Reverting {len(unused_helpers)} unused helper function(s) to original definitions")
506+
507+
# Group unused helpers by file path
508+
unused_helpers_by_file = defaultdict(list)
509+
for helper in unused_helpers:
510+
unused_helpers_by_file[helper.file_path].append(helper)
511+
512+
# For each file, revert the unused helper functions to their original definitions
513+
for file_path, helpers_in_file in unused_helpers_by_file.items():
514+
if file_path in original_helper_code:
515+
try:
516+
# Read current file content
517+
current_code = file_path.read_text(encoding="utf8")
518+
519+
# Get original code for this file
520+
original_code = original_helper_code[file_path]
521+
522+
# Use the code replacer to selectively revert only the unused helper functions
523+
helper_names = [helper.qualified_name for helper in helpers_in_file]
524+
reverted_code = replace_function_definitions_in_module(
525+
function_names=helper_names,
526+
optimized_code=original_code, # Use original code as the "optimized" code to revert
527+
module_abspath=file_path,
528+
preexisting_objects=set(), # Empty set since we're reverting
529+
project_root_path=project_root,
530+
)
531+
532+
if reverted_code:
533+
logger.debug(f"Reverted unused helpers in {file_path}: {', '.join(helper_names)}")
534+
535+
except Exception as e:
536+
logger.error(f"Error reverting unused helpers in {file_path}: {e}")
537+
538+
539+
def _analyze_imports_in_optimized_code(
540+
optimized_ast: ast.AST, code_context: CodeOptimizationContext
541+
) -> dict[str, set[str]]:
542+
"""Analyze import statements in optimized code to map imported names to qualified helper names.
543+
544+
Args:
545+
optimized_ast: The AST of the optimized code
546+
code_context: The code optimization context containing helper functions
547+
548+
Returns:
549+
Dictionary mapping imported names to sets of possible qualified helper names
550+
551+
"""
552+
imported_names_map = defaultdict(set)
553+
554+
# Create a lookup of helper functions by their simple names and file paths
555+
helpers_by_name = defaultdict(list)
556+
helpers_by_file = defaultdict(list)
557+
558+
for helper in code_context.helper_functions:
559+
if helper.jedi_definition.type != "class":
560+
helpers_by_name[helper.only_function_name].append(helper)
561+
module_name = helper.file_path.stem
562+
helpers_by_file[module_name].append(helper)
563+
564+
# Analyze import statements in the optimized code
565+
for node in ast.walk(optimized_ast):
566+
if isinstance(node, ast.ImportFrom):
567+
# Handle "from module import function" statements
568+
if node.module:
569+
module_name = node.module
570+
for alias in node.names:
571+
imported_name = alias.asname if alias.asname else alias.name
572+
original_name = alias.name
573+
574+
# Find helpers that match this import
575+
for helper in helpers_by_file.get(module_name, []):
576+
if helper.only_function_name == original_name:
577+
imported_names_map[imported_name].add(helper.qualified_name)
578+
imported_names_map[imported_name].add(helper.fully_qualified_name)
579+
580+
elif isinstance(node, ast.Import):
581+
# Handle "import module" statements
582+
for alias in node.names:
583+
imported_name = alias.asname if alias.asname else alias.name
584+
module_name = alias.name
585+
586+
# For "import module" statements, functions would be called as module.function
587+
for helper in helpers_by_file.get(module_name, []):
588+
full_call = f"{imported_name}.{helper.only_function_name}"
589+
imported_names_map[full_call].add(helper.qualified_name)
590+
imported_names_map[full_call].add(helper.fully_qualified_name)
591+
592+
return dict(imported_names_map)
593+
594+
595+
def detect_unused_helper_functions(
596+
function_to_optimize, code_context: CodeOptimizationContext, optimized_code: str
597+
) -> list[FunctionSource]:
598+
"""Detect helper functions that are no longer called by the optimized entrypoint function.
599+
600+
Args:
601+
code_context: The code optimization context containing helper functions
602+
optimized_code: The optimized code to analyze
603+
604+
Returns:
605+
List of FunctionSource objects representing unused helper functions
606+
607+
"""
608+
try:
609+
# Parse the optimized code to analyze function calls and imports
610+
optimized_ast = ast.parse(optimized_code)
611+
612+
# Find the optimized entrypoint function
613+
entrypoint_function_ast = None
614+
for node in ast.walk(optimized_ast):
615+
if isinstance(node, ast.FunctionDef) and node.name == function_to_optimize.function_name:
616+
entrypoint_function_ast = node
617+
break
618+
619+
if not entrypoint_function_ast:
620+
logger.debug(f"Could not find entrypoint function {function_to_optimize.function_name} in optimized code")
621+
return []
622+
623+
# First, analyze imports to build a mapping of imported names to their original qualified names
624+
imported_names_map = _analyze_imports_in_optimized_code(optimized_ast, code_context)
625+
626+
# Extract all function calls in the entrypoint function
627+
called_function_names = set()
628+
for node in ast.walk(entrypoint_function_ast):
629+
if isinstance(node, ast.Call):
630+
if isinstance(node.func, ast.Name):
631+
# Regular function call: function_name()
632+
called_name = node.func.id
633+
called_function_names.add(called_name)
634+
# Also add the qualified name if this is an imported function
635+
if called_name in imported_names_map:
636+
called_function_names.update(imported_names_map[called_name])
637+
elif isinstance(node.func, ast.Attribute):
638+
# Method call: obj.method() or self.method() or module.function()
639+
if isinstance(node.func.value, ast.Name):
640+
if node.func.value.id == "self":
641+
# self.method_name() -> add both method_name and ClassName.method_name
642+
called_function_names.add(node.func.attr)
643+
# For class methods, also add the qualified name
644+
if hasattr(function_to_optimize, "parents") and function_to_optimize.parents:
645+
class_name = function_to_optimize.parents[0].name
646+
called_function_names.add(f"{class_name}.{node.func.attr}")
647+
else:
648+
# obj.method() or module.function()
649+
attr_name = node.func.attr
650+
called_function_names.add(attr_name)
651+
called_function_names.add(f"{node.func.value.id}.{attr_name}")
652+
# Check if this is a module.function call that maps to a helper
653+
full_call = f"{node.func.value.id}.{attr_name}"
654+
if full_call in imported_names_map:
655+
called_function_names.update(imported_names_map[full_call])
656+
# Handle nested attribute access like obj.attr.method()
657+
else:
658+
called_function_names.add(node.func.attr)
659+
660+
logger.debug(f"Functions called in optimized entrypoint: {called_function_names}")
661+
logger.debug(f"Imported names mapping: {imported_names_map}")
662+
663+
# Find helper functions that are no longer called
664+
unused_helpers = []
665+
for helper_function in code_context.helper_functions:
666+
if helper_function.jedi_definition.type != "class":
667+
# Check if the helper function is called using multiple name variants
668+
helper_qualified_name = helper_function.qualified_name
669+
helper_simple_name = helper_function.only_function_name
670+
helper_fully_qualified_name = helper_function.fully_qualified_name
671+
672+
# Create a set of all possible names this helper might be called by
673+
possible_call_names = {helper_qualified_name, helper_simple_name, helper_fully_qualified_name}
674+
675+
# For cross-file helpers, also consider module-based calls
676+
if helper_function.file_path != function_to_optimize.file_path:
677+
# Add potential module.function combinations
678+
module_name = helper_function.file_path.stem
679+
possible_call_names.add(f"{module_name}.{helper_simple_name}")
680+
681+
# Check if any of the possible names are in the called functions
682+
is_called = bool(possible_call_names.intersection(called_function_names))
683+
684+
if not is_called:
685+
unused_helpers.append(helper_function)
686+
logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code")
687+
logger.debug(f" Checked names: {possible_call_names}")
688+
else:
689+
logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code")
690+
logger.debug(f" Called via: {possible_call_names.intersection(called_function_names)}")
691+
692+
return unused_helpers
693+
694+
except Exception as e:
695+
logger.debug(f"Error detecting unused helper functions: {e}")
696+
return []

codeflash/optimization/function_optimizer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
4444
from codeflash.code_utils.time_utils import humanize_runtime
4545
from codeflash.context import code_context_extractor
46+
from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions
4647
from codeflash.either import Failure, Success, is_successful
4748
from codeflash.models.ExperimentMetadata import ExperimentMetadata
4849
from codeflash.models.models import (
@@ -298,7 +299,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
298299
self.log_successful_optimization(explanation, generated_tests, exp_type)
299300

300301
self.replace_function_and_helpers_with_optimized_code(
301-
code_context=code_context, optimized_code=best_optimization.candidate.source_code
302+
code_context=code_context,
303+
optimized_code=best_optimization.candidate.source_code,
304+
original_helper_code=original_helper_code,
302305
)
303306

304307
new_code, new_helper_code = self.reformat_code_and_helpers(
@@ -612,7 +615,7 @@ def reformat_code_and_helpers(
612615
return new_code, new_helper_code
613616

614617
def replace_function_and_helpers_with_optimized_code(
615-
self, code_context: CodeOptimizationContext, optimized_code: str
618+
self, code_context: CodeOptimizationContext, optimized_code: str, original_helper_code: str
616619
) -> bool:
617620
did_update = False
618621
read_writable_functions_by_file_path = defaultdict(set)
@@ -630,6 +633,12 @@ def replace_function_and_helpers_with_optimized_code(
630633
preexisting_objects=code_context.preexisting_objects,
631634
project_root_path=self.project_root,
632635
)
636+
unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code)
637+
638+
# Revert unused helper functions to their original definitions
639+
if unused_helpers:
640+
revert_unused_helper_functions(self.project_root, unused_helpers, original_helper_code)
641+
633642
return did_update
634643

635644
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:

tests/test_code_replacement.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def sorter(arr):
6969
original_helper_code[helper_function_path] = helper_code
7070
func_optimizer.args = Args()
7171
func_optimizer.replace_function_and_helpers_with_optimized_code(
72-
code_context=code_context, optimized_code=optimized_code
72+
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
7373
)
7474
final_output = code_path.read_text(encoding="utf-8")
7575
assert "inconsequential_var = '123'" in final_output
@@ -804,7 +804,8 @@ def __init__(self, name):
804804
self.name = name
805805
806806
def main_method(self):
807-
return HelperClass(self.name).helper_method()"""
807+
return HelperClass(self.name).helper_method()
808+
"""
808809
file_path = Path(__file__).resolve()
809810
func_top_optimize = FunctionToOptimize(
810811
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
@@ -1662,6 +1663,7 @@ def new_function2(value):
16621663
)
16631664
assert new_code == original_code
16641665

1666+
16651667
def test_global_reassignment() -> None:
16661668
original_code = """a=1
16671669
print("Hello world")
@@ -1733,7 +1735,7 @@ def new_function2(value):
17331735
original_helper_code[helper_function_path] = helper_code
17341736
func_optimizer.args = Args()
17351737
func_optimizer.replace_function_and_helpers_with_optimized_code(
1736-
code_context=code_context, optimized_code=optimized_code
1738+
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
17371739
)
17381740
new_code = code_path.read_text(encoding="utf-8")
17391741
code_path.unlink(missing_ok=True)
@@ -1809,7 +1811,7 @@ def new_function2(value):
18091811
original_helper_code[helper_function_path] = helper_code
18101812
func_optimizer.args = Args()
18111813
func_optimizer.replace_function_and_helpers_with_optimized_code(
1812-
code_context=code_context, optimized_code=optimized_code
1814+
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
18131815
)
18141816
new_code = code_path.read_text(encoding="utf-8")
18151817
code_path.unlink(missing_ok=True)
@@ -1886,7 +1888,7 @@ def new_function2(value):
18861888
original_helper_code[helper_function_path] = helper_code
18871889
func_optimizer.args = Args()
18881890
func_optimizer.replace_function_and_helpers_with_optimized_code(
1889-
code_context=code_context, optimized_code=optimized_code
1891+
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
18901892
)
18911893
new_code = code_path.read_text(encoding="utf-8")
18921894
code_path.unlink(missing_ok=True)
@@ -1962,7 +1964,7 @@ def new_function2(value):
19621964
original_helper_code[helper_function_path] = helper_code
19631965
func_optimizer.args = Args()
19641966
func_optimizer.replace_function_and_helpers_with_optimized_code(
1965-
code_context=code_context, optimized_code=optimized_code
1967+
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
19661968
)
19671969
new_code = code_path.read_text(encoding="utf-8")
19681970
code_path.unlink(missing_ok=True)
@@ -2039,7 +2041,7 @@ def new_function2(value):
20392041
original_helper_code[helper_function_path] = helper_code
20402042
func_optimizer.args = Args()
20412043
func_optimizer.replace_function_and_helpers_with_optimized_code(
2042-
code_context=code_context, optimized_code=optimized_code
2044+
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
20432045
)
20442046
new_code = code_path.read_text(encoding="utf-8")
20452047
code_path.unlink(missing_ok=True)
@@ -2127,8 +2129,8 @@ def new_function2(value):
21272129
original_helper_code[helper_function_path] = helper_code
21282130
func_optimizer.args = Args()
21292131
func_optimizer.replace_function_and_helpers_with_optimized_code(
2130-
code_context=code_context, optimized_code=optimized_code
2132+
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
21312133
)
21322134
new_code = code_path.read_text(encoding="utf-8")
21332135
code_path.unlink(missing_ok=True)
2134-
assert new_code.rstrip() == expected_code.rstrip()
2136+
assert new_code.rstrip() == expected_code.rstrip()

0 commit comments

Comments
 (0)