Skip to content

Commit b4ab00b

Browse files
Merge pull request #296 from codeflash-ai/revert-helper-function-is-unused
Revert helper functions definitions when they are not used anymore in the optimized FTO
2 parents 9be93b6 + 5e7733c commit b4ab00b

File tree

4 files changed

+1682
-12
lines changed

4 files changed

+1682
-12
lines changed

codeflash/context/unused_definition_remover.py

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
from __future__ import annotations
22

3+
import ast
4+
from collections import defaultdict
35
from dataclasses import dataclass, field
6+
from pathlib import Path
47
from typing import Optional
58

69
import libcst as cst
710

11+
from codeflash.cli_cmds.console import logger
12+
from codeflash.code_utils.code_replacer import replace_function_definitions_in_module
13+
from codeflash.models.models import CodeOptimizationContext, FunctionSource
14+
815

916
@dataclass
1017
class UsageInfo:
@@ -483,3 +490,220 @@ def print_definitions(definitions: dict[str, UsageInfo]) -> None:
483490
print(f" Used by qualified function: {info.used_by_qualified_function}")
484491
print(f" Dependencies: {', '.join(sorted(info.dependencies)) if info.dependencies else 'None'}")
485492
print()
493+
494+
495+
def revert_unused_helper_functions(
496+
project_root, unused_helpers: list[FunctionSource], original_helper_code: dict[Path, str]
497+
) -> None:
498+
"""Revert unused helper functions back to their original definitions.
499+
500+
Args:
501+
unused_helpers: List of unused helper functions to revert
502+
original_helper_code: Dictionary mapping file paths to their original code
503+
504+
"""
505+
if not unused_helpers:
506+
return
507+
508+
logger.info(f"Reverting {len(unused_helpers)} unused helper function(s) to original definitions")
509+
510+
# Group unused helpers by file path
511+
unused_helpers_by_file = defaultdict(list)
512+
for helper in unused_helpers:
513+
unused_helpers_by_file[helper.file_path].append(helper)
514+
515+
# For each file, revert the unused helper functions to their original definitions
516+
for file_path, helpers_in_file in unused_helpers_by_file.items():
517+
if file_path in original_helper_code:
518+
try:
519+
# Read current file content
520+
current_code = file_path.read_text(encoding="utf8")
521+
522+
# Get original code for this file
523+
original_code = original_helper_code[file_path]
524+
525+
# Use the code replacer to selectively revert only the unused helper functions
526+
helper_names = [helper.qualified_name for helper in helpers_in_file]
527+
reverted_code = replace_function_definitions_in_module(
528+
function_names=helper_names,
529+
optimized_code=original_code, # Use original code as the "optimized" code to revert
530+
module_abspath=file_path,
531+
preexisting_objects=set(), # Empty set since we're reverting
532+
project_root_path=project_root,
533+
)
534+
535+
if reverted_code:
536+
logger.debug(f"Reverted unused helpers in {file_path}: {', '.join(helper_names)}")
537+
538+
except Exception as e:
539+
logger.error(f"Error reverting unused helpers in {file_path}: {e}")
540+
541+
542+
def _analyze_imports_in_optimized_code(
543+
optimized_ast: ast.AST, code_context: CodeOptimizationContext
544+
) -> dict[str, set[str]]:
545+
"""Analyze import statements in optimized code to map imported names to qualified helper names.
546+
547+
Args:
548+
optimized_ast: The AST of the optimized code
549+
code_context: The code optimization context containing helper functions
550+
551+
Returns:
552+
Dictionary mapping imported names to sets of possible qualified helper names
553+
554+
"""
555+
imported_names_map = defaultdict(set)
556+
557+
# Precompute a two-level dict: module_name -> func_name -> [helpers]
558+
helpers_by_file_and_func = defaultdict(dict)
559+
helpers_by_file = defaultdict(list) # preserved for "import module"
560+
helpers_append = helpers_by_file_and_func.setdefault
561+
for helper in code_context.helper_functions:
562+
jedi_type = helper.jedi_definition.type
563+
if jedi_type != "class":
564+
func_name = helper.only_function_name
565+
module_name = helper.file_path.stem
566+
# Cache function lookup for this (module, func)
567+
file_entry = helpers_by_file_and_func[module_name]
568+
if func_name in file_entry:
569+
file_entry[func_name].append(helper)
570+
else:
571+
file_entry[func_name] = [helper]
572+
helpers_by_file[module_name].append(helper)
573+
574+
# Optimize attribute lookups and method binding outside the loop
575+
helpers_by_file_and_func_get = helpers_by_file_and_func.get
576+
helpers_by_file_get = helpers_by_file.get
577+
578+
for node in ast.walk(optimized_ast):
579+
if isinstance(node, ast.ImportFrom):
580+
# Handle "from module import function" statements
581+
module_name = node.module
582+
if module_name:
583+
file_entry = helpers_by_file_and_func_get(module_name, None)
584+
if file_entry:
585+
for alias in node.names:
586+
imported_name = alias.asname if alias.asname else alias.name
587+
original_name = alias.name
588+
helpers = file_entry.get(original_name, None)
589+
if helpers:
590+
for helper in helpers:
591+
imported_names_map[imported_name].add(helper.qualified_name)
592+
imported_names_map[imported_name].add(helper.fully_qualified_name)
593+
594+
elif isinstance(node, ast.Import):
595+
# Handle "import module" statements
596+
for alias in node.names:
597+
imported_name = alias.asname if alias.asname else alias.name
598+
module_name = alias.name
599+
for helper in helpers_by_file_get(module_name, []):
600+
# For "import module" statements, functions would be called as module.function
601+
full_call = f"{imported_name}.{helper.only_function_name}"
602+
imported_names_map[full_call].add(helper.qualified_name)
603+
imported_names_map[full_call].add(helper.fully_qualified_name)
604+
605+
return dict(imported_names_map)
606+
607+
608+
def detect_unused_helper_functions(
609+
function_to_optimize, code_context: CodeOptimizationContext, optimized_code: str
610+
) -> list[FunctionSource]:
611+
"""Detect helper functions that are no longer called by the optimized entrypoint function.
612+
613+
Args:
614+
code_context: The code optimization context containing helper functions
615+
optimized_code: The optimized code to analyze
616+
617+
Returns:
618+
List of FunctionSource objects representing unused helper functions
619+
620+
"""
621+
try:
622+
# Parse the optimized code to analyze function calls and imports
623+
optimized_ast = ast.parse(optimized_code)
624+
625+
# Find the optimized entrypoint function
626+
entrypoint_function_ast = None
627+
for node in ast.walk(optimized_ast):
628+
if isinstance(node, ast.FunctionDef) and node.name == function_to_optimize.function_name:
629+
entrypoint_function_ast = node
630+
break
631+
632+
if not entrypoint_function_ast:
633+
logger.debug(f"Could not find entrypoint function {function_to_optimize.function_name} in optimized code")
634+
return []
635+
636+
# First, analyze imports to build a mapping of imported names to their original qualified names
637+
imported_names_map = _analyze_imports_in_optimized_code(optimized_ast, code_context)
638+
639+
# Extract all function calls in the entrypoint function
640+
called_function_names = set()
641+
for node in ast.walk(entrypoint_function_ast):
642+
if isinstance(node, ast.Call):
643+
if isinstance(node.func, ast.Name):
644+
# Regular function call: function_name()
645+
called_name = node.func.id
646+
called_function_names.add(called_name)
647+
# Also add the qualified name if this is an imported function
648+
if called_name in imported_names_map:
649+
called_function_names.update(imported_names_map[called_name])
650+
elif isinstance(node.func, ast.Attribute):
651+
# Method call: obj.method() or self.method() or module.function()
652+
if isinstance(node.func.value, ast.Name):
653+
if node.func.value.id == "self":
654+
# self.method_name() -> add both method_name and ClassName.method_name
655+
called_function_names.add(node.func.attr)
656+
# For class methods, also add the qualified name
657+
if hasattr(function_to_optimize, "parents") and function_to_optimize.parents:
658+
class_name = function_to_optimize.parents[0].name
659+
called_function_names.add(f"{class_name}.{node.func.attr}")
660+
else:
661+
# obj.method() or module.function()
662+
attr_name = node.func.attr
663+
called_function_names.add(attr_name)
664+
called_function_names.add(f"{node.func.value.id}.{attr_name}")
665+
# Check if this is a module.function call that maps to a helper
666+
full_call = f"{node.func.value.id}.{attr_name}"
667+
if full_call in imported_names_map:
668+
called_function_names.update(imported_names_map[full_call])
669+
# Handle nested attribute access like obj.attr.method()
670+
else:
671+
called_function_names.add(node.func.attr)
672+
673+
logger.debug(f"Functions called in optimized entrypoint: {called_function_names}")
674+
logger.debug(f"Imported names mapping: {imported_names_map}")
675+
676+
# Find helper functions that are no longer called
677+
unused_helpers = []
678+
for helper_function in code_context.helper_functions:
679+
if helper_function.jedi_definition.type != "class":
680+
# Check if the helper function is called using multiple name variants
681+
helper_qualified_name = helper_function.qualified_name
682+
helper_simple_name = helper_function.only_function_name
683+
helper_fully_qualified_name = helper_function.fully_qualified_name
684+
685+
# Create a set of all possible names this helper might be called by
686+
possible_call_names = {helper_qualified_name, helper_simple_name, helper_fully_qualified_name}
687+
688+
# For cross-file helpers, also consider module-based calls
689+
if helper_function.file_path != function_to_optimize.file_path:
690+
# Add potential module.function combinations
691+
module_name = helper_function.file_path.stem
692+
possible_call_names.add(f"{module_name}.{helper_simple_name}")
693+
694+
# Check if any of the possible names are in the called functions
695+
is_called = bool(possible_call_names.intersection(called_function_names))
696+
697+
if not is_called:
698+
unused_helpers.append(helper_function)
699+
logger.debug(f"Helper function {helper_qualified_name} is not called in optimized code")
700+
logger.debug(f" Checked names: {possible_call_names}")
701+
else:
702+
logger.debug(f"Helper function {helper_qualified_name} is still called in optimized code")
703+
logger.debug(f" Called via: {possible_call_names.intersection(called_function_names)}")
704+
705+
return unused_helpers
706+
707+
except Exception as e:
708+
logger.debug(f"Error detecting unused helper functions: {e}")
709+
return []

codeflash/optimization/function_optimizer.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
4747
from codeflash.code_utils.time_utils import humanize_runtime
4848
from codeflash.context import code_context_extractor
49+
from codeflash.context.unused_definition_remover import detect_unused_helper_functions, revert_unused_helper_functions
4950
from codeflash.either import Failure, Success, is_successful
5051
from codeflash.models.ExperimentMetadata import ExperimentMetadata
5152
from codeflash.models.models import (
@@ -295,7 +296,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
295296
)
296297

297298
self.replace_function_and_helpers_with_optimized_code(
298-
code_context=code_context, optimized_code=best_optimization.candidate.source_code
299+
code_context=code_context,
300+
optimized_code=best_optimization.candidate.source_code,
301+
original_helper_code=original_helper_code,
299302
)
300303

301304
new_code, new_helper_code = self.reformat_code_and_helpers(
@@ -418,7 +421,9 @@ def determine_best_candidate(
418421
code_print(candidate.source_code)
419422
try:
420423
did_update = self.replace_function_and_helpers_with_optimized_code(
421-
code_context=code_context, optimized_code=candidate.source_code
424+
code_context=code_context,
425+
optimized_code=candidate.source_code,
426+
original_helper_code=original_helper_code,
422427
)
423428
if not did_update:
424429
logger.warning(
@@ -619,7 +624,7 @@ def reformat_code_and_helpers(
619624
return new_code, new_helper_code
620625

621626
def replace_function_and_helpers_with_optimized_code(
622-
self, code_context: CodeOptimizationContext, optimized_code: str
627+
self, code_context: CodeOptimizationContext, optimized_code: str, original_helper_code: str
623628
) -> bool:
624629
did_update = False
625630
read_writable_functions_by_file_path = defaultdict(set)
@@ -637,6 +642,12 @@ def replace_function_and_helpers_with_optimized_code(
637642
preexisting_objects=code_context.preexisting_objects,
638643
project_root_path=self.project_root,
639644
)
645+
unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code)
646+
647+
# Revert unused helper functions to their original definitions
648+
if unused_helpers:
649+
revert_unused_helper_functions(self.project_root, unused_helpers, original_helper_code)
650+
640651
return did_update
641652

642653
def get_code_optimization_context(self) -> Result[CodeOptimizationContext, str]:

tests/test_code_replacement.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def sorter(arr):
6969
original_helper_code[helper_function_path] = helper_code
7070
func_optimizer.args = Args()
7171
func_optimizer.replace_function_and_helpers_with_optimized_code(
72-
code_context=code_context, optimized_code=optimized_code
72+
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
7373
)
7474
final_output = code_path.read_text(encoding="utf-8")
7575
assert "inconsequential_var = '123'" in final_output
@@ -804,7 +804,8 @@ def __init__(self, name):
804804
self.name = name
805805
806806
def main_method(self):
807-
return HelperClass(self.name).helper_method()"""
807+
return HelperClass(self.name).helper_method()
808+
"""
808809
file_path = Path(__file__).resolve()
809810
func_top_optimize = FunctionToOptimize(
810811
function_name="main_method", file_path=file_path, parents=[FunctionParent("MainClass", "ClassDef")]
@@ -1662,6 +1663,7 @@ def new_function2(value):
16621663
)
16631664
assert new_code == original_code
16641665

1666+
16651667
def test_global_reassignment() -> None:
16661668
original_code = """a=1
16671669
print("Hello world")
@@ -1733,7 +1735,7 @@ def new_function2(value):
17331735
original_helper_code[helper_function_path] = helper_code
17341736
func_optimizer.args = Args()
17351737
func_optimizer.replace_function_and_helpers_with_optimized_code(
1736-
code_context=code_context, optimized_code=optimized_code
1738+
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
17371739
)
17381740
new_code = code_path.read_text(encoding="utf-8")
17391741
code_path.unlink(missing_ok=True)
@@ -1809,7 +1811,7 @@ def new_function2(value):
18091811
original_helper_code[helper_function_path] = helper_code
18101812
func_optimizer.args = Args()
18111813
func_optimizer.replace_function_and_helpers_with_optimized_code(
1812-
code_context=code_context, optimized_code=optimized_code
1814+
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
18131815
)
18141816
new_code = code_path.read_text(encoding="utf-8")
18151817
code_path.unlink(missing_ok=True)
@@ -1886,7 +1888,7 @@ def new_function2(value):
18861888
original_helper_code[helper_function_path] = helper_code
18871889
func_optimizer.args = Args()
18881890
func_optimizer.replace_function_and_helpers_with_optimized_code(
1889-
code_context=code_context, optimized_code=optimized_code
1891+
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
18901892
)
18911893
new_code = code_path.read_text(encoding="utf-8")
18921894
code_path.unlink(missing_ok=True)
@@ -1962,7 +1964,7 @@ def new_function2(value):
19621964
original_helper_code[helper_function_path] = helper_code
19631965
func_optimizer.args = Args()
19641966
func_optimizer.replace_function_and_helpers_with_optimized_code(
1965-
code_context=code_context, optimized_code=optimized_code
1967+
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
19661968
)
19671969
new_code = code_path.read_text(encoding="utf-8")
19681970
code_path.unlink(missing_ok=True)
@@ -2039,7 +2041,7 @@ def new_function2(value):
20392041
original_helper_code[helper_function_path] = helper_code
20402042
func_optimizer.args = Args()
20412043
func_optimizer.replace_function_and_helpers_with_optimized_code(
2042-
code_context=code_context, optimized_code=optimized_code
2044+
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
20432045
)
20442046
new_code = code_path.read_text(encoding="utf-8")
20452047
code_path.unlink(missing_ok=True)
@@ -2127,8 +2129,8 @@ def new_function2(value):
21272129
original_helper_code[helper_function_path] = helper_code
21282130
func_optimizer.args = Args()
21292131
func_optimizer.replace_function_and_helpers_with_optimized_code(
2130-
code_context=code_context, optimized_code=optimized_code
2132+
code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code
21312133
)
21322134
new_code = code_path.read_text(encoding="utf-8")
21332135
code_path.unlink(missing_ok=True)
2134-
assert new_code.rstrip() == expected_code.rstrip()
2136+
assert new_code.rstrip() == expected_code.rstrip()

0 commit comments

Comments
 (0)