diff --git a/code_to_optimize/code_directories/circular_deps/constants.py b/code_to_optimize/code_directories/circular_deps/constants.py index dc4b0638e..be8fdac15 100644 --- a/code_to_optimize/code_directories/circular_deps/constants.py +++ b/code_to_optimize/code_directories/circular_deps/constants.py @@ -1,8 +1,2 @@ DEFAULT_API_URL = "https://api.galileo.ai/" DEFAULT_APP_URL = "https://app.galileo.ai/" - - -# function_names: GalileoApiClient.get_console_url -# module_abs_path : /home/mohammed/Work/galileo-python/src/galileo/api_client.py -# preexisting_objects: {('GalileoApiClient', ()), ('_set_destination', ()), ('get_console_url', (FunctionParent(name='GalileoApiClient', type='ClassDef'),))} -# project_root_path: /home/mohammed/Work/galileo-python/src diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 4ec4f3aec..d921469d1 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -13,7 +13,7 @@ from codeflash.code_utils.env_utils import get_codeflash_api_key, is_LSP_enabled from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name from codeflash.models.ExperimentMetadata import ExperimentMetadata -from codeflash.models.models import AIServiceRefinerRequest, OptimizedCandidate +from codeflash.models.models import AIServiceRefinerRequest, CodeStringsMarkdown, OptimizedCandidate from codeflash.telemetry.posthog_cf import ph from codeflash.version import __version__ as codeflash_version @@ -136,7 +136,7 @@ def optimize_python_code( # noqa: D417 logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.") return [ OptimizedCandidate( - source_code=opt["source_code"], + source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]), explanation=opt["explanation"], optimization_id=opt["optimization_id"], ) @@ -206,7 +206,7 @@ def optimize_python_code_line_profiler( # noqa: D417 console.rule() return [ OptimizedCandidate( - source_code=opt["source_code"], + source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]), explanation=opt["explanation"], optimization_id=opt["optimization_id"], ) @@ -263,7 +263,7 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest] console.rule() return [ OptimizedCandidate( - source_code=opt["source_code"], + source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]), explanation=opt["explanation"], optimization_id=opt["optimization_id"][:-4] + "refi", ) diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 3c73c5919..740e578b6 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -19,7 +19,7 @@ from pathlib import Path from codeflash.discovery.functions_to_optimize import FunctionToOptimize - from codeflash.models.models import CodeOptimizationContext, OptimizedCandidate, ValidCode + from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, OptimizedCandidate, ValidCode ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST) @@ -408,16 +408,17 @@ def replace_functions_and_add_imports( def replace_function_definitions_in_module( function_names: list[str], - optimized_code: str, + optimized_code: CodeStringsMarkdown, module_abspath: Path, preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]], project_root_path: Path, ) -> bool: source_code: str = module_abspath.read_text(encoding="utf8") + code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code) new_code: str = replace_functions_and_add_imports( - add_global_assignments(optimized_code, source_code), + add_global_assignments(code_to_apply, source_code), function_names, - optimized_code, + code_to_apply, module_abspath, preexisting_objects, project_root_path, @@ -428,6 +429,19 @@ def replace_function_definitions_in_module( return True +def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStringsMarkdown) -> str: + file_to_code_context = optimized_code.file_to_path() + module_optimized_code = file_to_code_context.get(str(relative_path)) + if module_optimized_code is None: + logger.warning( + f"Optimized code not found for {relative_path} In the context\n-------\n{optimized_code}\n-------\n" + "re-check your 'markdown code structure'" + f"existing files are {file_to_code_context.keys()}" + ) + module_optimized_code = "" + return module_optimized_code + + def is_zero_diff(original_code: str, new_code: str) -> bool: return normalize_code(original_code) == normalize_code(new_code) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index cd8c2efb3..db7fa4257 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -104,7 +104,7 @@ def is_diff_line(line: str) -> bool: def format_code( formatter_cmds: list[str], path: Union[str, Path], - optimized_function: str = "", + optimized_code: str = "", check_diff: bool = False, # noqa print_status: bool = True, # noqa exit_on_failure: bool = True, # noqa @@ -121,7 +121,7 @@ def format_code( if check_diff and original_code_lines > 50: # we dont' count the formatting diff for the optimized function as it should be well-formatted - original_code_without_opfunc = original_code.replace(optimized_function, "") + original_code_without_opfunc = original_code.replace(optimized_code, "") original_temp = Path(test_dir_str) / "original_temp.py" original_temp.write_text(original_code_without_opfunc, encoding="utf8") diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index b520f12b7..09c0c564a 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -61,13 +61,14 @@ def get_code_optimization_context( ) # Extract code context for optimization - final_read_writable_code = extract_code_string_context_from_files( + final_read_writable_code = extract_code_markdown_context_from_files( helpers_of_fto_dict, {}, project_root_path, remove_docstrings=False, code_context_type=CodeContextType.READ_WRITABLE, - ).code + ) + read_only_code_markdown = extract_code_markdown_context_from_files( helpers_of_fto_dict, helpers_of_helpers_dict, @@ -84,14 +85,14 @@ def get_code_optimization_context( ) # Handle token limits - final_read_writable_tokens = encoded_tokens_len(final_read_writable_code) + final_read_writable_tokens = encoded_tokens_len(final_read_writable_code.markdown) if final_read_writable_tokens > optim_token_limit: raise ValueError("Read-writable code has exceeded token limit, cannot proceed") # Setup preexisting objects for code replacer preexisting_objects = set( chain( - find_preexisting_objects(final_read_writable_code), + *(find_preexisting_objects(codestring.code) for codestring in final_read_writable_code.code_strings), *(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings), ) ) diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 977c72bd3..cf57af031 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -3,16 +3,15 @@ import ast from collections import defaultdict from dataclasses import dataclass, field -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from pathlib import Path +from itertools import chain +from pathlib import Path from typing import TYPE_CHECKING, Optional import libcst as cst from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_replacer import replace_function_definitions_in_module +from codeflash.models.models import CodeString, CodeStringsMarkdown if TYPE_CHECKING: from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -530,7 +529,11 @@ def revert_unused_helper_functions( helper_names = [helper.qualified_name for helper in helpers_in_file] reverted_code = replace_function_definitions_in_module( function_names=helper_names, - optimized_code=original_code, # Use original code as the "optimized" code to revert + optimized_code=CodeStringsMarkdown( + code_strings=[ + CodeString(code=original_code, file_path=Path(file_path).relative_to(project_root)) + ] + ), # Use original code as the "optimized" code to revert module_abspath=file_path, preexisting_objects=set(), # Empty set since we're reverting project_root_path=project_root, @@ -609,7 +612,9 @@ def _analyze_imports_in_optimized_code( def detect_unused_helper_functions( - function_to_optimize: FunctionToOptimize, code_context: CodeOptimizationContext, optimized_code: str + function_to_optimize: FunctionToOptimize, + code_context: CodeOptimizationContext, + optimized_code: str | CodeStringsMarkdown, ) -> list[FunctionSource]: """Detect helper functions that are no longer called by the optimized entrypoint function. @@ -622,6 +627,14 @@ def detect_unused_helper_functions( List of FunctionSource objects representing unused helper functions """ + if isinstance(optimized_code, CodeStringsMarkdown) and len(optimized_code.code_strings) > 0: + return list( + chain.from_iterable( + detect_unused_helper_functions(function_to_optimize, code_context, code.code) + for code in optimized_code.code_strings + ) + ) + try: # Parse the optimized code to analyze function calls and imports optimized_ast = ast.parse(optimized_code) diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index 2e5492a0c..4ed0b7a62 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -222,7 +222,7 @@ def generate_tests(server: CodeflashLanguageServer, params: FunctionOptimization generated_test.generated_original_test_source for generated_test in generated_tests_list.generated_tests ] optimizations_dict = { - candidate.optimization_id: {"source_code": candidate.source_code, "explanation": candidate.explanation} + candidate.optimization_id: {"source_code": candidate.source_code.markdown, "explanation": candidate.explanation} for candidate in optimizations_set.control + optimizations_set.experiment } @@ -330,7 +330,7 @@ def perform_function_optimization( # noqa: PLR0911 "message": f"No best optimizations found for function {function_to_optimize_qualified_name}", } - optimized_source = best_optimization.candidate.source_code + optimized_source = best_optimization.candidate.source_code.markdown speedup = original_code_baseline.runtime / best_optimization.runtime server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info") diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 91882c66a..1636d6889 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -19,7 +19,7 @@ from typing import Annotated, Optional, cast from jedi.api.classes import Name -from pydantic import AfterValidator, BaseModel, ConfigDict, Field +from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr from pydantic.dataclasses import dataclass from codeflash.cli_cmds.console import console, logger @@ -157,12 +157,51 @@ class CodeString(BaseModel): file_path: Optional[Path] = None +def get_code_block_splitter(file_path: Path) -> str: + return f"# file: {file_path}" + + +markdown_pattern = re.compile(r"```python:([^\n]+)\n(.*?)\n```", re.DOTALL) + + class CodeStringsMarkdown(BaseModel): code_strings: list[CodeString] = [] + _cache: dict = PrivateAttr(default_factory=dict) + + @property + def flat(self) -> str: + """Returns the combined Python module from all code blocks. + + Each block is prefixed by a file path comment to indicate its origin. + This representation is syntactically valid Python code. + + Returns: + str: The concatenated code of all blocks with file path annotations. + + !! Important !!: + Avoid parsing the flat code with multiple files, + parsing may result in unexpected behavior. + + + """ + if self._cache.get("flat") is not None: + return self._cache["flat"] + self._cache["flat"] = "\n".join( + get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings + ) + return self._cache["flat"] @property def markdown(self) -> str: - """Returns the markdown representation of the code, including the file path where possible.""" + """Returns a Markdown-formatted string containing all code blocks. + + Each block is enclosed in a triple-backtick code block with an optional + file path suffix (e.g., ```python:filename.py). + + Returns: + str: Markdown representation of the code blocks. + + """ return "\n".join( [ f"```python{':' + str(code_string.file_path) if code_string.file_path else ''}\n{code_string.code.strip()}\n```" @@ -170,10 +209,44 @@ def markdown(self) -> str: ] ) + def file_to_path(self) -> dict[str, str]: + """Return a dictionary mapping file paths to their corresponding code blocks. + + Returns: + dict[str, str]: Mapping from file path (as string) to code. + + """ + if self._cache.get("file_to_path") is not None: + return self._cache["file_to_path"] + self._cache["file_to_path"] = { + str(code_string.file_path): code_string.code for code_string in self.code_strings + } + return self._cache["file_to_path"] + + @staticmethod + def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown: + """Parse a Markdown string into a CodeStringsMarkdown object. + + Extracts code blocks and their associated file paths and constructs a new CodeStringsMarkdown instance. + + Args: + markdown_code (str): The Markdown-formatted string to parse. + + Returns: + CodeStringsMarkdown: Parsed object containing code blocks. + + """ + matches = markdown_pattern.findall(markdown_code) + results = CodeStringsMarkdown() + for file_path, code in matches: + path = file_path.strip() + results.code_strings.append(CodeString(code=code, file_path=Path(path))) + return results + class CodeOptimizationContext(BaseModel): testgen_context_code: str = "" - read_writable_code: str = Field(min_length=1) + read_writable_code: CodeStringsMarkdown read_only_context_code: str = "" hashing_code_context: str = "" hashing_code_context_hash: str = "" @@ -272,7 +345,7 @@ class TestsInFile: @dataclass(frozen=True) class OptimizedCandidate: - source_code: str + source_code: CodeStringsMarkdown explanation: str optimization_id: str diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 9933751b7..6905bf47c 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -95,6 +95,7 @@ from codeflash.either import Result from codeflash.models.models import ( BenchmarkKey, + CodeStringsMarkdown, CoverageData, FunctionCalledInTest, FunctionSource, @@ -169,7 +170,10 @@ def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[P helper_code = f.read() original_helper_code[helper_function_path] = helper_code - if has_any_async_functions(code_context.read_writable_code): + async_code = any( + has_any_async_functions(code_string.code) for code_string in code_context.read_writable_code.code_strings + ) + if async_code: return Failure("Codeflash does not support async functions in the code to optimize.") # Random here means that we still attempt optimization with a fractional chance to see if # last time we could not find an optimization, maybe this time we do. @@ -288,7 +292,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() - code_print(code_context.read_writable_code) + code_print(code_context.read_writable_code.flat) test_setup_result = self.generate_and_instrument_tests( # also generates optimizations code_context, should_run_experiment=should_run_experiment @@ -382,7 +386,7 @@ def determine_best_candidate( ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client future_line_profile_results = executor.submit( ai_service_client.optimize_python_code_line_profiler, - source_code=code_context.read_writable_code, + source_code=code_context.read_writable_code.markdown, dependency_code=code_context.read_only_context_code, trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, line_profiler_results=original_code_baseline.line_profile_results["str_out"], @@ -413,7 +417,7 @@ def determine_best_candidate( get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) logger.info(f"Optimization candidate {candidate_index}/{original_len}:") - code_print(candidate.source_code) + code_print(candidate.source_code.flat) try: did_update = self.replace_function_and_helpers_with_optimized_code( code_context=code_context, @@ -577,7 +581,7 @@ def determine_best_candidate( runtimes_list = [] for valid_opt in self.valid_optimizations: diff_lens_list.append( - diff_length(valid_opt.candidate.source_code, code_context.read_writable_code) + diff_length(valid_opt.candidate.source_code.flat, code_context.read_writable_code.flat) ) # char level diff runtimes_list.append(valid_opt.runtime) diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list) @@ -609,10 +613,10 @@ def refine_optimizations( request = [ AIServiceRefinerRequest( optimization_id=opt.candidate.optimization_id, - original_source_code=code_context.read_writable_code, + original_source_code=code_context.read_writable_code.markdown, read_only_dependency_code=code_context.read_only_context_code, original_code_runtime=humanize_runtime(original_code_baseline.runtime), - optimized_source_code=opt.candidate.source_code, + optimized_source_code=opt.candidate.source_code.markdown, optimized_explanation=opt.candidate.explanation, optimized_code_runtime=humanize_runtime(opt.runtime), speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=opt.runtime) * 100)}%", @@ -678,13 +682,22 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, f.write(helper_code) def reformat_code_and_helpers( - self, helper_functions: list[FunctionSource], path: Path, original_code: str, optimized_function: str + self, + helper_functions: list[FunctionSource], + path: Path, + original_code: str, + optimized_context: CodeStringsMarkdown, ) -> tuple[str, dict[Path, str]]: should_sort_imports = not self.args.disable_imports_sorting if should_sort_imports and isort.code(original_code) != original_code: should_sort_imports = False - new_code = format_code(self.args.formatter_cmds, path, optimized_function=optimized_function, check_diff=True) + optimized_code = "" + if optimized_context is not None: + file_to_code_context = optimized_context.file_to_path() + optimized_code = file_to_code_context.get(str(path.relative_to(self.project_root)), "") + + new_code = format_code(self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True) if should_sort_imports: new_code = sort_imports(new_code) @@ -693,7 +706,7 @@ def reformat_code_and_helpers( module_abspath = hp.file_path hp_source_code = hp.source_code formatted_helper_code = format_code( - self.args.formatter_cmds, module_abspath, optimized_function=hp_source_code, check_diff=True + self.args.formatter_cmds, module_abspath, optimized_code=hp_source_code, check_diff=True ) if should_sort_imports: formatted_helper_code = sort_imports(formatted_helper_code) @@ -702,7 +715,7 @@ def reformat_code_and_helpers( return new_code, new_helper_code def replace_function_and_helpers_with_optimized_code( - self, code_context: CodeOptimizationContext, optimized_code: str, original_helper_code: str + self, code_context: CodeOptimizationContext, optimized_code: CodeStringsMarkdown, original_helper_code: str ) -> bool: did_update = False read_writable_functions_by_file_path = defaultdict(set) @@ -846,7 +859,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio def generate_tests_and_optimizations( self, testgen_context_code: str, - read_writable_code: str, + read_writable_code: CodeStringsMarkdown, read_only_context_code: str, helper_functions: list[FunctionSource], generated_test_paths: list[Path], @@ -867,7 +880,7 @@ def generate_tests_and_optimizations( ) future_optimization_candidates = executor.submit( self.aiservice_client.optimize_python_code, - read_writable_code, + read_writable_code.markdown, read_only_context_code, self.function_trace_id[:-4] + "EXP0" if run_experiment else self.function_trace_id, N_CANDIDATES, @@ -886,7 +899,7 @@ def generate_tests_and_optimizations( if run_experiment: future_candidates_exp = executor.submit( self.local_aiservice_client.optimize_python_code, - read_writable_code, + read_writable_code.markdown, read_only_context_code, self.function_trace_id[:-4] + "EXP1", N_CANDIDATES, @@ -1048,7 +1061,7 @@ def find_and_process_best_optimization( if best_optimization: logger.info("Best candidate:") - code_print(best_optimization.candidate.source_code) + code_print(best_optimization.candidate.source_code.flat) console.print( Panel( best_optimization.candidate.explanation, title="Best Candidate Explanation", border_style="blue" @@ -1082,7 +1095,7 @@ def find_and_process_best_optimization( code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code, - optimized_function=best_optimization.candidate.source_code, + optimized_context=best_optimization.candidate.source_code, ) original_code_combined = original_helper_code.copy() @@ -1154,10 +1167,10 @@ def process_review( optimized_runtimes_all=optimized_runtime_by_test, ) new_explanation_raw_str = self.aiservice_client.get_new_explanation( - source_code=code_context.read_writable_code, + source_code=code_context.read_writable_code.flat, dependency_code=code_context.read_only_context_code, trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, - optimized_code=best_optimization.candidate.source_code, + optimized_code=best_optimization.candidate.source_code.flat, original_line_profiler_results=original_code_baseline.line_profile_results["str_out"], optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"], original_code_runtime=humanize_runtime(original_code_baseline.runtime), diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 25200cb9c..3a7de5d1c 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -88,7 +88,8 @@ def test_code_replacement10() -> None: read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +```python:{file_path.relative_to(file_path.parent)} from __future__ import annotations class HelperClass: @@ -106,6 +107,7 @@ def __init__(self, name): def main_method(self): self.name = HelperClass.NestedClass("test").nested_method() return HelperClass(self.name).helper_method() +``` """ expected_read_only_context = """ """ @@ -125,7 +127,7 @@ def main_method(self): ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -145,7 +147,8 @@ def test_class_method_dependencies() -> None: read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +```python:{file_path.relative_to(file_path.parent)} from __future__ import annotations from collections import defaultdict @@ -173,7 +176,7 @@ def topologicalSort(self): # Print contents of stack return stack - +``` """ expected_read_only_context = "" @@ -198,7 +201,7 @@ def topologicalSort(self): ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -224,22 +227,23 @@ def test_bubble_sort_helper() -> None: read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py import math -from bubble_sort_with_math import sorter def sorter(arr): arr.sort() x = math.sqrt(2) print(x) return arr - - +``` +```python:code_to_optimize/code_directories/retriever/bubble_sort_imported.py +from bubble_sort_with_math import sorter def sort_from_another_file(arr): sorted_arr = sorter(arr) return sorted_arr - +``` """ expected_read_only_context = "" @@ -257,8 +261,7 @@ def sort_from_another_file(arr): return sorted_arr ``` """ - - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -455,7 +458,8 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): def __init__(self) -> None: ... @@ -550,7 +554,8 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: kwargs=kwargs, lifespan=self.__duration__, ) - """ +``` +""" expected_read_only_context = f''' ```python:{file_path.relative_to(opt.args.project_root)} _P = ParamSpec("_P") @@ -644,7 +649,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: return self.__backend__.get_cache_or_call(func=self.__wrapped__, args=args, kwargs=kwargs, lifespan=self.__duration__) ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -696,7 +701,8 @@ def helper_method(self): read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} class MyClass: def __init__(self): self.x = 1 @@ -709,7 +715,8 @@ def __init__(self): self.x = 1 def helper_method(self): return self.x - """ +``` +""" expected_read_only_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} class MyClass: @@ -736,7 +743,7 @@ def helper_method(self): ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -793,7 +800,8 @@ def helper_method(self): read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context # In this scenario, the read-only code context is too long, so the read-only docstrings are removed. - expected_read_write_context = """ + expected_read_write_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} class MyClass: def __init__(self): self.x = 1 @@ -807,7 +815,8 @@ def __init__(self): self.x = 1 def helper_method(self): return self.x - """ +``` +""" expected_read_only_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} class MyClass: @@ -831,7 +840,7 @@ def helper_method(self): return self.x ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -888,7 +897,8 @@ def helper_method(self): read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context # In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely. - expected_read_write_context = """ + expected_read_write_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} class MyClass: def __init__(self): self.x = 1 @@ -902,7 +912,8 @@ def __init__(self): self.x = 1 def helper_method(self): return self.x - """ +``` +""" expected_read_only_context = "" expected_hashing_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} @@ -917,7 +928,7 @@ def helper_method(self): return self.x ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1041,11 +1052,9 @@ def test_repo_helper() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +```python:{path_to_utils.relative_to(project_root)} import math -import requests -from globals import API_URL -from utils import DataProcessor class DataProcessor: @@ -1061,8 +1070,11 @@ def process_data(self, raw_data: str) -> str: def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str: \"\"\"Add a prefix to the processed data.\"\"\" return prefix + data - - +``` +```python:{path_to_file.relative_to(project_root)} +import requests +from globals import API_URL +from utils import DataProcessor def fetch_and_process_data(): # Use the global variable for the request @@ -1077,8 +1089,8 @@ def fetch_and_process_data(): processed = processor.add_prefix(processed) return processed - - """ +``` +""" expected_read_only_context = f""" ```python:{path_to_utils.relative_to(project_root)} class DataProcessor: @@ -1112,7 +1124,7 @@ def fetch_and_process_data(): return processed ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1133,12 +1145,10 @@ def test_repo_helper_of_helper() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +```python:{path_to_utils.relative_to(project_root)} import math from transform_utils import DataTransformer -import requests -from globals import API_URL -from utils import DataProcessor class DataProcessor: @@ -1154,8 +1164,11 @@ def process_data(self, raw_data: str) -> str: def transform_data(self, data: str) -> str: \"\"\"Transform the processed data\"\"\" return DataTransformer().transform(data) - - +``` +```python:{path_to_file.relative_to(project_root)} +import requests +from globals import API_URL +from utils import DataProcessor def fetch_and_transform_data(): # Use the global variable for the request @@ -1169,8 +1182,8 @@ def fetch_and_transform_data(): transformed = processor.transform_data(processed) return transformed - - """ +``` +""" expected_read_only_context = f""" ```python:{path_to_utils.relative_to(project_root)} class DataProcessor: @@ -1210,8 +1223,7 @@ def fetch_and_transform_data(): return transformed ``` """ - - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1231,18 +1243,18 @@ def test_repo_helper_of_helper_same_class() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ -import math -from transform_utils import DataTransformer - + expected_read_write_context = f""" +```python:{path_to_transform_utils.relative_to(project_root)} class DataTransformer: def __init__(self): self.data = None def transform_using_own_method(self, data): return self.transform(data) - - +``` +```python:{path_to_utils.relative_to(project_root)} +import math +from transform_utils import DataTransformer class DataProcessor: @@ -1254,7 +1266,7 @@ def __init__(self, default_prefix: str = "PREFIX_"): def transform_data_own_method(self, data: str) -> str: \"\"\"Transform the processed data using own method\"\"\" return DataTransformer().transform_using_own_method(data) - +``` """ expected_read_only_context = f""" ```python:{path_to_transform_utils.relative_to(project_root)} @@ -1291,7 +1303,7 @@ def transform_data_own_method(self, data: str) -> str: ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1311,18 +1323,18 @@ def test_repo_helper_of_helper_same_file() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ -import math -from transform_utils import DataTransformer - + expected_read_write_context = f""" +```python:{path_to_transform_utils.relative_to(project_root)} class DataTransformer: def __init__(self): self.data = None def transform_using_same_file_function(self, data): return update_data(data) - - +``` +```python:{path_to_utils.relative_to(project_root)} +import math +from transform_utils import DataTransformer class DataProcessor: @@ -1334,7 +1346,8 @@ def __init__(self, default_prefix: str = "PREFIX_"): def transform_data_same_file_function(self, data: str) -> str: \"\"\"Transform the processed data using a function from the same file\"\"\" return DataTransformer().transform_using_same_file_function(data) - """ +``` +""" expected_read_only_context = f""" ```python:{path_to_transform_utils.relative_to(project_root)} def update_data(data): @@ -1366,7 +1379,7 @@ def transform_data_same_file_function(self, data: str) -> str: ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1385,7 +1398,8 @@ def test_repo_helper_all_same_file() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +```python:{path_to_transform_utils.relative_to(project_root)} class DataTransformer: def __init__(self): self.data = None @@ -1400,7 +1414,8 @@ def transform_data_all_same_file(self, data): def update_data(data): return data + " updated" - """ +``` +""" expected_read_only_context = f""" ```python:{path_to_transform_utils.relative_to(project_root)} class DataTransformer: @@ -1427,7 +1442,7 @@ def update_data(data): ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1447,10 +1462,10 @@ def test_repo_helper_circular_dependency() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +```python:{path_to_utils.relative_to(project_root)} import math from transform_utils import DataTransformer -from code_to_optimize.code_directories.retriever.utils import DataProcessor class DataProcessor: @@ -1462,8 +1477,9 @@ def __init__(self, default_prefix: str = "PREFIX_"): def circular_dependency(self, data: str) -> str: \"\"\"Test circular dependency\"\"\" return DataTransformer().circular_dependency(data) - - +``` +```python:{path_to_transform_utils.relative_to(project_root)} +from code_to_optimize.code_directories.retriever.utils import DataProcessor class DataTransformer: def __init__(self): @@ -1471,9 +1487,8 @@ def __init__(self): def circular_dependency(self, data): return DataProcessor().circular_dependency(data) - - - """ +``` +""" expected_read_only_context = f""" ```python:{path_to_utils.relative_to(project_root)} class DataProcessor: @@ -1502,7 +1517,7 @@ def circular_dependency(self, data): ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1545,13 +1560,15 @@ def outside_method(): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +```python:{file_path.relative_to(opt.args.project_root)} class MyClass: def __init__(self): self.x = 1 self.y = outside_method() def target_method(self): return self.x + self.y +``` """ expected_read_only_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} @@ -1567,7 +1584,7 @@ def target_method(self): return self.x + self.y ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1624,11 +1641,11 @@ def function_to_optimize(): return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data() ``` """ - expected_read_write_context = """ + expected_read_write_context = f""" +```python:{path_to_main.relative_to(project_root)} import requests from globals import API_URL from utils import DataProcessor -import code_to_optimize.code_directories.retriever.main def fetch_and_transform_data(): # Use the global variable for the request @@ -1642,13 +1659,15 @@ def fetch_and_transform_data(): transformed = processor.transform_data(processed) return transformed - - +``` +```python:{path_to_fto.relative_to(project_root)} +import code_to_optimize.code_directories.retriever.main def function_to_optimize(): return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data() +``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1807,7 +1826,8 @@ def get_system_details(): read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context # The expected contexts - expected_read_write_context = """ + expected_read_write_context = f""" +```python:{main_file_path.relative_to(opt.args.project_root)} import utility_module class Calculator: @@ -1834,6 +1854,7 @@ def calculate(self, operation, x, y): return self.subtract(x, y) else: return None +``` """ expected_read_only_context = """ ```python:utility_module.py @@ -1891,7 +1912,7 @@ def calculate(self, operation, x, y): ``` """ # Verify the contexts match the expected values - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -2049,11 +2070,10 @@ def get_system_details(): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code # The expected contexts - expected_read_write_context = """ + expected_read_write_context = f""" +```python:utility_module.py # Function that will be used in the main code -import utility_module - def select_precision(precision, fallback_precision): if precision is None: return fallback_precision or DEFAULT_PRECISION @@ -2075,8 +2095,9 @@ def select_precision(precision, fallback_precision): return precision.lower() else: return DEFAULT_PRECISION - - +``` +```python:{main_file_path.relative_to(opt.args.project_root)} +import utility_module class Calculator: def __init__(self, precision="high", fallback_precision=None, mode="standard"): @@ -2088,6 +2109,7 @@ def __init__(self, precision="high", fallback_precision=None, mode="standard"): self.backend = utility_module.CALCULATION_BACKEND self.system = utility_module.SYSTEM_TYPE self.default_precision = utility_module.DEFAULT_PRECISION +``` """ expected_read_only_context = """ ```python:utility_module.py @@ -2102,7 +2124,7 @@ def __init__(self, precision="high", fallback_precision=None, mode="standard"): CALCULATION_BACKEND = "python" ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() @@ -2438,8 +2460,8 @@ def simple_method(self): assert "return 42" in code_content - -def test_replace_functions_and_add_imports(): +# This shouldn't happen as we are now using a scoped optimization context, but keep it just in case +def test_circular_deps(): path_to_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "circular_deps" file_abs_path = path_to_root / "api_client.py" optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8") diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 7272163d3..d77d6a43e 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -13,7 +13,7 @@ replace_functions_in_file, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import CodeOptimizationContext, FunctionParent +from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig @@ -41,11 +41,16 @@ class Args: def test_code_replacement_global_statements(): - optimized_code = """import numpy as np + project_root = Path(__file__).parent.parent.resolve() + code_path = (project_root / "code_to_optimize/bubble_sort_optimized.py").resolve() + optimized_code = f"""```python:{code_path.relative_to(project_root)} +import numpy as np + inconsequential_var = '123' def sorter(arr): - return arr.sort()""" - code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_optimized.py").resolve() + return arr.sort() +``` +""" original_code_str = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").read_text( encoding="utf-8" ) @@ -70,7 +75,7 @@ def sorter(arr): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code ) final_output = code_path.read_text(encoding="utf-8") assert "inconsequential_var = '123'" in final_output @@ -118,6 +123,7 @@ def totally_new_function(value): function_name: str = "NewClass.new_function" preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code) + print(f"Preexisting objects: {preexisting_objects}") new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=[function_name], @@ -1666,6 +1672,9 @@ def new_function2(value): def test_global_reassignment() -> None: + root_dir = Path(__file__).parent.parent.resolve() + code_path = (root_dir / "code_to_optimize/global_var_original.py").resolve() + original_code = """a=1 print("Hello world") def some_fn(): @@ -1678,7 +1687,9 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) """ - optimized_code = """import numpy as np + optimized_code = f"""```python:{code_path.relative_to(root_dir)} +import numpy as np + def some_fn(): a=np.zeros(10) print("did something") @@ -1691,7 +1702,8 @@ def new_function2(value): return cst.ensure_type(value, str) a=2 print("Hello world") - """ +``` +""" expected_code = """import numpy as np print("Hello world") @@ -1713,7 +1725,6 @@ def __call__(self, value): return "I am still old" def new_function2(value): return cst.ensure_type(value, str)""" - code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve() code_path.write_text(original_code, encoding="utf-8") tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/") project_root_path = (Path(__file__).parent / "..").resolve() @@ -1735,7 +1746,7 @@ def new_function2(value): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) @@ -1753,7 +1764,8 @@ def new_function2(value): return cst.ensure_type(value, str) a=1 """ - optimized_code = """a=2 + optimized_code = f"""```python:{code_path.relative_to(root_dir)} +a=2 import numpy as np def some_fn(): a=np.zeros(10) @@ -1766,7 +1778,8 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) print("Hello world") - """ +``` +""" expected_code = """import numpy as np print("Hello world") @@ -1811,7 +1824,7 @@ def new_function2(value): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) @@ -1829,7 +1842,8 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) """ - optimized_code = """import numpy as np + optimized_code = f"""```python:{code_path.relative_to(root_dir)} +import numpy as np a=2 def some_fn(): a=np.zeros(10) @@ -1843,7 +1857,8 @@ def new_function2(value): return cst.ensure_type(value, str) a=3 print("Hello world") - """ +``` +""" expected_code = """import numpy as np print("Hello world") @@ -1888,7 +1903,7 @@ def new_function2(value): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) @@ -1906,7 +1921,8 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) """ - optimized_code = """a=2 + optimized_code = f"""```python:{code_path.relative_to(root_dir)} +a=2 import numpy as np def some_fn(): a=np.zeros(10) @@ -1919,7 +1935,8 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) print("Hello world") - """ +``` +""" expected_code = """import numpy as np print("Hello world") @@ -1964,7 +1981,7 @@ def new_function2(value): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) @@ -1982,7 +1999,8 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) """ - optimized_code = """import numpy as np + optimized_code = f"""```python:{code_path.relative_to(root_dir)} +import numpy as np a=2 def some_fn(): a=np.zeros(10) @@ -1996,7 +2014,8 @@ def new_function2(value): return cst.ensure_type(value, str) a=3 print("Hello world") - """ +``` +""" expected_code = """import numpy as np print("Hello world") @@ -2041,7 +2060,7 @@ def new_function2(value): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) @@ -2062,7 +2081,8 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) """ - optimized_code = """import numpy as np + optimized_code = f"""```python:{code_path.relative_to(root_dir)} +import numpy as np if 1<2: a=2 else: @@ -2079,6 +2099,7 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) print("Hello world") +``` """ expected_code = """import numpy as np @@ -2129,7 +2150,7 @@ def new_function2(value): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index fbd7d0b9d..5afc4630e 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -10,6 +10,7 @@ from codeflash.code_utils.formatter import format_code, sort_imports from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.models.models import CodeString, CodeStringsMarkdown from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig @@ -263,7 +264,12 @@ def _run_formatting_test(source_code: str, should_content_change: bool, expected helper_functions=[], path=target_path, original_code=optimizer.function_to_optimize_source_code, - optimized_function=optimized_function, + optimized_context=CodeStringsMarkdown(code_strings=[ + CodeString( + code=optimized_function, + file_path=target_path.relative_to(test_dir) + ) + ]), ) content = target_path.read_text(encoding="utf8") @@ -796,7 +802,7 @@ def _is_valid(self, item): return isinstance(item, dict) and "id" in item ''' - optimization_function = """ def process(self,data): + optimization_function = """def process(self,data): '''Single quote docstring with formatting issues.''' return{'result':[item for item in data if self._is_valid(item)]}""" _run_formatting_test(source_code, True, optimized_function=optimization_function, expected=expected) \ No newline at end of file diff --git a/tests/test_multi_file_code_replacement.py b/tests/test_multi_file_code_replacement.py new file mode 100644 index 000000000..05a9c01c0 --- /dev/null +++ b/tests/test_multi_file_code_replacement.py @@ -0,0 +1,166 @@ +from pathlib import Path +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown +from codeflash.optimization.function_optimizer import FunctionOptimizer +from codeflash.verification.verification_utils import TestConfig + + +class Args: + disable_imports_sorting = True + formatter_cmds = ["disabled"] + +def test_multi_file_replcement01() -> None: + root_dir = Path(__file__).parent.parent.resolve() + helper_file = (root_dir / "code_to_optimize/temp_helper.py").resolve() + + helper_file.write_text("""import re +from collections.abc import Sequence + +from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent + +def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int: + if not content: + return 0 + + if isinstance(content, str): + return len(_TOKEN_SPLIT_RE.split(content.strip())) + + tokens = 0 + for part in content: + if isinstance(part, str): + tokens += len(_TOKEN_SPLIT_RE.split(part.strip())) + elif isinstance(part, BinaryContent): + tokens += len(part.data) + # TODO(Marcelo): We need to study how we can estimate the tokens for AudioUrl or ImageUrl. + + return tokens + + +_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+') +""", encoding="utf-8") + + main_file = (root_dir / "code_to_optimize/temp_main.py").resolve() + + original_main = """from temp_helper import _estimate_string_tokens +from pydantic_ai_slim.pydantic_ai.usage import Usage + +def _get_string_usage(text: str) -> Usage: + response_tokens = _estimate_string_tokens(text) + return Usage(response_tokens=response_tokens, total_tokens=response_tokens) +""" + main_file.write_text(original_main, encoding="utf-8") + + optimized_code = f"""```python:{helper_file.relative_to(root_dir)} +import re +from collections.abc import Sequence + +from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent + +_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+') +_translate_table = {{ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}} + +def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int: + if not content: + return 0 + + if isinstance(content, str): + # Fast path using translate and split instead of regex when separat + s = content.strip() + if s: + s = s.translate(_translate_table) + # Split on whitespace (default). This handles multiple consecut + return len(s.split()) + return 0 + + tokens = 0 + for part in content: + if isinstance(part, str): + s = part.strip() + if s: + s = s.translate(_translate_table) + tokens += len(s.split()) + elif isinstance(part, BinaryContent): + tokens += len(part.data) + + return tokens +``` +```python:{main_file.relative_to(root_dir)} +from temp_helper import _estimate_string_tokens +from pydantic_ai_slim.pydantic_ai.usage import Usage + +def _get_string_usage(text: str) -> Usage: + response_tokens = _estimate_string_tokens(text) + return Usage(response_tokens=response_tokens, total_tokens=response_tokens) +``` +""" + + + + func = FunctionToOptimize(function_name="_get_string_usage", parents=[], file_path=main_file) + test_config = TestConfig( + tests_root=root_dir / "tests/pytest", + tests_project_rootdir=root_dir, + project_root_path=root_dir, + test_framework="pytest", + pytest_cmd="pytest", + ) + func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) + code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() + + + + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + + func_optimizer.args = Args() + func_optimizer.replace_function_and_helpers_with_optimized_code( + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code + ) + new_code = main_file.read_text(encoding="utf-8") + new_helper_code = helper_file.read_text(encoding="utf-8") + + helper_file.unlink(missing_ok=True) + main_file.unlink(missing_ok=True) + + expected_helper = """import re +from collections.abc import Sequence + +from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent + +def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int: + if not content: + return 0 + + if isinstance(content, str): + # Fast path using translate and split instead of regex when separat + s = content.strip() + if s: + s = s.translate(_translate_table) + # Split on whitespace (default). This handles multiple consecut + return len(s.split()) + return 0 + + tokens = 0 + for part in content: + if isinstance(part, str): + s = part.strip() + if s: + s = s.translate(_translate_table) + tokens += len(s.split()) + elif isinstance(part, BinaryContent): + tokens += len(part.data) + + return tokens + + +_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+') + +_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'} +""" + + assert new_code.rstrip() == original_main.rstrip() # No Change + assert new_helper_code.rstrip() == expected_helper.rstrip() \ No newline at end of file diff --git a/tests/test_unused_helper_revert.py b/tests/test_unused_helper_revert.py index 5642cd618..30f291e62 100644 --- a/tests/test_unused_helper_revert.py +++ b/tests/test_unused_helper_revert.py @@ -6,6 +6,7 @@ import pytest from codeflash.context.unused_definition_remover import detect_unused_helper_functions from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.models.models import CodeStringsMarkdown from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig @@ -56,6 +57,7 @@ def test_detect_unused_helper_functions(temp_project): # Optimized version that only calls one helper optimized_code = """ +```python:main.py def entrypoint_function(n): \"\"\"Optimized function that only calls one helper.\"\"\" result1 = helper_function_1(n) @@ -68,6 +70,7 @@ def helper_function_1(x): def helper_function_2(x): \"\"\"Second helper function - MODIFIED VERSION should be reverted.\"\"\" return x * 4 # This change should be reverted to original x * 3 +``` """ # Create FunctionToOptimize instance @@ -89,7 +92,7 @@ def helper_function_2(x): code_context = ctx_result.unwrap() # Test unused helper detection - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) # Should detect helper_function_2 as unused unused_names = {uh.qualified_name for uh in unused_helpers} @@ -100,6 +103,7 @@ def helper_function_2(x): # Also test the complete replace_function_and_helpers_with_optimized_code workflow # First modify the optimized code to include a MODIFIED unused helper optimized_code_with_modified_helper = """ +```python:main.py def entrypoint_function(n): \"\"\"Optimized function that only calls one helper.\"\"\" result1 = helper_function_1(n) @@ -112,15 +116,15 @@ def helper_function_1(x): def helper_function_2(x): \"\"\"Second helper function - MODIFIED VERSION should be reverted.\"\"\" return x * 7 # This should be reverted to x * 3 +``` """ original_helper_code = {main_file: main_file.read_text()} # Apply optimization and test reversion optimizer.replace_function_and_helpers_with_optimized_code( - code_context, optimized_code_with_modified_helper, original_helper_code + code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code_with_modified_helper), original_helper_code ) - # Check final file content final_content = main_file.read_text() @@ -138,7 +142,7 @@ def helper_function_2(x): original_helper_code = {main_file: main_file.read_text()} # Apply optimization and test reversion - optimizer.replace_function_and_helpers_with_optimized_code(code_context, optimized_code, original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code) # Check final file content final_content = main_file.read_text() @@ -160,6 +164,7 @@ def test_revert_unused_helper_functions(temp_project): # Optimized version that only calls one helper and modifies the unused one optimized_code = """ +```python:main.py def entrypoint_function(n): \"\"\"Optimized function that only calls one helper.\"\"\" result1 = helper_function_1(n) @@ -172,6 +177,7 @@ def helper_function_1(x): def helper_function_2(x): \"\"\"Modified helper function - should be reverted.\"\"\" return x * 4 # This change should be reverted +``` """ # Create FunctionToOptimize instance @@ -200,7 +206,7 @@ def helper_function_2(x): # 1. Apply the optimization # 2. Detect unused helpers # 3. Revert unused helpers to original definitions - optimizer.replace_function_and_helpers_with_optimized_code(code_context, optimized_code, original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code) # Check final file content final_content = main_file.read_text() @@ -222,6 +228,7 @@ def test_no_unused_helpers_no_revert(temp_project): # Optimized version that still calls both helpers optimized_code = """ +```python:main.py def entrypoint_function(n): \"\"\"Optimized function that still calls both helpers.\"\"\" result1 = helper_function_1(n) @@ -235,6 +242,7 @@ def helper_function_1(x): def helper_function_2(x): \"\"\"Second helper function - optimized.\"\"\" return x * 3 +``` """ # Create FunctionToOptimize instance @@ -259,11 +267,11 @@ def helper_function_2(x): original_helper_code = {main_file: main_file.read_text()} # Test detection - should find no unused helpers - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) assert len(unused_helpers) == 0, "No helpers should be detected as unused" # Apply optimization - optimizer.replace_function_and_helpers_with_optimized_code(code_context, optimized_code, original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code) # Check final file content - should contain the optimized versions final_content = main_file.read_text() @@ -304,12 +312,14 @@ def helper_function_2(x): # Optimized version that only calls one helper optimized_code = """ +```python:main.py from helpers import helper_function_1 def entrypoint_function(n): \"\"\"Optimized function that only calls one helper.\"\"\" result1 = helper_function_1(n) return result1 + n * 3 # Inlined helper_function_2 +``` """ # Create test config @@ -340,7 +350,7 @@ def entrypoint_function(n): code_context = ctx_result.unwrap() # Test unused helper detection - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) # Should detect helper_function_2 as unused unused_names = {uh.qualified_name for uh in unused_helpers} @@ -383,8 +393,7 @@ def helper_function_2(x): } # Apply optimization and test reversion - optimizer.replace_function_and_helpers_with_optimized_code(code_context, optimized_code, original_helper_code) - + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code) # Check main file content main_content = main_file.read_text() assert "result1 + n * 3" in main_content, "Entrypoint function should be optimized" @@ -432,7 +441,7 @@ def helper_function_2(x): } # Apply optimization and test reversion - optimizer.replace_function_and_helpers_with_optimized_code(code_context, optimized_code, original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code) # Check main file content main_content = main_file.read_text() @@ -479,6 +488,7 @@ def helper_method_2(self, x): # Optimized version that only calls one helper method optimized_code = """ +```python:main.py class Calculator: def entrypoint_method(self, n): \"\"\"Optimized method that only calls one helper.\"\"\" @@ -492,6 +502,7 @@ def helper_method_1(self, x): def helper_method_2(self, x): \"\"\"Second helper method - should be reverted.\"\"\" return x * 4 +``` """ # Create test config @@ -527,7 +538,7 @@ def helper_method_2(self, x): code_context = ctx_result.unwrap() # Test unused helper detection - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) # Should detect Calculator.helper_method_2 as unused unused_names = {uh.qualified_name for uh in unused_helpers} @@ -538,6 +549,7 @@ def helper_method_2(self, x): # Also test the complete replace_function_and_helpers_with_optimized_code workflow # Update optimized code to include a MODIFIED unused helper optimized_code_with_modified_helper = """ +```python:main.py class Calculator: def entrypoint_method(self, n): \"\"\"Optimized method that only calls one helper.\"\"\" @@ -551,13 +563,14 @@ def helper_method_1(self, x): def helper_method_2(self, x): \"\"\"Second helper method - MODIFIED VERSION should be reverted.\"\"\" return x * 8 # This should be reverted to x * 3 +``` """ original_helper_code = {main_file: main_file.read_text()} # Apply optimization and test reversion optimizer.replace_function_and_helpers_with_optimized_code( - code_context, optimized_code_with_modified_helper, original_helper_code + code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code_with_modified_helper), original_helper_code ) # Check final file content @@ -576,7 +589,7 @@ def helper_method_2(self, x): # Test reversion original_helper_code = {main_file: main_file.read_text()} - optimizer.replace_function_and_helpers_with_optimized_code(code_context, optimized_code, original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code) # Check final file content final_content = main_file.read_text() @@ -620,6 +633,7 @@ def process_data(self, n): # Optimized version that only calls one external helper optimized_code = """ +```python:main.py def external_helper_1(x): \"\"\"External helper function.\"\"\" return x * 2 @@ -633,6 +647,7 @@ def process_data(self, n): \"\"\"Optimized method that only calls one helper.\"\"\" result1 = external_helper_1(n) return result1 + n * 3 # Inlined external_helper_2 +``` """ # Create test config @@ -668,7 +683,7 @@ def process_data(self, n): code_context = ctx_result.unwrap() # Test unused helper detection - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) # Should detect external_helper_2 as unused unused_names = {uh.qualified_name for uh in unused_helpers} @@ -679,6 +694,7 @@ def process_data(self, n): # Also test the complete replace_function_and_helpers_with_optimized_code workflow # Update optimized code to include a MODIFIED unused helper optimized_code_with_modified_helper = """ +```python:main.py def external_helper_1(x): \"\"\"External helper function.\"\"\" return x * 2 @@ -692,13 +708,14 @@ def process_data(self, n): \"\"\"Optimized method that only calls one helper.\"\"\" result1 = external_helper_1(n) return result1 + n * 3 # Inlined external_helper_2 +``` """ original_helper_code = {main_file: main_file.read_text()} # Apply optimization and test reversion optimizer.replace_function_and_helpers_with_optimized_code( - code_context, optimized_code_with_modified_helper, original_helper_code + code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code_with_modified_helper), original_helper_code ) # Check final file content @@ -717,6 +734,7 @@ def process_data(self, n): # Also test the complete replace_function_and_helpers_with_optimized_code workflow # Update optimized code to include a MODIFIED unused helper optimized_code_with_modified_helper = """ +```python:main.py def external_helper_1(x): \"\"\"External helper function.\"\"\" return x * 2 @@ -730,13 +748,14 @@ def process_data(self, n): \"\"\"Optimized method that only calls one helper.\"\"\" result1 = external_helper_1(n) return result1 + n * 3 # Inlined external_helper_2 +``` """ original_helper_code = {main_file: main_file.read_text()} # Apply optimization and test reversion optimizer.replace_function_and_helpers_with_optimized_code( - code_context, optimized_code_with_modified_helper, original_helper_code + code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code_with_modified_helper), original_helper_code ) # Check final file content @@ -787,6 +806,7 @@ def local_helper(self, x): # Optimized version that inlines one helper optimized_code = """ +```python:main.py def global_helper_1(x): return x * 2 @@ -802,6 +822,7 @@ def compute(self, n): def local_helper(self, x): return x + 1 +``` """ # Create test config @@ -868,7 +889,7 @@ def local_helper(self, x): ] }, )(), - optimized_code, + CodeStringsMarkdown.parse_markdown_code(optimized_code), ) # Should detect global_helper_2 as unused @@ -955,6 +976,7 @@ def clean_data(x): # Optimized version that only uses some functions optimized_code = """ +```python:main.py import utils from math_helpers import add @@ -965,6 +987,7 @@ def entrypoint_function(n): # Inlined multiply: result3 = n * 2 # Inlined process_data: result4 = n ** 2 return result1 + result2 + (n * 2) + (n ** 2) +``` """ # Create test config @@ -995,7 +1018,7 @@ def entrypoint_function(n): code_context = ctx_result.unwrap() # Test unused helper detection - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) # Should detect multiply, process_data as unused (at minimum) unused_names = {uh.qualified_name for uh in unused_helpers} @@ -1055,7 +1078,7 @@ def subtract(x, y): } # Apply optimization and test reversion - optimizer.replace_function_and_helpers_with_optimized_code(code_context, optimized_code, original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code) # Check main file content main_content = main_file.read_text() @@ -1116,6 +1139,7 @@ def divide_numbers(x, y): # Optimized version that only uses add_numbers optimized_code = """ +```python:main.py import calculator def entrypoint_function(n): @@ -1123,6 +1147,7 @@ def entrypoint_function(n): result1 = calculator.add_numbers(n, 10) # Inlined: result2 = n * 5 return result1 + (n * 5) +``` """ # Create test config @@ -1153,7 +1178,7 @@ def entrypoint_function(n): code_context = ctx_result.unwrap() # Test unused helper detection - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) # Should detect multiply_numbers and divide_numbers as unused unused_names = {uh.qualified_name for uh in unused_helpers} @@ -1204,7 +1229,7 @@ def divide_numbers(x, y): } # Apply optimization and test reversion - optimizer.replace_function_and_helpers_with_optimized_code(code_context, optimized_code, original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code) # Check main file content main_content = main_file.read_text() @@ -1263,7 +1288,7 @@ def divide_numbers(x, y): } # Apply optimization and test reversion - optimizer.replace_function_and_helpers_with_optimized_code(code_context, optimized_code, original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code) # Check main file content main_content = main_file.read_text() @@ -1318,6 +1343,7 @@ def calculate_class(cls, n): # Optimized static method that inlines one utility optimized_static_code = """ +```python:main.py def utility_function_1(x): return x * 2 @@ -1337,6 +1363,7 @@ def calculate_class(cls, n): result1 = utility_function_1(n) result2 = utility_function_2(n) return result1 - result2 +``` """ # Create test config @@ -1373,7 +1400,7 @@ def calculate_class(cls, n): # Test unused helper detection for static method unused_helpers = detect_unused_helper_functions( - optimizer.function_to_optimize, code_context, optimized_static_code + optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_static_code) ) # Should detect utility_function_2 as unused @@ -1385,6 +1412,7 @@ def calculate_class(cls, n): # Also test the complete replace_function_and_helpers_with_optimized_code workflow # Update optimized code to include a MODIFIED unused helper optimized_static_code_with_modified_helper = """ +```python:main.py def utility_function_1(x): return x * 2 @@ -1404,13 +1432,14 @@ def calculate_class(cls, n): result1 = utility_function_1(n) result2 = utility_function_2(n) return result1 - result2 +``` """ original_helper_code = {main_file: main_file.read_text()} # Apply optimization and test reversion optimizer.replace_function_and_helpers_with_optimized_code( - code_context, optimized_static_code_with_modified_helper, original_helper_code + code_context, CodeStringsMarkdown.parse_markdown_code(optimized_static_code_with_modified_helper), original_helper_code ) # Check final file content