diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 78ad56ddc..4d91c3fd0 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -612,6 +612,34 @@ def _analyze_imports_in_optimized_code( return dict(imported_names_map) +def find_target_node( + root: ast.AST, function_to_optimize: FunctionToOptimize +) -> Optional[ast.FunctionDef | ast.AsyncFunctionDef]: + parents = function_to_optimize.parents + node = root + for parent in parents: + # Fast loop: directly look for the matching ClassDef in node.body + body = getattr(node, "body", None) + if not body: + return None + for child in body: + if isinstance(child, ast.ClassDef) and child.name == parent.name: + node = child + break + else: + return None + + # Now node is either the root or the target parent class; look for function + body = getattr(node, "body", None) + if not body: + return None + target_name = function_to_optimize.function_name + for child in body: + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef)) and child.name == target_name: + return child + return None + + def detect_unused_helper_functions( function_to_optimize: FunctionToOptimize, code_context: CodeOptimizationContext, @@ -641,11 +669,7 @@ def detect_unused_helper_functions( optimized_ast = ast.parse(optimized_code) # Find the optimized entrypoint function - entrypoint_function_ast = None - for node in ast.walk(optimized_ast): - if isinstance(node, ast.FunctionDef) and node.name == function_to_optimize.function_name: - entrypoint_function_ast = node - break + entrypoint_function_ast = find_target_node(optimized_ast, function_to_optimize) if not entrypoint_function_ast: logger.debug(f"Could not find entrypoint function {function_to_optimize.function_name} in optimized code") diff --git a/tests/test_unused_helper_revert.py b/tests/test_unused_helper_revert.py index 30f291e62..c342f0c1b 100644 --- a/tests/test_unused_helper_revert.py +++ b/tests/test_unused_helper_revert.py @@ -6,7 +6,7 @@ import pytest from codeflash.context.unused_definition_remover import detect_unused_helper_functions from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import CodeStringsMarkdown +from codeflash.models.models import CodeStringsMarkdown, FunctionParent from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig @@ -1460,3 +1460,152 @@ def calculate_class(cls, n): import shutil shutil.rmtree(temp_dir, ignore_errors=True) + + +def test_unused_helper_detection_with_duplicated_function_name_in_different_classes(): + """Test detection when helpers are called via module.function style.""" + temp_dir = Path(tempfile.mkdtemp()) + + try: + # Main file + main_file = temp_dir / "main.py" + main_file.write_text("""from __future__ import annotations +import json +from helpers import replace_quotes_with_backticks, simplify_worktree_paths +from dataclasses import asdict, dataclass + +@dataclass +class LspMessage: + + def serialize(self) -> str: + data = self._loop_through(asdict(self)) + # Important: keep type as the first key, for making it easy and fast for the client to know if this is a lsp message before parsing it + ordered = {"type": self.type(), **data} + return ( + message_delimiter + + json.dumps(ordered) + + message_delimiter + ) + + +@dataclass +class LspMarkdownMessage(LspMessage): + + def serialize(self) -> str: + self.markdown = simplify_worktree_paths(self.markdown) + self.markdown = replace_quotes_with_backticks(self.markdown) + return super().serialize() +""") + + # Helpers file + helpers_file = temp_dir / "helpers.py" + helpers_file.write_text("""def simplify_worktree_paths(msg: str, highlight: bool = True) -> str: # noqa: FBT001, FBT002 + path_in_msg = worktree_path_regex.search(msg) + if path_in_msg: + last_part_of_path = path_in_msg.group(0).split("/")[-1] + if highlight: + last_part_of_path = f"`{last_part_of_path}`" + return msg.replace(path_in_msg.group(0), last_part_of_path) + return msg + + +def replace_quotes_with_backticks(text: str) -> str: + # double-quoted strings + text = _double_quote_pat.sub(r"`\1`", text) + # single-quoted strings + return _single_quote_pat.sub(r"`\1`", text) +""") + + # Optimized version that only uses add_numbers + optimized_code = """ +```python:main.py +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass + +from codeflash.lsp.helpers import (replace_quotes_with_backticks, + simplify_worktree_paths) + + +@dataclass +class LspMessage: + + def serialize(self) -> str: + # Use local variable to minimize lookup costs and avoid unnecessary dictionary unpacking + data = self._loop_through(asdict(self)) + msg_type = self.type() + ordered = {'type': msg_type} + ordered.update(data) + return ( + message_delimiter + + json.dumps(ordered) + + message_delimiter # \u241F is the message delimiter becuase it can be more than one message sent over the same message, so we need something to separate each message + ) + +@dataclass +class LspMarkdownMessage(LspMessage): + + def serialize(self) -> str: + # Side effect required, must preserve for behavioral correctness + self.markdown = simplify_worktree_paths(self.markdown) + self.markdown = replace_quotes_with_backticks(self.markdown) + return super().serialize() +``` +```python:helpers.py +def simplify_worktree_paths(msg: str, highlight: bool = True) -> str: # noqa: FBT001, FBT002 + m = worktree_path_regex.search(msg) + if m: + # More efficient way to get last path part + last_part_of_path = m.group(0).rpartition('/')[-1] + if highlight: + last_part_of_path = f"`{last_part_of_path}`" + return msg.replace(m.group(0), last_part_of_path) + return msg + +def replace_quotes_with_backticks(text: str) -> str: + # Efficient string substitution, reduces intermediate string allocations + return _single_quote_pat.sub( + r"`\1`", + _double_quote_pat.sub(r"`\1`", text), + ) +``` +""" + + # Create test config + test_cfg = TestConfig( + tests_root=temp_dir / "tests", + tests_project_rootdir=temp_dir, + project_root_path=temp_dir, + test_framework="pytest", + pytest_cmd="pytest", + ) + + # Create FunctionToOptimize instance + function_to_optimize = FunctionToOptimize( + file_path=main_file, function_name="serialize", qualified_name="serialize", parents=[ + FunctionParent(name="LspMarkdownMessage", type="ClassDef"), + ] + ) + + optimizer = FunctionOptimizer( + function_to_optimize=function_to_optimize, + test_cfg=test_cfg, + function_to_optimize_source_code=main_file.read_text(), + ) + + ctx_result = optimizer.get_code_optimization_context() + assert ctx_result.is_successful(), f"Failed to get context: {ctx_result.failure()}" + + code_context = ctx_result.unwrap() + + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) + + unused_names = {uh.qualified_name for uh in unused_helpers} + assert len(unused_names) == 0 # no unused helpers + + finally: + # Cleanup + import shutil + + shutil.rmtree(temp_dir, ignore_errors=True)