From 216eb7e7942afed735824f92b550dd797080a778 Mon Sep 17 00:00:00 2001 From: mohammed Date: Wed, 16 Jul 2025 18:48:31 +0300 Subject: [PATCH 01/25] start using markdown representation for read writable context --- codeflash/context/code_context_extractor.py | 11 ++++++----- codeflash/models/models.py | 18 ++++++++++++++++-- tests/test_code_context_extractor.py | 7 ++++--- 3 files changed, 26 insertions(+), 10 deletions(-) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index b520f12b7..c86b8650c 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -61,13 +61,14 @@ def get_code_optimization_context( ) # Extract code context for optimization - final_read_writable_code = extract_code_string_context_from_files( + final_read_writable_code = extract_code_markdown_context_from_files( helpers_of_fto_dict, - {}, + helpers_of_helpers_dict, project_root_path, remove_docstrings=False, code_context_type=CodeContextType.READ_WRITABLE, - ).code + ) + read_only_code_markdown = extract_code_markdown_context_from_files( helpers_of_fto_dict, helpers_of_helpers_dict, @@ -84,14 +85,14 @@ def get_code_optimization_context( ) # Handle token limits - final_read_writable_tokens = encoded_tokens_len(final_read_writable_code) + final_read_writable_tokens = encoded_tokens_len(final_read_writable_code.__str__) if final_read_writable_tokens > optim_token_limit: raise ValueError("Read-writable code has exceeded token limit, cannot proceed") # Setup preexisting objects for code replacer preexisting_objects = set( chain( - find_preexisting_objects(final_read_writable_code), + find_preexisting_objects(final_read_writable_code.__str__), *(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings), ) ) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index e96d12423..569606c35 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -19,7 +19,7 @@ from typing import Annotated, Optional, cast from jedi.api.classes import Name -from pydantic import AfterValidator, BaseModel, ConfigDict, Field +from pydantic import AfterValidator, BaseModel, ConfigDict from pydantic.dataclasses import dataclass from codeflash.cli_cmds.console import console, logger @@ -139,8 +139,22 @@ class CodeString(BaseModel): file_path: Optional[Path] = None +def get_code_block_splitter(file_path: Path) -> str: + return f"# codeflash-splitter__{file_path}" + + class CodeStringsMarkdown(BaseModel): code_strings: list[CodeString] = [] + cached_code: str | None = None + + @property + def __str__(self) -> str: + if self.cached_code is not None: + return self.cached_code + self.cached_code = "\n\n".join( + get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings + ) + return self.cached_code @property def markdown(self) -> str: @@ -155,7 +169,7 @@ def markdown(self) -> str: class CodeOptimizationContext(BaseModel): testgen_context_code: str = "" - read_writable_code: str = Field(min_length=1) + read_writable_code: CodeStringsMarkdown read_only_context_code: str = "" hashing_code_context: str = "" hashing_code_context_hash: str = "" diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 25200cb9c..9b2706245 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -9,7 +9,7 @@ import pytest from codeflash.context.code_context_extractor import get_code_optimization_context from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import FunctionParent +from codeflash.models.models import FunctionParent, get_code_block_splitter from codeflash.optimization.optimizer import Optimizer from codeflash.code_utils.code_replacer import replace_functions_and_add_imports from codeflash.code_utils.code_extractor import add_global_assignments @@ -88,7 +88,8 @@ def test_code_replacement10() -> None: read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +{get_code_block_splitter(file_path.relative_to(file_path.parent))} from __future__ import annotations class HelperClass: @@ -125,7 +126,7 @@ def main_method(self): ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() From 083a983434710cc7d92ad069b1c090559b5aa409 Mon Sep 17 00:00:00 2001 From: mohammed Date: Thu, 24 Jul 2025 19:11:00 +0300 Subject: [PATCH 02/25] render the code markdown to the console --- codeflash/optimization/function_optimizer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index ef7b215bc..260228a85 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -93,6 +93,7 @@ from codeflash.either import Result from codeflash.models.models import ( BenchmarkKey, + CodeStringsMarkdown, CoverageData, FunctionCalledInTest, FunctionSource, @@ -164,7 +165,7 @@ def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[P helper_code = f.read() original_helper_code[helper_function_path] = helper_code - if has_any_async_functions(code_context.read_writable_code): + if has_any_async_functions(code_context.read_writable_code.__str__): return Failure("Codeflash does not support async functions in the code to optimize.") # Random here means that we still attempt optimization with a fractional chance to see if # last time we could not find an optimization, maybe this time we do. @@ -283,7 +284,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() - code_print(code_context.read_writable_code) + code_print(code_context.read_writable_code.__str__) test_setup_result = self.generate_and_instrument_tests( # also generates optimizations code_context, should_run_experiment=should_run_experiment @@ -756,7 +757,7 @@ def instrument_existing_tests(self, function_to_all_tests: dict[str, set[Functio def generate_tests_and_optimizations( self, testgen_context_code: str, - read_writable_code: str, + read_writable_code: CodeStringsMarkdown, read_only_context_code: str, helper_functions: list[FunctionSource], generated_test_paths: list[Path], @@ -777,7 +778,7 @@ def generate_tests_and_optimizations( ) future_optimization_candidates = executor.submit( self.aiservice_client.optimize_python_code, - read_writable_code, + read_writable_code.__str__, read_only_context_code, self.function_trace_id[:-4] + "EXP0" if run_experiment else self.function_trace_id, N_CANDIDATES, @@ -796,7 +797,7 @@ def generate_tests_and_optimizations( if run_experiment: future_candidates_exp = executor.submit( self.local_aiservice_client.optimize_python_code, - read_writable_code, + read_writable_code.__str__, read_only_context_code, self.function_trace_id[:-4] + "EXP1", N_CANDIDATES, From e504c879c5b549c834c88a4ab4b41e221bb6d1f3 Mon Sep 17 00:00:00 2001 From: mohammed Date: Fri, 25 Jul 2025 01:44:15 +0300 Subject: [PATCH 03/25] split & apply --- codeflash/models/models.py | 19 ++++++++++++++++++- codeflash/optimization/function_optimizer.py | 9 +++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 569606c35..8d9e9df13 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -139,8 +139,11 @@ class CodeString(BaseModel): file_path: Optional[Path] = None +SPLITTER_MARKER = "# codeflash-splitter__" + + def get_code_block_splitter(file_path: Path) -> str: - return f"# codeflash-splitter__{file_path}" + return f"{SPLITTER_MARKER}{file_path}" class CodeStringsMarkdown(BaseModel): @@ -166,6 +169,20 @@ def markdown(self) -> str: ] ) + @staticmethod + def from_str_with_markers(code_with_markers: str) -> list[CodeString]: + pattern = rf"{SPLITTER_MARKER}([^\n]+)\n" + matches = list(re.finditer(pattern, code_with_markers)) + + results = [] + for i, match in enumerate(matches): + start = match.end() + end = matches[i + 1].start() if i + 1 < len(matches) else len(code_with_markers) + file_path = match.group(1).strip() + code = code_with_markers[start:end].lstrip("\n") + results.append(CodeString(file_path=file_path, code=code)) + return results + class CodeOptimizationContext(BaseModel): testgen_context_code: str = "" diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 260228a85..cb93e96ba 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -62,6 +62,7 @@ from codeflash.models.models import ( BestOptimization, CodeOptimizationContext, + CodeStringsMarkdown, GeneratedTests, GeneratedTestsList, OptimizationSet, @@ -93,7 +94,6 @@ from codeflash.either import Result from codeflash.models.models import ( BenchmarkKey, - CodeStringsMarkdown, CoverageData, FunctionCalledInTest, FunctionSource, @@ -621,13 +621,18 @@ def replace_function_and_helpers_with_optimized_code( read_writable_functions_by_file_path[self.function_to_optimize.file_path].add( self.function_to_optimize.qualified_name ) + code_strings = CodeStringsMarkdown.from_str_with_markers(optimized_code) + optimized_code_dict = {code_string.file_path: code_string.code for code_string in code_strings} + logger.debug(f"Optimized code: {optimized_code_dict}") for helper_function in code_context.helper_functions: if helper_function.jedi_definition.type != "class": read_writable_functions_by_file_path[helper_function.file_path].add(helper_function.qualified_name) for module_abspath, qualified_names in read_writable_functions_by_file_path.items(): + relative_module_path = module_abspath.relative_to(self.project_root) + logger.debug(f"applying optimized code to: {relative_module_path}") did_update |= replace_function_definitions_in_module( function_names=list(qualified_names), - optimized_code=optimized_code, + optimized_code=optimized_code_dict.get(relative_module_path), module_abspath=module_abspath, preexisting_objects=code_context.preexisting_objects, project_root_path=self.project_root, From 99cd9dc706e1ad1588dd7a37427590e0d493f2a4 Mon Sep 17 00:00:00 2001 From: mohammed Date: Fri, 25 Jul 2025 15:13:10 +0300 Subject: [PATCH 04/25] fix tests for context extractor --- codeflash/models/models.py | 2 +- codeflash/optimization/function_optimizer.py | 2 +- tests/test_code_context_extractor.py | 132 ++++++++++--------- 3 files changed, 73 insertions(+), 63 deletions(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 8d9e9df13..c69e5deaf 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -154,7 +154,7 @@ class CodeStringsMarkdown(BaseModel): def __str__(self) -> str: if self.cached_code is not None: return self.cached_code - self.cached_code = "\n\n".join( + self.cached_code = "\n".join( get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings ) return self.cached_code diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index cb93e96ba..a6770c532 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -376,7 +376,7 @@ def determine_best_candidate( ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client future_line_profile_results = executor.submit( ai_service_client.optimize_python_code_line_profiler, - source_code=code_context.read_writable_code, + source_code=code_context.read_writable_code.__str__, dependency_code=code_context.read_only_context_code, trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, line_profiler_results=original_code_baseline.line_profile_results["str_out"], diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 9b2706245..fc76a68ff 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -146,7 +146,8 @@ def test_class_method_dependencies() -> None: read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +{get_code_block_splitter(file_path.relative_to(file_path.parent))} from __future__ import annotations from collections import defaultdict @@ -199,7 +200,7 @@ def topologicalSort(self): ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -225,9 +226,9 @@ def test_bubble_sort_helper() -> None: read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +{get_code_block_splitter("code_to_optimize/code_directories/retriever/bubble_sort_with_math.py")} import math -from bubble_sort_with_math import sorter def sorter(arr): arr.sort() @@ -235,7 +236,8 @@ def sorter(arr): print(x) return arr - +{get_code_block_splitter("code_to_optimize/code_directories/retriever/bubble_sort_imported.py")} +from bubble_sort_with_math import sorter def sort_from_another_file(arr): sorted_arr = sorter(arr) @@ -258,8 +260,7 @@ def sort_from_another_file(arr): return sorted_arr ``` """ - - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -456,7 +457,8 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +{get_code_block_splitter(file_path.relative_to(opt.args.project_root))} class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): def __init__(self) -> None: ... @@ -645,7 +647,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: return self.__backend__.get_cache_or_call(func=self.__wrapped__, args=args, kwargs=kwargs, lifespan=self.__duration__) ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -697,7 +699,8 @@ def helper_method(self): read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +{get_code_block_splitter(file_path.relative_to(opt.args.project_root))} class MyClass: def __init__(self): self.x = 1 @@ -737,7 +740,7 @@ def helper_method(self): ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -794,7 +797,8 @@ def helper_method(self): read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context # In this scenario, the read-only code context is too long, so the read-only docstrings are removed. - expected_read_write_context = """ + expected_read_write_context = f""" +{get_code_block_splitter(file_path.relative_to(opt.args.project_root))} class MyClass: def __init__(self): self.x = 1 @@ -832,7 +836,7 @@ def helper_method(self): return self.x ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -889,7 +893,8 @@ def helper_method(self): read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context # In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely. - expected_read_write_context = """ + expected_read_write_context = f""" +{get_code_block_splitter(file_path.relative_to(opt.args.project_root))} class MyClass: def __init__(self): self.x = 1 @@ -918,7 +923,7 @@ def helper_method(self): return self.x ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1042,11 +1047,9 @@ def test_repo_helper() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +{get_code_block_splitter(path_to_utils.relative_to(project_root))} import math -import requests -from globals import API_URL -from utils import DataProcessor class DataProcessor: @@ -1063,7 +1066,10 @@ def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str: \"\"\"Add a prefix to the processed data.\"\"\" return prefix + data - +{get_code_block_splitter(path_to_file.relative_to(project_root))} +import requests +from globals import API_URL +from utils import DataProcessor def fetch_and_process_data(): # Use the global variable for the request @@ -1078,8 +1084,7 @@ def fetch_and_process_data(): processed = processor.add_prefix(processed) return processed - - """ +""" expected_read_only_context = f""" ```python:{path_to_utils.relative_to(project_root)} class DataProcessor: @@ -1113,7 +1118,7 @@ def fetch_and_process_data(): return processed ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1134,12 +1139,10 @@ def test_repo_helper_of_helper() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +{get_code_block_splitter(path_to_utils.relative_to(project_root))} import math from transform_utils import DataTransformer -import requests -from globals import API_URL -from utils import DataProcessor class DataProcessor: @@ -1156,7 +1159,10 @@ def transform_data(self, data: str) -> str: \"\"\"Transform the processed data\"\"\" return DataTransformer().transform(data) - +{get_code_block_splitter(path_to_file.relative_to(project_root))} +import requests +from globals import API_URL +from utils import DataProcessor def fetch_and_transform_data(): # Use the global variable for the request @@ -1211,8 +1217,7 @@ def fetch_and_transform_data(): return transformed ``` """ - - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1232,10 +1237,8 @@ def test_repo_helper_of_helper_same_class() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ -import math -from transform_utils import DataTransformer - + expected_read_write_context = f""" +{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))} class DataTransformer: def __init__(self): self.data = None @@ -1243,7 +1246,9 @@ def __init__(self): def transform_using_own_method(self, data): return self.transform(data) - +{get_code_block_splitter(path_to_utils.relative_to(project_root))} +import math +from transform_utils import DataTransformer class DataProcessor: @@ -1292,7 +1297,7 @@ def transform_data_own_method(self, data: str) -> str: ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1312,10 +1317,8 @@ def test_repo_helper_of_helper_same_file() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ -import math -from transform_utils import DataTransformer - + expected_read_write_context = f""" +{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))} class DataTransformer: def __init__(self): self.data = None @@ -1323,7 +1326,9 @@ def __init__(self): def transform_using_same_file_function(self, data): return update_data(data) - +{get_code_block_splitter(path_to_utils.relative_to(project_root))} +import math +from transform_utils import DataTransformer class DataProcessor: @@ -1367,7 +1372,7 @@ def transform_data_same_file_function(self, data: str) -> str: ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1386,7 +1391,8 @@ def test_repo_helper_all_same_file() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))} class DataTransformer: def __init__(self): self.data = None @@ -1428,7 +1434,7 @@ def update_data(data): ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1448,10 +1454,10 @@ def test_repo_helper_circular_dependency() -> None: code_ctx = get_code_optimization_context(function_to_optimize, project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +{get_code_block_splitter(path_to_utils.relative_to(project_root))} import math from transform_utils import DataTransformer -from code_to_optimize.code_directories.retriever.utils import DataProcessor class DataProcessor: @@ -1464,7 +1470,8 @@ def circular_dependency(self, data: str) -> str: \"\"\"Test circular dependency\"\"\" return DataTransformer().circular_dependency(data) - +{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))} +from code_to_optimize.code_directories.retriever.utils import DataProcessor class DataTransformer: def __init__(self): @@ -1503,7 +1510,7 @@ def circular_dependency(self, data): ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1546,7 +1553,8 @@ def outside_method(): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context - expected_read_write_context = """ + expected_read_write_context = f""" +{get_code_block_splitter(file_path.relative_to(opt.args.project_root))} class MyClass: def __init__(self): self.x = 1 @@ -1568,7 +1576,7 @@ def target_method(self): return self.x + self.y ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1625,11 +1633,11 @@ def function_to_optimize(): return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data() ``` """ - expected_read_write_context = """ + expected_read_write_context = f""" +{get_code_block_splitter(path_to_main.relative_to(project_root))} import requests from globals import API_URL from utils import DataProcessor -import code_to_optimize.code_directories.retriever.main def fetch_and_transform_data(): # Use the global variable for the request @@ -1644,12 +1652,13 @@ def fetch_and_transform_data(): return transformed - +{get_code_block_splitter(path_to_fto.relative_to(project_root))} +import code_to_optimize.code_directories.retriever.main def function_to_optimize(): return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data() """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1808,7 +1817,8 @@ def get_system_details(): read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context # The expected contexts - expected_read_write_context = """ + expected_read_write_context = f""" +{get_code_block_splitter(main_file_path.relative_to(opt.args.project_root))} import utility_module class Calculator: @@ -1892,7 +1902,7 @@ def calculate(self, operation, x, y): ``` """ # Verify the contexts match the expected values - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -2050,11 +2060,10 @@ def get_system_details(): code_ctx = get_code_optimization_context(function_to_optimize, opt.args.project_root) read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code # The expected contexts - expected_read_write_context = """ + expected_read_write_context = f""" +{get_code_block_splitter("utility_module.py")} # Function that will be used in the main code -import utility_module - def select_precision(precision, fallback_precision): if precision is None: return fallback_precision or DEFAULT_PRECISION @@ -2077,7 +2086,8 @@ def select_precision(precision, fallback_precision): else: return DEFAULT_PRECISION - +{get_code_block_splitter(main_file_path.relative_to(opt.args.project_root))} +import utility_module class Calculator: def __init__(self, precision="high", fallback_precision=None, mode="standard"): @@ -2103,7 +2113,7 @@ def __init__(self, precision="high", fallback_precision=None, mode="standard"): CALCULATION_BACKEND = "python" ``` """ - assert read_write_context.strip() == expected_read_write_context.strip() + assert read_write_context.__str__.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() From 330bf91e738ac521b06d738dc25e557b5b5b8064 Mon Sep 17 00:00:00 2001 From: mohammed Date: Fri, 25 Jul 2025 15:39:47 +0300 Subject: [PATCH 05/25] fix code replacement tests --- codeflash/models/models.py | 7 +++-- codeflash/optimization/function_optimizer.py | 12 +++++--- tests/test_code_replacement.py | 32 ++++++++++++++------ 3 files changed, 35 insertions(+), 16 deletions(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index c69e5deaf..74c70c89c 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -169,8 +169,11 @@ def markdown(self) -> str: ] ) + def path_to_code_string(self) -> dict[str, str]: + return {code_string.file_path: code_string.code for code_string in self.code_strings} + @staticmethod - def from_str_with_markers(code_with_markers: str) -> list[CodeString]: + def from_str_with_markers(code_with_markers: str) -> CodeStringsMarkdown: pattern = rf"{SPLITTER_MARKER}([^\n]+)\n" matches = list(re.finditer(pattern, code_with_markers)) @@ -181,7 +184,7 @@ def from_str_with_markers(code_with_markers: str) -> list[CodeString]: file_path = match.group(1).strip() code = code_with_markers[start:end].lstrip("\n") results.append(CodeString(file_path=file_path, code=code)) - return results + return CodeStringsMarkdown(code_strings=results) class CodeOptimizationContext(BaseModel): diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index a6770c532..528e60071 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -621,18 +621,22 @@ def replace_function_and_helpers_with_optimized_code( read_writable_functions_by_file_path[self.function_to_optimize.file_path].add( self.function_to_optimize.qualified_name ) - code_strings = CodeStringsMarkdown.from_str_with_markers(optimized_code) - optimized_code_dict = {code_string.file_path: code_string.code for code_string in code_strings} - logger.debug(f"Optimized code: {optimized_code_dict}") + file_to_code_context = CodeStringsMarkdown.from_str_with_markers(optimized_code).path_to_code_string() + logger.debug(f"Optimized code: {file_to_code_context}") for helper_function in code_context.helper_functions: if helper_function.jedi_definition.type != "class": read_writable_functions_by_file_path[helper_function.file_path].add(helper_function.qualified_name) for module_abspath, qualified_names in read_writable_functions_by_file_path.items(): relative_module_path = module_abspath.relative_to(self.project_root) logger.debug(f"applying optimized code to: {relative_module_path}") + + optimized_code = file_to_code_context.get(relative_module_path) + if not optimized_code: + msg = f"Optimized code not found for {relative_module_path}, existing files in the context are: {list(file_to_code_context.keys())}, re-check your 'split markers'" + raise ValueError(msg) did_update |= replace_function_definitions_in_module( function_names=list(qualified_names), - optimized_code=optimized_code_dict.get(relative_module_path), + optimized_code=optimized_code, module_abspath=module_abspath, preexisting_objects=code_context.preexisting_objects, project_root_path=self.project_root, diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 7272163d3..c845f7879 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -13,7 +13,7 @@ replace_functions_in_file, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import CodeOptimizationContext, FunctionParent +from codeflash.models.models import CodeOptimizationContext, FunctionParent, get_code_block_splitter from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig @@ -41,11 +41,14 @@ class Args: def test_code_replacement_global_statements(): - optimized_code = """import numpy as np + project_root = Path(__file__).parent.parent.resolve() + code_path = (project_root / "code_to_optimize/bubble_sort_optimized.py").resolve() + optimized_code = f"""{get_code_block_splitter(code_path.relative_to(project_root))} +import numpy as np + inconsequential_var = '123' def sorter(arr): return arr.sort()""" - code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_optimized.py").resolve() original_code_str = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").read_text( encoding="utf-8" ) @@ -1666,6 +1669,9 @@ def new_function2(value): def test_global_reassignment() -> None: + root_dir = Path(__file__).parent.parent.resolve() + code_path = (root_dir / "code_to_optimize/global_var_original.py").resolve() + original_code = """a=1 print("Hello world") def some_fn(): @@ -1678,7 +1684,9 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) """ - optimized_code = """import numpy as np + optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))} +import numpy as np + def some_fn(): a=np.zeros(10) print("did something") @@ -1713,7 +1721,6 @@ def __call__(self, value): return "I am still old" def new_function2(value): return cst.ensure_type(value, str)""" - code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/global_var_original.py").resolve() code_path.write_text(original_code, encoding="utf-8") tests_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/tests/pytest/") project_root_path = (Path(__file__).parent / "..").resolve() @@ -1753,7 +1760,8 @@ def new_function2(value): return cst.ensure_type(value, str) a=1 """ - optimized_code = """a=2 + optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))} +a=2 import numpy as np def some_fn(): a=np.zeros(10) @@ -1829,7 +1837,8 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) """ - optimized_code = """import numpy as np + optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))} +import numpy as np a=2 def some_fn(): a=np.zeros(10) @@ -1906,7 +1915,8 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) """ - optimized_code = """a=2 + optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))} +a=2 import numpy as np def some_fn(): a=np.zeros(10) @@ -1982,7 +1992,8 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) """ - optimized_code = """import numpy as np + optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))} +import numpy as np a=2 def some_fn(): a=np.zeros(10) @@ -2062,7 +2073,8 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) """ - optimized_code = """import numpy as np + optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))} +import numpy as np if 1<2: a=2 else: From 886616f0e4e2558913caa55e4acf039a87d2b113 Mon Sep 17 00:00:00 2001 From: mohammed Date: Fri, 25 Jul 2025 16:24:54 +0300 Subject: [PATCH 06/25] fix unused helper tests --- codeflash/models/models.py | 2 +- codeflash/optimization/function_optimizer.py | 15 +++--- tests/test_unused_helper_revert.py | 49 +++++++++++++------- 3 files changed, 43 insertions(+), 23 deletions(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 74c70c89c..5272e21bc 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -170,7 +170,7 @@ def markdown(self) -> str: ) def path_to_code_string(self) -> dict[str, str]: - return {code_string.file_path: code_string.code for code_string in self.code_strings} + return {str(code_string.file_path): code_string.code for code_string in self.code_strings} @staticmethod def from_str_with_markers(code_with_markers: str) -> CodeStringsMarkdown: diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 528e60071..1a08bb66e 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -627,16 +627,19 @@ def replace_function_and_helpers_with_optimized_code( if helper_function.jedi_definition.type != "class": read_writable_functions_by_file_path[helper_function.file_path].add(helper_function.qualified_name) for module_abspath, qualified_names in read_writable_functions_by_file_path.items(): - relative_module_path = module_abspath.relative_to(self.project_root) + relative_module_path = str(module_abspath.relative_to(self.project_root)) logger.debug(f"applying optimized code to: {relative_module_path}") - optimized_code = file_to_code_context.get(relative_module_path) - if not optimized_code: - msg = f"Optimized code not found for {relative_module_path}, existing files in the context are: {list(file_to_code_context.keys())}, re-check your 'split markers'" - raise ValueError(msg) + scoped_optimized_code = file_to_code_context.get(relative_module_path, None) + if scoped_optimized_code is None: + logger.warning( + f"Optimized code not found for {relative_module_path}, existing files in the context are: {list(file_to_code_context.keys())}, re-check your 'split markers'" + ) + scoped_optimized_code = "" + did_update |= replace_function_definitions_in_module( function_names=list(qualified_names), - optimized_code=optimized_code, + optimized_code=scoped_optimized_code, module_abspath=module_abspath, preexisting_objects=code_context.preexisting_objects, project_root_path=self.project_root, diff --git a/tests/test_unused_helper_revert.py b/tests/test_unused_helper_revert.py index 5642cd618..90a809ccc 100644 --- a/tests/test_unused_helper_revert.py +++ b/tests/test_unused_helper_revert.py @@ -6,6 +6,7 @@ import pytest from codeflash.context.unused_definition_remover import detect_unused_helper_functions from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.models.models import get_code_block_splitter from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig @@ -55,7 +56,8 @@ def test_detect_unused_helper_functions(temp_project): temp_dir, main_file, test_cfg = temp_project # Optimized version that only calls one helper - optimized_code = """ + optimized_code = f""" +{get_code_block_splitter("main.py")} def entrypoint_function(n): \"\"\"Optimized function that only calls one helper.\"\"\" result1 = helper_function_1(n) @@ -99,7 +101,8 @@ def helper_function_2(x): # Also test the complete replace_function_and_helpers_with_optimized_code workflow # First modify the optimized code to include a MODIFIED unused helper - optimized_code_with_modified_helper = """ + optimized_code_with_modified_helper = f""" +{get_code_block_splitter("main.py")} def entrypoint_function(n): \"\"\"Optimized function that only calls one helper.\"\"\" result1 = helper_function_1(n) @@ -159,7 +162,8 @@ def test_revert_unused_helper_functions(temp_project): temp_dir, main_file, test_cfg = temp_project # Optimized version that only calls one helper and modifies the unused one - optimized_code = """ + optimized_code = f""" +{get_code_block_splitter("main.py") } def entrypoint_function(n): \"\"\"Optimized function that only calls one helper.\"\"\" result1 = helper_function_1(n) @@ -221,7 +225,8 @@ def test_no_unused_helpers_no_revert(temp_project): temp_dir, main_file, test_cfg = temp_project # Optimized version that still calls both helpers - optimized_code = """ + optimized_code = f""" +{get_code_block_splitter("main.py")} def entrypoint_function(n): \"\"\"Optimized function that still calls both helpers.\"\"\" result1 = helper_function_1(n) @@ -303,13 +308,16 @@ def helper_function_2(x): """) # Optimized version that only calls one helper - optimized_code = """ + optimized_code = f""" +{get_code_block_splitter("main.py")} from helpers import helper_function_1 def entrypoint_function(n): \"\"\"Optimized function that only calls one helper.\"\"\" result1 = helper_function_1(n) return result1 + n * 3 # Inlined helper_function_2 + +{get_code_block_splitter("helpers.py")} """ # Create test config @@ -384,7 +392,6 @@ def helper_function_2(x): # Apply optimization and test reversion optimizer.replace_function_and_helpers_with_optimized_code(code_context, optimized_code, original_helper_code) - # Check main file content main_content = main_file.read_text() assert "result1 + n * 3" in main_content, "Entrypoint function should be optimized" @@ -478,7 +485,8 @@ def helper_method_2(self, x): """) # Optimized version that only calls one helper method - optimized_code = """ + optimized_code = f""" +{get_code_block_splitter("main.py") } class Calculator: def entrypoint_method(self, n): \"\"\"Optimized method that only calls one helper.\"\"\" @@ -537,7 +545,8 @@ def helper_method_2(self, x): # Also test the complete replace_function_and_helpers_with_optimized_code workflow # Update optimized code to include a MODIFIED unused helper - optimized_code_with_modified_helper = """ + optimized_code_with_modified_helper = f""" +{get_code_block_splitter("main.py")} class Calculator: def entrypoint_method(self, n): \"\"\"Optimized method that only calls one helper.\"\"\" @@ -619,7 +628,8 @@ def process_data(self, n): """) # Optimized version that only calls one external helper - optimized_code = """ + optimized_code = f""" +{get_code_block_splitter("main.py") } def external_helper_1(x): \"\"\"External helper function.\"\"\" return x * 2 @@ -678,7 +688,8 @@ def process_data(self, n): # Also test the complete replace_function_and_helpers_with_optimized_code workflow # Update optimized code to include a MODIFIED unused helper - optimized_code_with_modified_helper = """ + optimized_code_with_modified_helper = f""" +{get_code_block_splitter("main.py")} def external_helper_1(x): \"\"\"External helper function.\"\"\" return x * 2 @@ -716,7 +727,8 @@ def process_data(self, n): # Also test the complete replace_function_and_helpers_with_optimized_code workflow # Update optimized code to include a MODIFIED unused helper - optimized_code_with_modified_helper = """ + optimized_code_with_modified_helper = f""" +{get_code_block_splitter("main.py")} def external_helper_1(x): \"\"\"External helper function.\"\"\" return x * 2 @@ -786,7 +798,8 @@ def local_helper(self, x): """) # Optimized version that inlines one helper - optimized_code = """ + optimized_code = f""" +{get_code_block_splitter("main.py") } def global_helper_1(x): return x * 2 @@ -954,7 +967,8 @@ def clean_data(x): """) # Optimized version that only uses some functions - optimized_code = """ + optimized_code = f""" +{get_code_block_splitter("main.py") } import utils from math_helpers import add @@ -1115,7 +1129,8 @@ def divide_numbers(x, y): """) # Optimized version that only uses add_numbers - optimized_code = """ + optimized_code = f""" +{get_code_block_splitter("main.py") } import calculator def entrypoint_function(n): @@ -1317,7 +1332,8 @@ def calculate_class(cls, n): """) # Optimized static method that inlines one utility - optimized_static_code = """ + optimized_static_code = f""" +{get_code_block_splitter("main.py")} def utility_function_1(x): return x * 2 @@ -1384,7 +1400,8 @@ def calculate_class(cls, n): # Also test the complete replace_function_and_helpers_with_optimized_code workflow # Update optimized code to include a MODIFIED unused helper - optimized_static_code_with_modified_helper = """ + optimized_static_code_with_modified_helper = f""" +{get_code_block_splitter("main.py")} def utility_function_1(x): return x * 2 From d3e5e6f49e6d83a841ac456e2382fe42fe0f95d1 Mon Sep 17 00:00:00 2001 From: mohammed Date: Fri, 25 Jul 2025 16:34:00 +0300 Subject: [PATCH 07/25] fix for python 3.9 --- codeflash/models/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 5272e21bc..0891242d6 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -148,7 +148,7 @@ def get_code_block_splitter(file_path: Path) -> str: class CodeStringsMarkdown(BaseModel): code_strings: list[CodeString] = [] - cached_code: str | None = None + cached_code: Optional[str] = None @property def __str__(self) -> str: From f48c77df38c09ed1abade07badddace4c819b83d Mon Sep 17 00:00:00 2001 From: mohammed Date: Fri, 25 Jul 2025 17:50:13 +0300 Subject: [PATCH 08/25] flat method rename --- codeflash/context/code_context_extractor.py | 4 +-- codeflash/models/models.py | 8 ++--- codeflash/optimization/function_optimizer.py | 10 +++--- tests/test_code_context_extractor.py | 34 ++++++++++---------- 4 files changed, 28 insertions(+), 28 deletions(-) diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index c86b8650c..97befcb4a 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -85,14 +85,14 @@ def get_code_optimization_context( ) # Handle token limits - final_read_writable_tokens = encoded_tokens_len(final_read_writable_code.__str__) + final_read_writable_tokens = encoded_tokens_len(final_read_writable_code.flat) if final_read_writable_tokens > optim_token_limit: raise ValueError("Read-writable code has exceeded token limit, cannot proceed") # Setup preexisting objects for code replacer preexisting_objects = set( chain( - find_preexisting_objects(final_read_writable_code.__str__), + find_preexisting_objects(final_read_writable_code.flat), *(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings), ) ) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 0891242d6..30499bfcc 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -139,11 +139,11 @@ class CodeString(BaseModel): file_path: Optional[Path] = None -SPLITTER_MARKER = "# codeflash-splitter__" +LINE_SPLITTER_MARKER_PREFIX = "# codeflash-splitter__" def get_code_block_splitter(file_path: Path) -> str: - return f"{SPLITTER_MARKER}{file_path}" + return f"{LINE_SPLITTER_MARKER_PREFIX}{file_path}" class CodeStringsMarkdown(BaseModel): @@ -151,7 +151,7 @@ class CodeStringsMarkdown(BaseModel): cached_code: Optional[str] = None @property - def __str__(self) -> str: + def flat(self) -> str: if self.cached_code is not None: return self.cached_code self.cached_code = "\n".join( @@ -174,7 +174,7 @@ def path_to_code_string(self) -> dict[str, str]: @staticmethod def from_str_with_markers(code_with_markers: str) -> CodeStringsMarkdown: - pattern = rf"{SPLITTER_MARKER}([^\n]+)\n" + pattern = rf"{LINE_SPLITTER_MARKER_PREFIX}([^\n]+)\n" matches = list(re.finditer(pattern, code_with_markers)) results = [] diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 1a08bb66e..85152748c 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -165,7 +165,7 @@ def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[P helper_code = f.read() original_helper_code[helper_function_path] = helper_code - if has_any_async_functions(code_context.read_writable_code.__str__): + if has_any_async_functions(code_context.read_writable_code.flat): return Failure("Codeflash does not support async functions in the code to optimize.") # Random here means that we still attempt optimization with a fractional chance to see if # last time we could not find an optimization, maybe this time we do. @@ -284,7 +284,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() - code_print(code_context.read_writable_code.__str__) + code_print(code_context.read_writable_code.flat) test_setup_result = self.generate_and_instrument_tests( # also generates optimizations code_context, should_run_experiment=should_run_experiment @@ -376,7 +376,7 @@ def determine_best_candidate( ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client future_line_profile_results = executor.submit( ai_service_client.optimize_python_code_line_profiler, - source_code=code_context.read_writable_code.__str__, + source_code=code_context.read_writable_code.flat, dependency_code=code_context.read_only_context_code, trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, line_profiler_results=original_code_baseline.line_profile_results["str_out"], @@ -790,7 +790,7 @@ def generate_tests_and_optimizations( ) future_optimization_candidates = executor.submit( self.aiservice_client.optimize_python_code, - read_writable_code.__str__, + read_writable_code.flat, read_only_context_code, self.function_trace_id[:-4] + "EXP0" if run_experiment else self.function_trace_id, N_CANDIDATES, @@ -809,7 +809,7 @@ def generate_tests_and_optimizations( if run_experiment: future_candidates_exp = executor.submit( self.local_aiservice_client.optimize_python_code, - read_writable_code.__str__, + read_writable_code.flat, read_only_context_code, self.function_trace_id[:-4] + "EXP1", N_CANDIDATES, diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index fc76a68ff..346e02f39 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -126,7 +126,7 @@ def main_method(self): ``` """ - assert read_write_context.__str__.strip() == expected_read_write_context.strip() + assert read_write_context.flat.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -200,7 +200,7 @@ def topologicalSort(self): ``` """ - assert read_write_context.__str__.strip() == expected_read_write_context.strip() + assert read_write_context.flat.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -260,7 +260,7 @@ def sort_from_another_file(arr): return sorted_arr ``` """ - assert read_write_context.__str__.strip() == expected_read_write_context.strip() + assert read_write_context.flat.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -647,7 +647,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: return self.__backend__.get_cache_or_call(func=self.__wrapped__, args=args, kwargs=kwargs, lifespan=self.__duration__) ``` """ - assert read_write_context.__str__.strip() == expected_read_write_context.strip() + assert read_write_context.flat.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -740,7 +740,7 @@ def helper_method(self): ``` """ - assert read_write_context.__str__.strip() == expected_read_write_context.strip() + assert read_write_context.flat.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -836,7 +836,7 @@ def helper_method(self): return self.x ``` """ - assert read_write_context.__str__.strip() == expected_read_write_context.strip() + assert read_write_context.flat.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -923,7 +923,7 @@ def helper_method(self): return self.x ``` """ - assert read_write_context.__str__.strip() == expected_read_write_context.strip() + assert read_write_context.flat.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1118,7 +1118,7 @@ def fetch_and_process_data(): return processed ``` """ - assert read_write_context.__str__.strip() == expected_read_write_context.strip() + assert read_write_context.flat.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1217,7 +1217,7 @@ def fetch_and_transform_data(): return transformed ``` """ - assert read_write_context.__str__.strip() == expected_read_write_context.strip() + assert read_write_context.flat.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1297,7 +1297,7 @@ def transform_data_own_method(self, data: str) -> str: ``` """ - assert read_write_context.__str__.strip() == expected_read_write_context.strip() + assert read_write_context.flat.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1372,7 +1372,7 @@ def transform_data_same_file_function(self, data: str) -> str: ``` """ - assert read_write_context.__str__.strip() == expected_read_write_context.strip() + assert read_write_context.flat.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1434,7 +1434,7 @@ def update_data(data): ``` """ - assert read_write_context.__str__.strip() == expected_read_write_context.strip() + assert read_write_context.flat.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1510,7 +1510,7 @@ def circular_dependency(self, data): ``` """ - assert read_write_context.__str__.strip() == expected_read_write_context.strip() + assert read_write_context.flat.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1576,7 +1576,7 @@ def target_method(self): return self.x + self.y ``` """ - assert read_write_context.__str__.strip() == expected_read_write_context.strip() + assert read_write_context.flat.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1658,7 +1658,7 @@ def fetch_and_transform_data(): def function_to_optimize(): return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data() """ - assert read_write_context.__str__.strip() == expected_read_write_context.strip() + assert read_write_context.flat.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1902,7 +1902,7 @@ def calculate(self, operation, x, y): ``` """ # Verify the contexts match the expected values - assert read_write_context.__str__.strip() == expected_read_write_context.strip() + assert read_write_context.flat.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -2113,7 +2113,7 @@ def __init__(self, precision="high", fallback_precision=None, mode="standard"): CALCULATION_BACKEND = "python" ``` """ - assert read_write_context.__str__.strip() == expected_read_write_context.strip() + assert read_write_context.flat.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() From 57f8af0860ae486b565e7144956898d00b128a6d Mon Sep 17 00:00:00 2001 From: mohammed Date: Sat, 26 Jul 2025 07:24:20 +0300 Subject: [PATCH 09/25] test multifile replacement --- tests/test_multi_file_code_replacement.py | 170 ++++++++++++++++++++++ tests/test_unused_helper_revert.py | 2 - 2 files changed, 170 insertions(+), 2 deletions(-) create mode 100644 tests/test_multi_file_code_replacement.py diff --git a/tests/test_multi_file_code_replacement.py b/tests/test_multi_file_code_replacement.py new file mode 100644 index 000000000..e2d437206 --- /dev/null +++ b/tests/test_multi_file_code_replacement.py @@ -0,0 +1,170 @@ +from pathlib import Path +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.models.models import CodeOptimizationContext, get_code_block_splitter +from codeflash.optimization.function_optimizer import FunctionOptimizer +from codeflash.verification.verification_utils import TestConfig + + +class Args: + disable_imports_sorting = True + formatter_cmds = ["disabled"] + +def test_multi_file_replcement01() -> None: + root_dir = Path(__file__).parent.parent.resolve() + helper_file = (root_dir / "code_to_optimize/temp_helper.py").resolve() + + helper_file.write_text("""import re +from collections.abc import Sequence + +from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent + +def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int: + if not content: + return 0 + + if isinstance(content, str): + return len(_TOKEN_SPLIT_RE.split(content.strip())) + + tokens = 0 + for part in content: + if isinstance(part, str): + tokens += len(_TOKEN_SPLIT_RE.split(part.strip())) + elif isinstance(part, BinaryContent): + tokens += len(part.data) + # TODO(Marcelo): We need to study how we can estimate the tokens for AudioUrl or ImageUrl. + + return tokens + + +_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+') +""", encoding="utf-8") + + main_file = (root_dir / "code_to_optimize/temp_main.py").resolve() + + original_main = """from temp_helper import _estimate_string_tokens +from pydantic_ai_slim.pydantic_ai.usage import Usage + +def _get_string_usage(text: str) -> Usage: + response_tokens = _estimate_string_tokens(text) + return Usage(response_tokens=response_tokens, total_tokens=response_tokens) +""" + main_file.write_text(original_main, encoding="utf-8") + + optimized_code = f"""{get_code_block_splitter(helper_file.relative_to(root_dir))} +import re +from collections.abc import Sequence + +from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent + +# Compile regex once, as in original +_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+') + +# Precompute translation table for fast token splitting for string input +# This covers the chars: whitespace (\\x09-\\x0d, space), " (0x22), , (0x2c), +# Map those codepoints to ' ' +_translate_table = {{ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}} + +def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int: + if not content: + return 0 + + if isinstance(content, str): + # Fast path using translate and split instead of regex when separat + s = content.strip() + if s: + s = s.translate(_translate_table) + # Split on whitespace (default). This handles multiple consecut + return len(s.split()) + return 0 + + tokens = 0 + for part in content: + if isinstance(part, str): + s = part.strip() + if s: + s = s.translate(_translate_table) + tokens += len(s.split()) + elif isinstance(part, BinaryContent): + tokens += len(part.data) + + return tokens + +{get_code_block_splitter(main_file.relative_to(root_dir))} +from temp_helper import _estimate_string_tokens +from pydantic_ai_slim.pydantic_ai.usage import Usage + +def _get_string_usage(text: str) -> Usage: + response_tokens = _estimate_string_tokens(text) + return Usage(response_tokens=response_tokens, total_tokens=response_tokens) +""" + + + + func = FunctionToOptimize(function_name="_get_string_usage", parents=[], file_path=main_file) + test_config = TestConfig( + tests_root=root_dir / "tests/pytest", + tests_project_rootdir=root_dir, + project_root_path=root_dir, + test_framework="pytest", + pytest_cmd="pytest", + ) + func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) + code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() + + + + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + + func_optimizer.args = Args() + func_optimizer.replace_function_and_helpers_with_optimized_code( + code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code + ) + new_code = main_file.read_text(encoding="utf-8") + new_helper_code = helper_file.read_text(encoding="utf-8") + + helper_file.unlink(missing_ok=True) + main_file.unlink(missing_ok=True) + + expected_helper = """import re +from collections.abc import Sequence + +from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent + +def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int: + if not content: + return 0 + + if isinstance(content, str): + # Fast path using translate and split instead of regex when separat + s = content.strip() + if s: + s = s.translate(_translate_table) + # Split on whitespace (default). This handles multiple consecut + return len(s.split()) + return 0 + + tokens = 0 + for part in content: + if isinstance(part, str): + s = part.strip() + if s: + s = s.translate(_translate_table) + tokens += len(s.split()) + elif isinstance(part, BinaryContent): + tokens += len(part.data) + + return tokens + + +_TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+') + +_translate_table = {ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'} +""" + + assert new_code.rstrip() == original_main.rstrip() # No Change + assert new_helper_code.rstrip() == expected_helper.rstrip() \ No newline at end of file diff --git a/tests/test_unused_helper_revert.py b/tests/test_unused_helper_revert.py index 90a809ccc..a8f6f59ec 100644 --- a/tests/test_unused_helper_revert.py +++ b/tests/test_unused_helper_revert.py @@ -316,8 +316,6 @@ def entrypoint_function(n): \"\"\"Optimized function that only calls one helper.\"\"\" result1 = helper_function_1(n) return result1 + n * 3 # Inlined helper_function_2 - -{get_code_block_splitter("helpers.py")} """ # Create test config From 654a6ec2512ce04c0529894a5d0e86491e35b0bf Mon Sep 17 00:00:00 2001 From: mohammed Date: Sat, 26 Jul 2025 08:49:23 +0300 Subject: [PATCH 10/25] refactoring --- .../code_directories/circular_deps/constants.py | 6 ------ codeflash/models/models.py | 11 ++++------- codeflash/optimization/function_optimizer.py | 9 ++++++--- tests/test_code_context_extractor.py | 4 ++-- tests/test_multi_file_code_replacement.py | 5 ----- 5 files changed, 12 insertions(+), 23 deletions(-) diff --git a/code_to_optimize/code_directories/circular_deps/constants.py b/code_to_optimize/code_directories/circular_deps/constants.py index dc4b0638e..be8fdac15 100644 --- a/code_to_optimize/code_directories/circular_deps/constants.py +++ b/code_to_optimize/code_directories/circular_deps/constants.py @@ -1,8 +1,2 @@ DEFAULT_API_URL = "https://api.galileo.ai/" DEFAULT_APP_URL = "https://app.galileo.ai/" - - -# function_names: GalileoApiClient.get_console_url -# module_abs_path : /home/mohammed/Work/galileo-python/src/galileo/api_client.py -# preexisting_objects: {('GalileoApiClient', ()), ('_set_destination', ()), ('get_console_url', (FunctionParent(name='GalileoApiClient', type='ClassDef'),))} -# project_root_path: /home/mohammed/Work/galileo-python/src diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 30499bfcc..af57a97b9 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -169,22 +169,19 @@ def markdown(self) -> str: ] ) - def path_to_code_string(self) -> dict[str, str]: - return {str(code_string.file_path): code_string.code for code_string in self.code_strings} - @staticmethod - def from_str_with_markers(code_with_markers: str) -> CodeStringsMarkdown: + def parse_splitter_markers(code_with_markers: str) -> dict[str, str]: pattern = rf"{LINE_SPLITTER_MARKER_PREFIX}([^\n]+)\n" matches = list(re.finditer(pattern, code_with_markers)) - results = [] + results = {} for i, match in enumerate(matches): start = match.end() end = matches[i + 1].start() if i + 1 < len(matches) else len(code_with_markers) file_path = match.group(1).strip() code = code_with_markers[start:end].lstrip("\n") - results.append(CodeString(file_path=file_path, code=code)) - return CodeStringsMarkdown(code_strings=results) + results[file_path] = code + return results class CodeOptimizationContext(BaseModel): diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 85152748c..1ee8eded1 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -621,11 +621,13 @@ def replace_function_and_helpers_with_optimized_code( read_writable_functions_by_file_path[self.function_to_optimize.file_path].add( self.function_to_optimize.qualified_name ) - file_to_code_context = CodeStringsMarkdown.from_str_with_markers(optimized_code).path_to_code_string() - logger.debug(f"Optimized code: {file_to_code_context}") + + file_to_code_context = CodeStringsMarkdown.parse_splitter_markers(optimized_code) + for helper_function in code_context.helper_functions: if helper_function.jedi_definition.type != "class": read_writable_functions_by_file_path[helper_function.file_path].add(helper_function.qualified_name) + for module_abspath, qualified_names in read_writable_functions_by_file_path.items(): relative_module_path = str(module_abspath.relative_to(self.project_root)) logger.debug(f"applying optimized code to: {relative_module_path}") @@ -633,7 +635,8 @@ def replace_function_and_helpers_with_optimized_code( scoped_optimized_code = file_to_code_context.get(relative_module_path, None) if scoped_optimized_code is None: logger.warning( - f"Optimized code not found for {relative_module_path}, existing files in the context are: {list(file_to_code_context.keys())}, re-check your 'split markers'" + f"Optimized code not found for {relative_module_path} In the context\n-------\n{optimized_code}\n-------\n" + "Existing files in the context are: {list(file_to_code_context.keys())}, re-check your 'split markers'" ) scoped_optimized_code = "" diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 346e02f39..627b1755c 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -2449,8 +2449,8 @@ def simple_method(self): assert "return 42" in code_content - -def test_replace_functions_and_add_imports(): +# This shouldn't happen as we are now using a scoped optimization context, but keep it just in case +def test_circular_deps(): path_to_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "circular_deps" file_abs_path = path_to_root / "api_client.py" optimized_code = Path(path_to_root / "optimized.py").read_text(encoding="utf-8") diff --git a/tests/test_multi_file_code_replacement.py b/tests/test_multi_file_code_replacement.py index e2d437206..a53655c6a 100644 --- a/tests/test_multi_file_code_replacement.py +++ b/tests/test_multi_file_code_replacement.py @@ -56,12 +56,7 @@ def _get_string_usage(text: str) -> Usage: from pydantic_ai_slim.pydantic_ai.messages import BinaryContent, UserContent -# Compile regex once, as in original _TOKEN_SPLIT_RE = re.compile(r'[\\s",.:]+') - -# Precompute translation table for fast token splitting for string input -# This covers the chars: whitespace (\\x09-\\x0d, space), " (0x22), , (0x2c), -# Map those codepoints to ' ' _translate_table = {{ord(c): ord(' ') for c in ' \\t\\n\\r\\x0b\\x0c",.:'}} def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int: From 86913d46eec8bc35aeaa8d8a896e21c2fbd40e1d Mon Sep 17 00:00:00 2001 From: mohammed Date: Wed, 30 Jul 2025 01:12:47 +0300 Subject: [PATCH 11/25] flatten the context for refinement changes --- codeflash/optimization/function_optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 71cfb006e..3fe5cea6f 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -578,7 +578,7 @@ def determine_best_candidate( runtimes_list = [] for valid_opt in self.valid_optimizations: diff_lens_list.append( - diff_length(valid_opt.candidate.source_code, code_context.read_writable_code) + diff_length(valid_opt.candidate.source_code, code_context.read_writable_code.flat) ) # char level diff runtimes_list.append(valid_opt.runtime) diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list) @@ -610,7 +610,7 @@ def refine_optimizations( request = [ AIServiceRefinerRequest( optimization_id=opt.candidate.optimization_id, - original_source_code=code_context.read_writable_code, + original_source_code=code_context.read_writable_code.flat, read_only_dependency_code=code_context.read_only_context_code, original_code_runtime=humanize_runtime(original_code_baseline.runtime), optimized_source_code=opt.candidate.source_code, From 84324f849101e6ae8a22377aae8b82e95918643c Mon Sep 17 00:00:00 2001 From: mohammed Date: Fri, 1 Aug 2025 20:55:04 +0300 Subject: [PATCH 12/25] markdown multi context Signed-off-by: mohammed --- codeflash/api/aiservice.py | 11 ++++-- codeflash/code_utils/code_replacer.py | 2 + codeflash/code_utils/formatter.py | 5 ++- codeflash/lsp/beta.py | 4 +- codeflash/models/models.py | 9 +++-- codeflash/optimization/function_optimizer.py | 39 ++++++++++++-------- tests/test_formatter.py | 2 +- 7 files changed, 43 insertions(+), 29 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 4ec4f3aec..a3962397f 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -13,7 +13,7 @@ from codeflash.code_utils.env_utils import get_codeflash_api_key, is_LSP_enabled from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name from codeflash.models.ExperimentMetadata import ExperimentMetadata -from codeflash.models.models import AIServiceRefinerRequest, OptimizedCandidate +from codeflash.models.models import AIServiceRefinerRequest, CodeStringsMarkdown, OptimizedCandidate from codeflash.telemetry.posthog_cf import ph from codeflash.version import __version__ as codeflash_version @@ -73,6 +73,9 @@ def make_ai_service_request( url = f"{self.base_url}/ai{endpoint}" if method.upper() == "POST": json_payload = json.dumps(payload, indent=None, default=pydantic_encoder) + print(f"------------------------JSON PAYLOAD for {url}--------------------") + print(json_payload) + print("-------------------END OF JSON PAYLOAD--------------------") headers = {**self.headers, "Content-Type": "application/json"} response = requests.post(url, data=json_payload, headers=headers, timeout=timeout) else: @@ -136,7 +139,7 @@ def optimize_python_code( # noqa: D417 logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.") return [ OptimizedCandidate( - source_code=opt["source_code"], + source_code=CodeStringsMarkdown.parse_splitter_markers(opt["source_code"]), explanation=opt["explanation"], optimization_id=opt["optimization_id"], ) @@ -206,7 +209,7 @@ def optimize_python_code_line_profiler( # noqa: D417 console.rule() return [ OptimizedCandidate( - source_code=opt["source_code"], + source_code=CodeStringsMarkdown.parse_splitter_markers(opt["source_code"]), explanation=opt["explanation"], optimization_id=opt["optimization_id"], ) @@ -263,7 +266,7 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest] console.rule() return [ OptimizedCandidate( - source_code=opt["source_code"], + source_code=CodeStringsMarkdown.parse_splitter_markers(opt["source_code"]), explanation=opt["explanation"], optimization_id=opt["optimization_id"][:-4] + "refi", ) diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 3c73c5919..cd7a0b384 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -4,6 +4,7 @@ from collections import defaultdict from functools import lru_cache from typing import TYPE_CHECKING, Optional, TypeVar +from warnings import deprecated import isort import libcst as cst @@ -432,6 +433,7 @@ def is_zero_diff(original_code: str, new_code: str) -> bool: return normalize_code(original_code) == normalize_code(new_code) +@deprecated("") def replace_optimized_code( callee_module_paths: set[Path], candidates: list[OptimizedCandidate], diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index cd8c2efb3..4cdcea8f0 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -104,7 +104,7 @@ def is_diff_line(line: str) -> bool: def format_code( formatter_cmds: list[str], path: Union[str, Path], - optimized_function: str = "", + optimized_code: str = "", check_diff: bool = False, # noqa print_status: bool = True, # noqa exit_on_failure: bool = True, # noqa @@ -121,7 +121,8 @@ def format_code( if check_diff and original_code_lines > 50: # we dont' count the formatting diff for the optimized function as it should be well-formatted - original_code_without_opfunc = original_code.replace(optimized_function, "") + # TODO: This is not correct, optimized_code is not continuous, Think of a better way for doing this. + original_code_without_opfunc = original_code.replace(optimized_code, "") original_temp = Path(test_dir_str) / "original_temp.py" original_temp.write_text(original_code_without_opfunc, encoding="utf8") diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index 2fc91f0a0..d549f826c 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -168,7 +168,7 @@ def generate_tests(server: CodeflashLanguageServer, params: FunctionOptimization generated_test.generated_original_test_source for generated_test in generated_tests_list.generated_tests ] optimizations_dict = { - candidate.optimization_id: {"source_code": candidate.source_code, "explanation": candidate.explanation} + candidate.optimization_id: {"source_code": candidate.source_code.flat, "explanation": candidate.explanation} for candidate in optimizations_set.control + optimizations_set.experiment } @@ -276,7 +276,7 @@ def perform_function_optimization( # noqa: PLR0911 "message": f"No best optimizations found for function {function_to_optimize_qualified_name}", } - optimized_source = best_optimization.candidate.source_code + optimized_source = best_optimization.candidate.source_code.flat speedup = original_code_baseline.runtime / best_optimization.runtime server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info") diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 997e99aa4..71aeaa6e6 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -157,6 +157,7 @@ class CodeString(BaseModel): file_path: Optional[Path] = None +# Used to split files by adding a marker at the start of each file followed by the file path. LINE_SPLITTER_MARKER_PREFIX = "# codeflash-splitter__" @@ -188,17 +189,17 @@ def markdown(self) -> str: ) @staticmethod - def parse_splitter_markers(code_with_markers: str) -> dict[str, str]: + def parse_splitter_markers(code_with_markers: str) -> CodeStringsMarkdown: pattern = rf"{LINE_SPLITTER_MARKER_PREFIX}([^\n]+)\n" matches = list(re.finditer(pattern, code_with_markers)) - results = {} + results = CodeStringsMarkdown() for i, match in enumerate(matches): start = match.end() end = matches[i + 1].start() if i + 1 < len(matches) else len(code_with_markers) file_path = match.group(1).strip() code = code_with_markers[start:end].lstrip("\n") - results[file_path] = code + results.code_strings.append(CodeString(code=code, file_path=Path(file_path))) return results @@ -303,7 +304,7 @@ class TestsInFile: @dataclass(frozen=True) class OptimizedCandidate: - source_code: str + source_code: CodeStringsMarkdown explanation: str optimization_id: str diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 20517543e..414badf12 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -62,6 +62,7 @@ from codeflash.either import Failure, Success, is_successful from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( + LINE_SPLITTER_MARKER_PREFIX, BestOptimization, CodeOptimizationContext, CodeStringsMarkdown, @@ -216,7 +217,7 @@ def generate_and_instrument_tests( revert_to_print=bool(get_pr_number()), ): generated_results = self.generate_tests_and_optimizations( - testgen_context_code=code_context.testgen_context_code, + testgen_context_code=code_context.testgen_context_code, # TODO: should we send the markdow context for the testgen instead. read_writable_code=code_context.read_writable_code, read_only_context_code=code_context.read_only_context_code, helper_functions=code_context.helper_functions, @@ -289,7 +290,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() - code_print(code_context.read_writable_code.flat) + code_print(code_context.read_writable_code.flat) # Should we print the markdown or the flattened code? test_setup_result = self.generate_and_instrument_tests( # also generates optimizations code_context, should_run_experiment=should_run_experiment @@ -414,11 +415,11 @@ def determine_best_candidate( get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) logger.info(f"Optimization candidate {candidate_index}/{original_len}:") - code_print(candidate.source_code) + code_print(candidate.source_code.flat) try: did_update = self.replace_function_and_helpers_with_optimized_code( code_context=code_context, - optimized_code=candidate.source_code, + optimized_code=candidate.source_code.flat, original_helper_code=original_helper_code, ) if not did_update: @@ -578,7 +579,7 @@ def determine_best_candidate( runtimes_list = [] for valid_opt in self.valid_optimizations: diff_lens_list.append( - diff_length(valid_opt.candidate.source_code, code_context.read_writable_code.flat) + diff_length(valid_opt.candidate.source_code.flat, code_context.read_writable_code.flat) ) # char level diff runtimes_list.append(valid_opt.runtime) diff_lens_ranking = create_rank_dictionary_compact(diff_lens_list) @@ -613,7 +614,7 @@ def refine_optimizations( original_source_code=code_context.read_writable_code.flat, read_only_dependency_code=code_context.read_only_context_code, original_code_runtime=humanize_runtime(original_code_baseline.runtime), - optimized_source_code=opt.candidate.source_code, + optimized_source_code=opt.candidate.source_code.flat, optimized_explanation=opt.candidate.explanation, optimized_code_runtime=humanize_runtime(opt.runtime), speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=opt.runtime) * 100)}%", @@ -679,13 +680,13 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, f.write(helper_code) def reformat_code_and_helpers( - self, helper_functions: list[FunctionSource], path: Path, original_code: str, optimized_function: str + self, helper_functions: list[FunctionSource], path: Path, original_code: str, optimized_code: str ) -> tuple[str, dict[Path, str]]: should_sort_imports = not self.args.disable_imports_sorting if should_sort_imports and isort.code(original_code) != original_code: should_sort_imports = False - new_code = format_code(self.args.formatter_cmds, path, optimized_function=optimized_function, check_diff=True) + new_code = format_code(self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True) if should_sort_imports: new_code = sort_imports(new_code) @@ -694,7 +695,7 @@ def reformat_code_and_helpers( module_abspath = hp.file_path hp_source_code = hp.source_code formatted_helper_code = format_code( - self.args.formatter_cmds, module_abspath, optimized_function=hp_source_code, check_diff=True + self.args.formatter_cmds, module_abspath, optimized_code=hp_source_code, check_diff=True ) if should_sort_imports: formatted_helper_code = sort_imports(formatted_helper_code) @@ -711,7 +712,8 @@ def replace_function_and_helpers_with_optimized_code( self.function_to_optimize.qualified_name ) - file_to_code_context = CodeStringsMarkdown.parse_splitter_markers(optimized_code) + code_strings = CodeStringsMarkdown.parse_splitter_markers(optimized_code).code_strings + file_to_code_context = {str(code_string.file_path): code_string.code for code_string in code_strings} for helper_function in code_context.helper_functions: if helper_function.jedi_definition.type != "class": @@ -721,11 +723,12 @@ def replace_function_and_helpers_with_optimized_code( relative_module_path = str(module_abspath.relative_to(self.project_root)) logger.debug(f"applying optimized code to: {relative_module_path}") - scoped_optimized_code = file_to_code_context.get(relative_module_path, None) + scoped_optimized_code = file_to_code_context.get(relative_module_path) if scoped_optimized_code is None: logger.warning( f"Optimized code not found for {relative_module_path} In the context\n-------\n{optimized_code}\n-------\n" "Existing files in the context are: {list(file_to_code_context.keys())}, re-check your 'split markers'" + f"existing files are {file_to_code_context.keys()}" ) scoped_optimized_code = "" @@ -1063,7 +1066,7 @@ def find_and_process_best_optimization( if best_optimization: logger.info("Best candidate:") - code_print(best_optimization.candidate.source_code) + code_print(best_optimization.candidate.source_code.flat) console.print( Panel( best_optimization.candidate.explanation, title="Best Candidate Explanation", border_style="blue" @@ -1089,7 +1092,7 @@ def find_and_process_best_optimization( self.replace_function_and_helpers_with_optimized_code( code_context=code_context, - optimized_code=best_optimization.candidate.source_code, + optimized_code=best_optimization.candidate.source_code.flat, original_helper_code=original_helper_code, ) @@ -1097,7 +1100,7 @@ def find_and_process_best_optimization( code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code, - optimized_function=best_optimization.candidate.source_code, + optimized_code=best_optimization.candidate.source_code.flat, ) original_code_combined = original_helper_code.copy() @@ -1169,10 +1172,14 @@ def process_review( optimized_runtimes_all=optimized_runtime_by_test, ) new_explanation_raw_str = self.aiservice_client.get_new_explanation( - source_code=code_context.read_writable_code, + source_code=code_context.read_writable_code.flat.replace( + LINE_SPLITTER_MARKER_PREFIX, "# file: " + ), # for better readability to the LLM dependency_code=code_context.read_only_context_code, trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, - optimized_code=best_optimization.candidate.source_code, + optimized_code=best_optimization.candidate.source_code.flat.replace( + LINE_SPLITTER_MARKER_PREFIX, "# file: " + ), original_line_profiler_results=original_code_baseline.line_profile_results["str_out"], optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"], original_code_runtime=humanize_runtime(original_code_baseline.runtime), diff --git a/tests/test_formatter.py b/tests/test_formatter.py index fbd7d0b9d..c407cd0ec 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -263,7 +263,7 @@ def _run_formatting_test(source_code: str, should_content_change: bool, expected helper_functions=[], path=target_path, original_code=optimizer.function_to_optimize_source_code, - optimized_function=optimized_function, + optimized_code=optimized_function, ) content = target_path.read_text(encoding="utf8") From a81b1cc0bbc736699e37d87c206b1efea0c7ee9a Mon Sep 17 00:00:00 2001 From: mohammed Date: Fri, 1 Aug 2025 21:37:36 +0300 Subject: [PATCH 13/25] fix import issue Signed-off-by: mohammed --- codeflash/code_utils/code_replacer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index cd7a0b384..3c73c5919 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -4,7 +4,6 @@ from collections import defaultdict from functools import lru_cache from typing import TYPE_CHECKING, Optional, TypeVar -from warnings import deprecated import isort import libcst as cst @@ -433,7 +432,6 @@ def is_zero_diff(original_code: str, new_code: str) -> bool: return normalize_code(original_code) == normalize_code(new_code) -@deprecated("") def replace_optimized_code( callee_module_paths: set[Path], candidates: list[OptimizedCandidate], From 6b5c4a5643618e8d684c1f752f692f4074982bea Mon Sep 17 00:00:00 2001 From: mohammed Date: Mon, 4 Aug 2025 03:00:50 +0300 Subject: [PATCH 14/25] change the splitter marker --- codeflash/models/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 71aeaa6e6..51440d18e 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -158,7 +158,7 @@ class CodeString(BaseModel): # Used to split files by adding a marker at the start of each file followed by the file path. -LINE_SPLITTER_MARKER_PREFIX = "# codeflash-splitter__" +LINE_SPLITTER_MARKER_PREFIX = "# --codeflash:file--" def get_code_block_splitter(file_path: Path) -> str: From 3eee162d41165c233122383ece83e9ed92180000 Mon Sep 17 00:00:00 2001 From: mohammed Date: Mon, 4 Aug 2025 15:35:28 +0300 Subject: [PATCH 15/25] fix markdown context for formatting and more refactoring --- codeflash/api/aiservice.py | 12 ++++---- codeflash/code_utils/formatter.py | 1 - codeflash/models/models.py | 28 +++++++++++------- codeflash/optimization/function_optimizer.py | 31 +++++++++++++------- tests/test_code_replacement.py | 16 +++++----- tests/test_formatter.py | 10 +++++-- tests/test_multi_file_code_replacement.py | 4 +-- tests/test_unused_helper_revert.py | 31 ++++++++++---------- 8 files changed, 77 insertions(+), 56 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index a3962397f..d7c934fb6 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -73,9 +73,9 @@ def make_ai_service_request( url = f"{self.base_url}/ai{endpoint}" if method.upper() == "POST": json_payload = json.dumps(payload, indent=None, default=pydantic_encoder) - print(f"------------------------JSON PAYLOAD for {url}--------------------") - print(json_payload) - print("-------------------END OF JSON PAYLOAD--------------------") + # print(f"------------------------JSON PAYLOAD for {url}--------------------") + # print(json_payload) + # print("-------------------END OF JSON PAYLOAD--------------------") headers = {**self.headers, "Content-Type": "application/json"} response = requests.post(url, data=json_payload, headers=headers, timeout=timeout) else: @@ -139,7 +139,7 @@ def optimize_python_code( # noqa: D417 logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.") return [ OptimizedCandidate( - source_code=CodeStringsMarkdown.parse_splitter_markers(opt["source_code"]), + source_code=CodeStringsMarkdown.parse_flattened_code(opt["source_code"]), explanation=opt["explanation"], optimization_id=opt["optimization_id"], ) @@ -209,7 +209,7 @@ def optimize_python_code_line_profiler( # noqa: D417 console.rule() return [ OptimizedCandidate( - source_code=CodeStringsMarkdown.parse_splitter_markers(opt["source_code"]), + source_code=CodeStringsMarkdown.parse_flattened_code(opt["source_code"]), explanation=opt["explanation"], optimization_id=opt["optimization_id"], ) @@ -266,7 +266,7 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest] console.rule() return [ OptimizedCandidate( - source_code=CodeStringsMarkdown.parse_splitter_markers(opt["source_code"]), + source_code=CodeStringsMarkdown.parse_flattened_code(opt["source_code"]), explanation=opt["explanation"], optimization_id=opt["optimization_id"][:-4] + "refi", ) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 4cdcea8f0..db7fa4257 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -121,7 +121,6 @@ def format_code( if check_diff and original_code_lines > 50: # we dont' count the formatting diff for the optimized function as it should be well-formatted - # TODO: This is not correct, optimized_code is not continuous, Think of a better way for doing this. original_code_without_opfunc = original_code.replace(optimized_code, "") original_temp = Path(test_dir_str) / "original_temp.py" diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 51440d18e..9a3c302ef 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -19,7 +19,7 @@ from typing import Annotated, Optional, cast from jedi.api.classes import Name -from pydantic import AfterValidator, BaseModel, ConfigDict +from pydantic import AfterValidator, BaseModel, ConfigDict, PrivateAttr from pydantic.dataclasses import dataclass from codeflash.cli_cmds.console import console, logger @@ -167,16 +167,16 @@ def get_code_block_splitter(file_path: Path) -> str: class CodeStringsMarkdown(BaseModel): code_strings: list[CodeString] = [] - cached_code: Optional[str] = None + _cache: dict = PrivateAttr(default_factory=dict) @property def flat(self) -> str: - if self.cached_code is not None: - return self.cached_code - self.cached_code = "\n".join( + if self._cache.get("flat") is not None: + return self._cache["flat"] + self._cache["flat"] = "\n".join( get_code_block_splitter(block.file_path) + "\n" + block.code for block in self.code_strings ) - return self.cached_code + return self._cache["flat"] @property def markdown(self) -> str: @@ -188,17 +188,25 @@ def markdown(self) -> str: ] ) + def file_to_path(self) -> dict[str, str]: + if self._cache.get("file_to_path") is not None: + return self._cache["file_to_path"] + self._cache["file_to_path"] = { + str(code_string.file_path): code_string.code for code_string in self.code_strings + } + return self._cache["file_to_path"] + @staticmethod - def parse_splitter_markers(code_with_markers: str) -> CodeStringsMarkdown: + def parse_flattened_code(flat_code: str) -> CodeStringsMarkdown: pattern = rf"{LINE_SPLITTER_MARKER_PREFIX}([^\n]+)\n" - matches = list(re.finditer(pattern, code_with_markers)) + matches = list(re.finditer(pattern, flat_code)) results = CodeStringsMarkdown() for i, match in enumerate(matches): start = match.end() - end = matches[i + 1].start() if i + 1 < len(matches) else len(code_with_markers) + end = matches[i + 1].start() if i + 1 < len(matches) else len(flat_code) file_path = match.group(1).strip() - code = code_with_markers[start:end].lstrip("\n") + code = flat_code[start:end].lstrip("\n") results.code_strings.append(CodeString(code=code, file_path=Path(file_path))) return results diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 414badf12..663b124cd 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -65,7 +65,6 @@ LINE_SPLITTER_MARKER_PREFIX, BestOptimization, CodeOptimizationContext, - CodeStringsMarkdown, GeneratedTests, GeneratedTestsList, OptimizationSet, @@ -97,6 +96,7 @@ from codeflash.either import Result from codeflash.models.models import ( BenchmarkKey, + CodeStringsMarkdown, CoverageData, FunctionCalledInTest, FunctionSource, @@ -419,7 +419,7 @@ def determine_best_candidate( try: did_update = self.replace_function_and_helpers_with_optimized_code( code_context=code_context, - optimized_code=candidate.source_code.flat, + optimized_code=candidate.source_code, original_helper_code=original_helper_code, ) if not did_update: @@ -680,12 +680,21 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, f.write(helper_code) def reformat_code_and_helpers( - self, helper_functions: list[FunctionSource], path: Path, original_code: str, optimized_code: str + self, + helper_functions: list[FunctionSource], + path: Path, + original_code: str, + optimized_context: CodeStringsMarkdown, ) -> tuple[str, dict[Path, str]]: should_sort_imports = not self.args.disable_imports_sorting if should_sort_imports and isort.code(original_code) != original_code: should_sort_imports = False + optimized_code = "" + if optimized_context is not None: + file_to_code_context = optimized_context.file_to_path() + optimized_code = file_to_code_context.get(str(path.relative_to(self.project_root)), "") + new_code = format_code(self.args.formatter_cmds, path, optimized_code=optimized_code, check_diff=True) if should_sort_imports: new_code = sort_imports(new_code) @@ -704,7 +713,7 @@ def reformat_code_and_helpers( return new_code, new_helper_code def replace_function_and_helpers_with_optimized_code( - self, code_context: CodeOptimizationContext, optimized_code: str, original_helper_code: str + self, code_context: CodeOptimizationContext, optimized_code: CodeStringsMarkdown, original_helper_code: str ) -> bool: did_update = False read_writable_functions_by_file_path = defaultdict(set) @@ -712,8 +721,7 @@ def replace_function_and_helpers_with_optimized_code( self.function_to_optimize.qualified_name ) - code_strings = CodeStringsMarkdown.parse_splitter_markers(optimized_code).code_strings - file_to_code_context = {str(code_string.file_path): code_string.code for code_string in code_strings} + file_to_code_context = optimized_code.file_to_path() for helper_function in code_context.helper_functions: if helper_function.jedi_definition.type != "class": @@ -739,7 +747,7 @@ def replace_function_and_helpers_with_optimized_code( preexisting_objects=code_context.preexisting_objects, project_root_path=self.project_root, ) - unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code) + unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code.flat) # Revert unused helper functions to their original definitions if unused_helpers: @@ -1092,7 +1100,7 @@ def find_and_process_best_optimization( self.replace_function_and_helpers_with_optimized_code( code_context=code_context, - optimized_code=best_optimization.candidate.source_code.flat, + optimized_code=best_optimization.candidate.source_code, original_helper_code=original_helper_code, ) @@ -1100,7 +1108,7 @@ def find_and_process_best_optimization( code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code, - optimized_code=best_optimization.candidate.source_code.flat, + optimized_context=best_optimization.candidate.source_code, ) original_code_combined = original_helper_code.copy() @@ -1173,8 +1181,9 @@ def process_review( ) new_explanation_raw_str = self.aiservice_client.get_new_explanation( source_code=code_context.read_writable_code.flat.replace( - LINE_SPLITTER_MARKER_PREFIX, "# file: " - ), # for better readability to the LLM + LINE_SPLITTER_MARKER_PREFIX, + "# file: ", # for better readability + ), dependency_code=code_context.read_only_context_code, trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, optimized_code=best_optimization.candidate.source_code.flat.replace( diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index c845f7879..4f49dfc28 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -13,7 +13,7 @@ replace_functions_in_file, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import CodeOptimizationContext, FunctionParent, get_code_block_splitter +from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent, get_code_block_splitter from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig @@ -73,7 +73,7 @@ def sorter(arr): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code ) final_output = code_path.read_text(encoding="utf-8") assert "inconsequential_var = '123'" in final_output @@ -1742,7 +1742,7 @@ def new_function2(value): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) @@ -1819,7 +1819,7 @@ def new_function2(value): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) @@ -1897,7 +1897,7 @@ def new_function2(value): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) @@ -1974,7 +1974,7 @@ def new_function2(value): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) @@ -2052,7 +2052,7 @@ def new_function2(value): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) @@ -2141,7 +2141,7 @@ def new_function2(value): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index c407cd0ec..5afc4630e 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -10,6 +10,7 @@ from codeflash.code_utils.formatter import format_code, sort_imports from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.models.models import CodeString, CodeStringsMarkdown from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig @@ -263,7 +264,12 @@ def _run_formatting_test(source_code: str, should_content_change: bool, expected helper_functions=[], path=target_path, original_code=optimizer.function_to_optimize_source_code, - optimized_code=optimized_function, + optimized_context=CodeStringsMarkdown(code_strings=[ + CodeString( + code=optimized_function, + file_path=target_path.relative_to(test_dir) + ) + ]), ) content = target_path.read_text(encoding="utf8") @@ -796,7 +802,7 @@ def _is_valid(self, item): return isinstance(item, dict) and "id" in item ''' - optimization_function = """ def process(self,data): + optimization_function = """def process(self,data): '''Single quote docstring with formatting issues.''' return{'result':[item for item in data if self._is_valid(item)]}""" _run_formatting_test(source_code, True, optimized_function=optimization_function, expected=expected) \ No newline at end of file diff --git a/tests/test_multi_file_code_replacement.py b/tests/test_multi_file_code_replacement.py index a53655c6a..90355d243 100644 --- a/tests/test_multi_file_code_replacement.py +++ b/tests/test_multi_file_code_replacement.py @@ -1,6 +1,6 @@ from pathlib import Path from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import CodeOptimizationContext, get_code_block_splitter +from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, get_code_block_splitter from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig @@ -117,7 +117,7 @@ def _get_string_usage(text: str) -> Usage: func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code ) new_code = main_file.read_text(encoding="utf-8") new_helper_code = helper_file.read_text(encoding="utf-8") diff --git a/tests/test_unused_helper_revert.py b/tests/test_unused_helper_revert.py index a8f6f59ec..8a121eb17 100644 --- a/tests/test_unused_helper_revert.py +++ b/tests/test_unused_helper_revert.py @@ -6,7 +6,7 @@ import pytest from codeflash.context.unused_definition_remover import detect_unused_helper_functions from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import get_code_block_splitter +from codeflash.models.models import CodeStringsMarkdown, get_code_block_splitter from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig @@ -121,9 +121,8 @@ def helper_function_2(x): # Apply optimization and test reversion optimizer.replace_function_and_helpers_with_optimized_code( - code_context, optimized_code_with_modified_helper, original_helper_code + code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code_with_modified_helper), original_helper_code ) - # Check final file content final_content = main_file.read_text() @@ -141,7 +140,7 @@ def helper_function_2(x): original_helper_code = {main_file: main_file.read_text()} # Apply optimization and test reversion - optimizer.replace_function_and_helpers_with_optimized_code(code_context, optimized_code, original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code) # Check final file content final_content = main_file.read_text() @@ -204,7 +203,7 @@ def helper_function_2(x): # 1. Apply the optimization # 2. Detect unused helpers # 3. Revert unused helpers to original definitions - optimizer.replace_function_and_helpers_with_optimized_code(code_context, optimized_code, original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code) # Check final file content final_content = main_file.read_text() @@ -268,7 +267,7 @@ def helper_function_2(x): assert len(unused_helpers) == 0, "No helpers should be detected as unused" # Apply optimization - optimizer.replace_function_and_helpers_with_optimized_code(code_context, optimized_code, original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code) # Check final file content - should contain the optimized versions final_content = main_file.read_text() @@ -389,7 +388,7 @@ def helper_function_2(x): } # Apply optimization and test reversion - optimizer.replace_function_and_helpers_with_optimized_code(code_context, optimized_code, original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code) # Check main file content main_content = main_file.read_text() assert "result1 + n * 3" in main_content, "Entrypoint function should be optimized" @@ -437,7 +436,7 @@ def helper_function_2(x): } # Apply optimization and test reversion - optimizer.replace_function_and_helpers_with_optimized_code(code_context, optimized_code, original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code) # Check main file content main_content = main_file.read_text() @@ -564,7 +563,7 @@ def helper_method_2(self, x): # Apply optimization and test reversion optimizer.replace_function_and_helpers_with_optimized_code( - code_context, optimized_code_with_modified_helper, original_helper_code + code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code_with_modified_helper), original_helper_code ) # Check final file content @@ -583,7 +582,7 @@ def helper_method_2(self, x): # Test reversion original_helper_code = {main_file: main_file.read_text()} - optimizer.replace_function_and_helpers_with_optimized_code(code_context, optimized_code, original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code) # Check final file content final_content = main_file.read_text() @@ -707,7 +706,7 @@ def process_data(self, n): # Apply optimization and test reversion optimizer.replace_function_and_helpers_with_optimized_code( - code_context, optimized_code_with_modified_helper, original_helper_code + code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code_with_modified_helper), original_helper_code ) # Check final file content @@ -746,7 +745,7 @@ def process_data(self, n): # Apply optimization and test reversion optimizer.replace_function_and_helpers_with_optimized_code( - code_context, optimized_code_with_modified_helper, original_helper_code + code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code_with_modified_helper), original_helper_code ) # Check final file content @@ -1067,7 +1066,7 @@ def subtract(x, y): } # Apply optimization and test reversion - optimizer.replace_function_and_helpers_with_optimized_code(code_context, optimized_code, original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code) # Check main file content main_content = main_file.read_text() @@ -1217,7 +1216,7 @@ def divide_numbers(x, y): } # Apply optimization and test reversion - optimizer.replace_function_and_helpers_with_optimized_code(code_context, optimized_code, original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code) # Check main file content main_content = main_file.read_text() @@ -1276,7 +1275,7 @@ def divide_numbers(x, y): } # Apply optimization and test reversion - optimizer.replace_function_and_helpers_with_optimized_code(code_context, optimized_code, original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code) # Check main file content main_content = main_file.read_text() @@ -1425,7 +1424,7 @@ def calculate_class(cls, n): # Apply optimization and test reversion optimizer.replace_function_and_helpers_with_optimized_code( - code_context, optimized_static_code_with_modified_helper, original_helper_code + code_context, CodeStringsMarkdown.parse_flattened_code(optimized_static_code_with_modified_helper), original_helper_code ) # Check final file content From b3bd888d66b39f53a160a774693faf5b4a5c13d1 Mon Sep 17 00:00:00 2001 From: mohammed Date: Mon, 4 Aug 2025 16:03:19 +0300 Subject: [PATCH 16/25] fix splitter pattern --- code_to_optimize/bubble_sort.py | 7 +------ codeflash/models/models.py | 2 +- codeflash/version.py | 2 +- 3 files changed, 3 insertions(+), 8 deletions(-) diff --git a/code_to_optimize/bubble_sort.py b/code_to_optimize/bubble_sort.py index 9e97f63a0..7dc644cde 100644 --- a/code_to_optimize/bubble_sort.py +++ b/code_to_optimize/bubble_sort.py @@ -1,10 +1,5 @@ def sorter(arr): print("codeflash stdout: Sorting list") - for i in range(len(arr)): - for j in range(len(arr) - 1): - if arr[j] > arr[j + 1]: - temp = arr[j] - arr[j] = arr[j + 1] - arr[j + 1] = temp + arr.sort() print(f"result: {arr}") return arr diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 9a3c302ef..071779343 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -198,7 +198,7 @@ def file_to_path(self) -> dict[str, str]: @staticmethod def parse_flattened_code(flat_code: str) -> CodeStringsMarkdown: - pattern = rf"{LINE_SPLITTER_MARKER_PREFIX}([^\n]+)\n" + pattern = rf"^{LINE_SPLITTER_MARKER_PREFIX}([^\n]+)\n" matches = list(re.finditer(pattern, flat_code)) results = CodeStringsMarkdown() diff --git a/codeflash/version.py b/codeflash/version.py index 386d0212a..73578d9dc 100644 --- a/codeflash/version.py +++ b/codeflash/version.py @@ -1,2 +1,2 @@ # These version placeholders will be replaced by uv-dynamic-versioning during build. -__version__ = "0.16.1" +__version__ = "0.16.0.post75.dev0+3eee162d" From b6e3c0df3e29dda617e9283f2f2b7919c5c16aa4 Mon Sep 17 00:00:00 2001 From: mohammed Date: Mon, 4 Aug 2025 16:29:17 +0300 Subject: [PATCH 17/25] fix unit tests --- code_to_optimize/bubble_sort.py | 8 +++++++- codeflash/models/models.py | 7 ++++--- codeflash/optimization/function_optimizer.py | 2 +- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/code_to_optimize/bubble_sort.py b/code_to_optimize/bubble_sort.py index 7dc644cde..74bf5cf85 100644 --- a/code_to_optimize/bubble_sort.py +++ b/code_to_optimize/bubble_sort.py @@ -1,5 +1,11 @@ def sorter(arr): print("codeflash stdout: Sorting list") - arr.sort() + for i in range(len(arr)): + for j in range(len(arr) - 1): + if arr[j] > arr[j + 1]: + temp = arr[j] + arr[j] = arr[j + 1] + arr[j + 1] = temp print(f"result: {arr}") return arr + \ No newline at end of file diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 071779343..1ba117a19 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -165,6 +165,9 @@ def get_code_block_splitter(file_path: Path) -> str: return f"{LINE_SPLITTER_MARKER_PREFIX}{file_path}" +splitter_pattern = re.compile(f"^{LINE_SPLITTER_MARKER_PREFIX}([^\n]+)\n", re.MULTILINE | re.DOTALL) + + class CodeStringsMarkdown(BaseModel): code_strings: list[CodeString] = [] _cache: dict = PrivateAttr(default_factory=dict) @@ -198,9 +201,7 @@ def file_to_path(self) -> dict[str, str]: @staticmethod def parse_flattened_code(flat_code: str) -> CodeStringsMarkdown: - pattern = rf"^{LINE_SPLITTER_MARKER_PREFIX}([^\n]+)\n" - matches = list(re.finditer(pattern, flat_code)) - + matches = list(splitter_pattern.finditer(flat_code)) results = CodeStringsMarkdown() for i, match in enumerate(matches): start = match.end() diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 663b124cd..f287e0b41 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -735,7 +735,7 @@ def replace_function_and_helpers_with_optimized_code( if scoped_optimized_code is None: logger.warning( f"Optimized code not found for {relative_module_path} In the context\n-------\n{optimized_code}\n-------\n" - "Existing files in the context are: {list(file_to_code_context.keys())}, re-check your 'split markers'" + "re-check your 'split markers'" f"existing files are {file_to_code_context.keys()}" ) scoped_optimized_code = "" From 307e6bbb0ab5d3b36edc03ac626b742fcc4b935a Mon Sep 17 00:00:00 2001 From: mohammed Date: Mon, 4 Aug 2025 18:43:30 +0300 Subject: [PATCH 18/25] revert unwanted changes --- code_to_optimize/bubble_sort.py | 3 +-- codeflash/version.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/code_to_optimize/bubble_sort.py b/code_to_optimize/bubble_sort.py index 74bf5cf85..787cc4a90 100644 --- a/code_to_optimize/bubble_sort.py +++ b/code_to_optimize/bubble_sort.py @@ -7,5 +7,4 @@ def sorter(arr): arr[j] = arr[j + 1] arr[j + 1] = temp print(f"result: {arr}") - return arr - \ No newline at end of file + return arr \ No newline at end of file diff --git a/codeflash/version.py b/codeflash/version.py index 73578d9dc..386d0212a 100644 --- a/codeflash/version.py +++ b/codeflash/version.py @@ -1,2 +1,2 @@ # These version placeholders will be replaced by uv-dynamic-versioning during build. -__version__ = "0.16.0.post75.dev0+3eee162d" +__version__ = "0.16.1" From f1874562dc382498fdbbc180beb637909fbda324 Mon Sep 17 00:00:00 2001 From: mohammed Date: Tue, 5 Aug 2025 03:07:40 +0300 Subject: [PATCH 19/25] cleanup --- codeflash/api/aiservice.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index d7c934fb6..ae1864988 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -73,9 +73,6 @@ def make_ai_service_request( url = f"{self.base_url}/ai{endpoint}" if method.upper() == "POST": json_payload = json.dumps(payload, indent=None, default=pydantic_encoder) - # print(f"------------------------JSON PAYLOAD for {url}--------------------") - # print(json_payload) - # print("-------------------END OF JSON PAYLOAD--------------------") headers = {**self.headers, "Content-Type": "application/json"} response = requests.post(url, data=json_payload, headers=headers, timeout=timeout) else: From 07a93659875943bd1714ddd7935902cbd4b2937c Mon Sep 17 00:00:00 2001 From: mohammed Date: Wed, 6 Aug 2025 01:09:42 +0300 Subject: [PATCH 20/25] send&recieve markdown code --- codeflash/api/aiservice.py | 9 ++++--- codeflash/models/models.py | 15 +++++------ codeflash/optimization/function_optimizer.py | 10 +++---- tests/test_code_replacement.py | 14 +++++----- tests/test_multi_file_code_replacement.py | 2 +- tests/test_unused_helper_revert.py | 28 ++++++++++---------- 6 files changed, 39 insertions(+), 39 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index ae1864988..236d7d98f 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -73,6 +73,9 @@ def make_ai_service_request( url = f"{self.base_url}/ai{endpoint}" if method.upper() == "POST": json_payload = json.dumps(payload, indent=None, default=pydantic_encoder) + print(f"========JSON PAYLOAD FOR {url}==============") + print(f"Payload: {json_payload}") + print("======================") headers = {**self.headers, "Content-Type": "application/json"} response = requests.post(url, data=json_payload, headers=headers, timeout=timeout) else: @@ -136,7 +139,7 @@ def optimize_python_code( # noqa: D417 logger.debug(f"Generating optimizations took {end_time - start_time:.2f} seconds.") return [ OptimizedCandidate( - source_code=CodeStringsMarkdown.parse_flattened_code(opt["source_code"]), + source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]), explanation=opt["explanation"], optimization_id=opt["optimization_id"], ) @@ -206,7 +209,7 @@ def optimize_python_code_line_profiler( # noqa: D417 console.rule() return [ OptimizedCandidate( - source_code=CodeStringsMarkdown.parse_flattened_code(opt["source_code"]), + source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]), explanation=opt["explanation"], optimization_id=opt["optimization_id"], ) @@ -263,7 +266,7 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest] console.rule() return [ OptimizedCandidate( - source_code=CodeStringsMarkdown.parse_flattened_code(opt["source_code"]), + source_code=CodeStringsMarkdown.parse_markdown_code(opt["source_code"]), explanation=opt["explanation"], optimization_id=opt["optimization_id"][:-4] + "refi", ) diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 1ba117a19..9e663959f 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -165,7 +165,7 @@ def get_code_block_splitter(file_path: Path) -> str: return f"{LINE_SPLITTER_MARKER_PREFIX}{file_path}" -splitter_pattern = re.compile(f"^{LINE_SPLITTER_MARKER_PREFIX}([^\n]+)\n", re.MULTILINE | re.DOTALL) +markdown_pattern = re.compile(r"```python:([^\n]+)\n(.*?)\n```", re.DOTALL) class CodeStringsMarkdown(BaseModel): @@ -200,15 +200,12 @@ def file_to_path(self) -> dict[str, str]: return self._cache["file_to_path"] @staticmethod - def parse_flattened_code(flat_code: str) -> CodeStringsMarkdown: - matches = list(splitter_pattern.finditer(flat_code)) + def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown: + matches = markdown_pattern.findall(markdown_code) results = CodeStringsMarkdown() - for i, match in enumerate(matches): - start = match.end() - end = matches[i + 1].start() if i + 1 < len(matches) else len(flat_code) - file_path = match.group(1).strip() - code = flat_code[start:end].lstrip("\n") - results.code_strings.append(CodeString(code=code, file_path=Path(file_path))) + for file_path, code in matches: + path = file_path.strip() + results.code_strings.append(CodeString(code=code, file_path=Path(path))) return results diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index e76b1b8b4..7c829ca8e 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -384,7 +384,7 @@ def determine_best_candidate( ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client future_line_profile_results = executor.submit( ai_service_client.optimize_python_code_line_profiler, - source_code=code_context.read_writable_code.flat, + source_code=code_context.read_writable_code.markdown, dependency_code=code_context.read_only_context_code, trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, line_profiler_results=original_code_baseline.line_profile_results["str_out"], @@ -611,10 +611,10 @@ def refine_optimizations( request = [ AIServiceRefinerRequest( optimization_id=opt.candidate.optimization_id, - original_source_code=code_context.read_writable_code.flat, + original_source_code=code_context.read_writable_code.markdown, read_only_dependency_code=code_context.read_only_context_code, original_code_runtime=humanize_runtime(original_code_baseline.runtime), - optimized_source_code=opt.candidate.source_code.flat, + optimized_source_code=opt.candidate.source_code.markdown, optimized_explanation=opt.candidate.explanation, optimized_code_runtime=humanize_runtime(opt.runtime), speedup=f"{int(performance_gain(original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=opt.runtime) * 100)}%", @@ -894,7 +894,7 @@ def generate_tests_and_optimizations( ) future_optimization_candidates = executor.submit( self.aiservice_client.optimize_python_code, - read_writable_code.flat, + read_writable_code.markdown, read_only_context_code, self.function_trace_id[:-4] + "EXP0" if run_experiment else self.function_trace_id, N_CANDIDATES, @@ -913,7 +913,7 @@ def generate_tests_and_optimizations( if run_experiment: future_candidates_exp = executor.submit( self.local_aiservice_client.optimize_python_code, - read_writable_code.flat, + read_writable_code.markdown, read_only_context_code, self.function_trace_id[:-4] + "EXP1", N_CANDIDATES, diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 4f49dfc28..28e6dc3d5 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -73,7 +73,7 @@ def sorter(arr): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code ) final_output = code_path.read_text(encoding="utf-8") assert "inconsequential_var = '123'" in final_output @@ -1742,7 +1742,7 @@ def new_function2(value): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) @@ -1819,7 +1819,7 @@ def new_function2(value): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) @@ -1897,7 +1897,7 @@ def new_function2(value): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) @@ -1974,7 +1974,7 @@ def new_function2(value): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) @@ -2052,7 +2052,7 @@ def new_function2(value): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) @@ -2141,7 +2141,7 @@ def new_function2(value): original_helper_code[helper_function_path] = helper_code func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code ) new_code = code_path.read_text(encoding="utf-8") code_path.unlink(missing_ok=True) diff --git a/tests/test_multi_file_code_replacement.py b/tests/test_multi_file_code_replacement.py index 90355d243..1edd7045b 100644 --- a/tests/test_multi_file_code_replacement.py +++ b/tests/test_multi_file_code_replacement.py @@ -117,7 +117,7 @@ def _get_string_usage(text: str) -> Usage: func_optimizer.args = Args() func_optimizer.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code=original_helper_code + code_context=code_context, optimized_code=CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code=original_helper_code ) new_code = main_file.read_text(encoding="utf-8") new_helper_code = helper_file.read_text(encoding="utf-8") diff --git a/tests/test_unused_helper_revert.py b/tests/test_unused_helper_revert.py index 8a121eb17..6a4a0827e 100644 --- a/tests/test_unused_helper_revert.py +++ b/tests/test_unused_helper_revert.py @@ -121,7 +121,7 @@ def helper_function_2(x): # Apply optimization and test reversion optimizer.replace_function_and_helpers_with_optimized_code( - code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code_with_modified_helper), original_helper_code + code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code_with_modified_helper), original_helper_code ) # Check final file content final_content = main_file.read_text() @@ -140,7 +140,7 @@ def helper_function_2(x): original_helper_code = {main_file: main_file.read_text()} # Apply optimization and test reversion - optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code) # Check final file content final_content = main_file.read_text() @@ -203,7 +203,7 @@ def helper_function_2(x): # 1. Apply the optimization # 2. Detect unused helpers # 3. Revert unused helpers to original definitions - optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code) # Check final file content final_content = main_file.read_text() @@ -267,7 +267,7 @@ def helper_function_2(x): assert len(unused_helpers) == 0, "No helpers should be detected as unused" # Apply optimization - optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code) # Check final file content - should contain the optimized versions final_content = main_file.read_text() @@ -388,7 +388,7 @@ def helper_function_2(x): } # Apply optimization and test reversion - optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code) # Check main file content main_content = main_file.read_text() assert "result1 + n * 3" in main_content, "Entrypoint function should be optimized" @@ -436,7 +436,7 @@ def helper_function_2(x): } # Apply optimization and test reversion - optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code) # Check main file content main_content = main_file.read_text() @@ -563,7 +563,7 @@ def helper_method_2(self, x): # Apply optimization and test reversion optimizer.replace_function_and_helpers_with_optimized_code( - code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code_with_modified_helper), original_helper_code + code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code_with_modified_helper), original_helper_code ) # Check final file content @@ -582,7 +582,7 @@ def helper_method_2(self, x): # Test reversion original_helper_code = {main_file: main_file.read_text()} - optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code) # Check final file content final_content = main_file.read_text() @@ -706,7 +706,7 @@ def process_data(self, n): # Apply optimization and test reversion optimizer.replace_function_and_helpers_with_optimized_code( - code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code_with_modified_helper), original_helper_code + code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code_with_modified_helper), original_helper_code ) # Check final file content @@ -745,7 +745,7 @@ def process_data(self, n): # Apply optimization and test reversion optimizer.replace_function_and_helpers_with_optimized_code( - code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code_with_modified_helper), original_helper_code + code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code_with_modified_helper), original_helper_code ) # Check final file content @@ -1066,7 +1066,7 @@ def subtract(x, y): } # Apply optimization and test reversion - optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code) # Check main file content main_content = main_file.read_text() @@ -1216,7 +1216,7 @@ def divide_numbers(x, y): } # Apply optimization and test reversion - optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code) # Check main file content main_content = main_file.read_text() @@ -1275,7 +1275,7 @@ def divide_numbers(x, y): } # Apply optimization and test reversion - optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_flattened_code(optimized_code), original_helper_code) + optimizer.replace_function_and_helpers_with_optimized_code(code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code), original_helper_code) # Check main file content main_content = main_file.read_text() @@ -1424,7 +1424,7 @@ def calculate_class(cls, n): # Apply optimization and test reversion optimizer.replace_function_and_helpers_with_optimized_code( - code_context, CodeStringsMarkdown.parse_flattened_code(optimized_static_code_with_modified_helper), original_helper_code + code_context, CodeStringsMarkdown.parse_markdown_code(optimized_static_code_with_modified_helper), original_helper_code ) # Check final file content From 989b1f30a2dc167736c1b629966918176ea02e24 Mon Sep 17 00:00:00 2001 From: mohammed Date: Wed, 6 Aug 2025 03:33:46 +0300 Subject: [PATCH 21/25] unit tests fixing --- codeflash/api/aiservice.py | 6 +- codeflash/code_utils/code_replacer.py | 22 ++- .../context/unused_definition_remover.py | 12 +- codeflash/models/models.py | 36 ++++- codeflash/optimization/function_optimizer.py | 18 +-- tests/test_code_context_extractor.py | 139 ++++++++++-------- tests/test_code_replacement.py | 36 +++-- tests/test_multi_file_code_replacement.py | 9 +- tests/test_unused_helper_revert.py | 95 +++++++----- 9 files changed, 221 insertions(+), 152 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 236d7d98f..65bc92bcc 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -73,9 +73,9 @@ def make_ai_service_request( url = f"{self.base_url}/ai{endpoint}" if method.upper() == "POST": json_payload = json.dumps(payload, indent=None, default=pydantic_encoder) - print(f"========JSON PAYLOAD FOR {url}==============") - print(f"Payload: {json_payload}") - print("======================") + logger.debug(f"========JSON PAYLOAD FOR {url}==============") + logger.debug(json_payload) + logger.debug("======================") headers = {**self.headers, "Content-Type": "application/json"} response = requests.post(url, data=json_payload, headers=headers, timeout=timeout) else: diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 3c73c5919..740e578b6 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -19,7 +19,7 @@ from pathlib import Path from codeflash.discovery.functions_to_optimize import FunctionToOptimize - from codeflash.models.models import CodeOptimizationContext, OptimizedCandidate, ValidCode + from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, OptimizedCandidate, ValidCode ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST) @@ -408,16 +408,17 @@ def replace_functions_and_add_imports( def replace_function_definitions_in_module( function_names: list[str], - optimized_code: str, + optimized_code: CodeStringsMarkdown, module_abspath: Path, preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]], project_root_path: Path, ) -> bool: source_code: str = module_abspath.read_text(encoding="utf8") + code_to_apply = get_optimized_code_for_module(module_abspath.relative_to(project_root_path), optimized_code) new_code: str = replace_functions_and_add_imports( - add_global_assignments(optimized_code, source_code), + add_global_assignments(code_to_apply, source_code), function_names, - optimized_code, + code_to_apply, module_abspath, preexisting_objects, project_root_path, @@ -428,6 +429,19 @@ def replace_function_definitions_in_module( return True +def get_optimized_code_for_module(relative_path: Path, optimized_code: CodeStringsMarkdown) -> str: + file_to_code_context = optimized_code.file_to_path() + module_optimized_code = file_to_code_context.get(str(relative_path)) + if module_optimized_code is None: + logger.warning( + f"Optimized code not found for {relative_path} In the context\n-------\n{optimized_code}\n-------\n" + "re-check your 'markdown code structure'" + f"existing files are {file_to_code_context.keys()}" + ) + module_optimized_code = "" + return module_optimized_code + + def is_zero_diff(original_code: str, new_code: str) -> bool: return normalize_code(original_code) == normalize_code(new_code) diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 977c72bd3..72586b5eb 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -3,16 +3,14 @@ import ast from collections import defaultdict from dataclasses import dataclass, field -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from pathlib import Path +from pathlib import Path from typing import TYPE_CHECKING, Optional import libcst as cst from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_replacer import replace_function_definitions_in_module +from codeflash.models.models import CodeString, CodeStringsMarkdown if TYPE_CHECKING: from codeflash.discovery.functions_to_optimize import FunctionToOptimize @@ -530,7 +528,11 @@ def revert_unused_helper_functions( helper_names = [helper.qualified_name for helper in helpers_in_file] reverted_code = replace_function_definitions_in_module( function_names=helper_names, - optimized_code=original_code, # Use original code as the "optimized" code to revert + optimized_code=CodeStringsMarkdown( + code_strings=[ + CodeString(code=original_code, file_path=Path(file_path).relative_to(project_root)) + ] + ), # Use original code as the "optimized" code to revert module_abspath=file_path, preexisting_objects=set(), # Empty set since we're reverting project_root_path=project_root, diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 9e663959f..13f539e45 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -174,6 +174,15 @@ class CodeStringsMarkdown(BaseModel): @property def flat(self) -> str: + """Returns the combined Python module from all code blocks. + + Each block is prefixed by a file path comment to indicate its origin. + This representation is syntactically valid Python code. + + Returns: + str: The concatenated code of all blocks with file path annotations. + + """ if self._cache.get("flat") is not None: return self._cache["flat"] self._cache["flat"] = "\n".join( @@ -183,7 +192,15 @@ def flat(self) -> str: @property def markdown(self) -> str: - """Returns the markdown representation of the code, including the file path where possible.""" + """Returns a Markdown-formatted string containing all code blocks. + + Each block is enclosed in a triple-backtick code block with an optional + file path suffix (e.g., ```python:filename.py). + + Returns: + str: Markdown representation of the code blocks. + + """ return "\n".join( [ f"```python{':' + str(code_string.file_path) if code_string.file_path else ''}\n{code_string.code.strip()}\n```" @@ -192,6 +209,12 @@ def markdown(self) -> str: ) def file_to_path(self) -> dict[str, str]: + """Return a dictionary mapping file paths to their corresponding code blocks. + + Returns: + dict[str, str]: Mapping from file path (as string) to code. + + """ if self._cache.get("file_to_path") is not None: return self._cache["file_to_path"] self._cache["file_to_path"] = { @@ -201,6 +224,17 @@ def file_to_path(self) -> dict[str, str]: @staticmethod def parse_markdown_code(markdown_code: str) -> CodeStringsMarkdown: + """Parse a Markdown string into a CodeStringsMarkdown object. + + Extracts code blocks and their associated file paths and constructs a new CodeStringsMarkdown instance. + + Args: + markdown_code (str): The Markdown-formatted string to parse. + + Returns: + CodeStringsMarkdown: Parsed object containing code blocks. + + """ matches = markdown_pattern.findall(markdown_code) results = CodeStringsMarkdown() for file_path, code in matches: diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 7c829ca8e..6e7bc5a9b 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -720,29 +720,13 @@ def replace_function_and_helpers_with_optimized_code( read_writable_functions_by_file_path[self.function_to_optimize.file_path].add( self.function_to_optimize.qualified_name ) - - file_to_code_context = optimized_code.file_to_path() - for helper_function in code_context.helper_functions: if helper_function.jedi_definition.type != "class": read_writable_functions_by_file_path[helper_function.file_path].add(helper_function.qualified_name) - for module_abspath, qualified_names in read_writable_functions_by_file_path.items(): - relative_module_path = str(module_abspath.relative_to(self.project_root)) - logger.debug(f"applying optimized code to: {relative_module_path}") - - scoped_optimized_code = file_to_code_context.get(relative_module_path) - if scoped_optimized_code is None: - logger.warning( - f"Optimized code not found for {relative_module_path} In the context\n-------\n{optimized_code}\n-------\n" - "re-check your 'split markers'" - f"existing files are {file_to_code_context.keys()}" - ) - scoped_optimized_code = "" - did_update |= replace_function_definitions_in_module( function_names=list(qualified_names), - optimized_code=scoped_optimized_code, + optimized_code=optimized_code, module_abspath=module_abspath, preexisting_objects=code_context.preexisting_objects, project_root_path=self.project_root, diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index 627b1755c..a9751eb6f 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -89,7 +89,7 @@ def test_code_replacement10() -> None: hashing_context = code_ctx.hashing_code_context expected_read_write_context = f""" -{get_code_block_splitter(file_path.relative_to(file_path.parent))} +```python:{file_path.relative_to(file_path.parent)} from __future__ import annotations class HelperClass: @@ -107,6 +107,7 @@ def __init__(self, name): def main_method(self): self.name = HelperClass.NestedClass("test").nested_method() return HelperClass(self.name).helper_method() +``` """ expected_read_only_context = """ """ @@ -126,7 +127,7 @@ def main_method(self): ``` """ - assert read_write_context.flat.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -147,7 +148,7 @@ def test_class_method_dependencies() -> None: hashing_context = code_ctx.hashing_code_context expected_read_write_context = f""" -{get_code_block_splitter(file_path.relative_to(file_path.parent))} +```python:{file_path.relative_to(file_path.parent)} from __future__ import annotations from collections import defaultdict @@ -175,7 +176,7 @@ def topologicalSort(self): # Print contents of stack return stack - +``` """ expected_read_only_context = "" @@ -200,7 +201,7 @@ def topologicalSort(self): ``` """ - assert read_write_context.flat.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -227,7 +228,7 @@ def test_bubble_sort_helper() -> None: hashing_context = code_ctx.hashing_code_context expected_read_write_context = f""" -{get_code_block_splitter("code_to_optimize/code_directories/retriever/bubble_sort_with_math.py")} +```python:code_to_optimize/code_directories/retriever/bubble_sort_with_math.py import math def sorter(arr): @@ -235,14 +236,14 @@ def sorter(arr): x = math.sqrt(2) print(x) return arr - -{get_code_block_splitter("code_to_optimize/code_directories/retriever/bubble_sort_imported.py")} +``` +```python:code_to_optimize/code_directories/retriever/bubble_sort_imported.py from bubble_sort_with_math import sorter def sort_from_another_file(arr): sorted_arr = sorter(arr) return sorted_arr - +``` """ expected_read_only_context = "" @@ -260,7 +261,7 @@ def sort_from_another_file(arr): return sorted_arr ``` """ - assert read_write_context.flat.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -458,7 +459,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context expected_read_write_context = f""" -{get_code_block_splitter(file_path.relative_to(opt.args.project_root))} +```python:{file_path.relative_to(opt.args.project_root)} class AbstractCacheBackend(CacheBackend, Protocol[_KEY_T, _STORE_T]): def __init__(self) -> None: ... @@ -553,7 +554,8 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: kwargs=kwargs, lifespan=self.__duration__, ) - """ +``` +""" expected_read_only_context = f''' ```python:{file_path.relative_to(opt.args.project_root)} _P = ParamSpec("_P") @@ -647,7 +649,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: return self.__backend__.get_cache_or_call(func=self.__wrapped__, args=args, kwargs=kwargs, lifespan=self.__duration__) ``` """ - assert read_write_context.flat.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -700,7 +702,7 @@ def helper_method(self): hashing_context = code_ctx.hashing_code_context expected_read_write_context = f""" -{get_code_block_splitter(file_path.relative_to(opt.args.project_root))} +```python:{file_path.relative_to(opt.args.project_root)} class MyClass: def __init__(self): self.x = 1 @@ -713,7 +715,8 @@ def __init__(self): self.x = 1 def helper_method(self): return self.x - """ +``` +""" expected_read_only_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} class MyClass: @@ -740,7 +743,7 @@ def helper_method(self): ``` """ - assert read_write_context.flat.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -798,7 +801,7 @@ def helper_method(self): hashing_context = code_ctx.hashing_code_context # In this scenario, the read-only code context is too long, so the read-only docstrings are removed. expected_read_write_context = f""" -{get_code_block_splitter(file_path.relative_to(opt.args.project_root))} +```python:{file_path.relative_to(opt.args.project_root)} class MyClass: def __init__(self): self.x = 1 @@ -812,7 +815,8 @@ def __init__(self): self.x = 1 def helper_method(self): return self.x - """ +``` +""" expected_read_only_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} class MyClass: @@ -836,7 +840,7 @@ def helper_method(self): return self.x ``` """ - assert read_write_context.flat.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -894,7 +898,7 @@ def helper_method(self): hashing_context = code_ctx.hashing_code_context # In this scenario, the read-only code context is too long even after removing docstrings, hence we remove it completely. expected_read_write_context = f""" -{get_code_block_splitter(file_path.relative_to(opt.args.project_root))} +```python:{file_path.relative_to(opt.args.project_root)} class MyClass: def __init__(self): self.x = 1 @@ -908,7 +912,8 @@ def __init__(self): self.x = 1 def helper_method(self): return self.x - """ +``` +""" expected_read_only_context = "" expected_hashing_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} @@ -923,7 +928,7 @@ def helper_method(self): return self.x ``` """ - assert read_write_context.flat.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1048,7 +1053,7 @@ def test_repo_helper() -> None: read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context expected_read_write_context = f""" -{get_code_block_splitter(path_to_utils.relative_to(project_root))} +```python:{path_to_utils.relative_to(project_root)} import math class DataProcessor: @@ -1065,8 +1070,8 @@ def process_data(self, raw_data: str) -> str: def add_prefix(self, data: str, prefix: str = "PREFIX_") -> str: \"\"\"Add a prefix to the processed data.\"\"\" return prefix + data - -{get_code_block_splitter(path_to_file.relative_to(project_root))} +``` +```python:{path_to_file.relative_to(project_root)} import requests from globals import API_URL from utils import DataProcessor @@ -1084,6 +1089,7 @@ def fetch_and_process_data(): processed = processor.add_prefix(processed) return processed +``` """ expected_read_only_context = f""" ```python:{path_to_utils.relative_to(project_root)} @@ -1118,7 +1124,7 @@ def fetch_and_process_data(): return processed ``` """ - assert read_write_context.flat.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1140,7 +1146,7 @@ def test_repo_helper_of_helper() -> None: read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context expected_read_write_context = f""" -{get_code_block_splitter(path_to_utils.relative_to(project_root))} +```python:{path_to_utils.relative_to(project_root)} import math from transform_utils import DataTransformer @@ -1158,8 +1164,8 @@ def process_data(self, raw_data: str) -> str: def transform_data(self, data: str) -> str: \"\"\"Transform the processed data\"\"\" return DataTransformer().transform(data) - -{get_code_block_splitter(path_to_file.relative_to(project_root))} +``` +```python:{path_to_file.relative_to(project_root)} import requests from globals import API_URL from utils import DataProcessor @@ -1176,8 +1182,8 @@ def fetch_and_transform_data(): transformed = processor.transform_data(processed) return transformed - - """ +``` +""" expected_read_only_context = f""" ```python:{path_to_utils.relative_to(project_root)} class DataProcessor: @@ -1217,7 +1223,7 @@ def fetch_and_transform_data(): return transformed ``` """ - assert read_write_context.flat.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1238,15 +1244,15 @@ def test_repo_helper_of_helper_same_class() -> None: read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context expected_read_write_context = f""" -{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))} +```python:{path_to_transform_utils.relative_to(project_root)} class DataTransformer: def __init__(self): self.data = None def transform_using_own_method(self, data): return self.transform(data) - -{get_code_block_splitter(path_to_utils.relative_to(project_root))} +``` +```python:{path_to_utils.relative_to(project_root)} import math from transform_utils import DataTransformer @@ -1260,7 +1266,7 @@ def __init__(self, default_prefix: str = "PREFIX_"): def transform_data_own_method(self, data: str) -> str: \"\"\"Transform the processed data using own method\"\"\" return DataTransformer().transform_using_own_method(data) - +``` """ expected_read_only_context = f""" ```python:{path_to_transform_utils.relative_to(project_root)} @@ -1297,7 +1303,7 @@ def transform_data_own_method(self, data: str) -> str: ``` """ - assert read_write_context.flat.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1318,15 +1324,15 @@ def test_repo_helper_of_helper_same_file() -> None: read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context expected_read_write_context = f""" -{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))} +```python:{path_to_transform_utils.relative_to(project_root)} class DataTransformer: def __init__(self): self.data = None def transform_using_same_file_function(self, data): return update_data(data) - -{get_code_block_splitter(path_to_utils.relative_to(project_root))} +``` +```python:{path_to_utils.relative_to(project_root)} import math from transform_utils import DataTransformer @@ -1340,7 +1346,8 @@ def __init__(self, default_prefix: str = "PREFIX_"): def transform_data_same_file_function(self, data: str) -> str: \"\"\"Transform the processed data using a function from the same file\"\"\" return DataTransformer().transform_using_same_file_function(data) - """ +``` +""" expected_read_only_context = f""" ```python:{path_to_transform_utils.relative_to(project_root)} def update_data(data): @@ -1372,7 +1379,7 @@ def transform_data_same_file_function(self, data: str) -> str: ``` """ - assert read_write_context.flat.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1392,7 +1399,7 @@ def test_repo_helper_all_same_file() -> None: read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context expected_read_write_context = f""" -{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))} +```python:{path_to_transform_utils.relative_to(project_root)} class DataTransformer: def __init__(self): self.data = None @@ -1407,7 +1414,8 @@ def transform_data_all_same_file(self, data): def update_data(data): return data + " updated" - """ +``` +""" expected_read_only_context = f""" ```python:{path_to_transform_utils.relative_to(project_root)} class DataTransformer: @@ -1434,7 +1442,7 @@ def update_data(data): ``` """ - assert read_write_context.flat.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1455,7 +1463,7 @@ def test_repo_helper_circular_dependency() -> None: read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context expected_read_write_context = f""" -{get_code_block_splitter(path_to_utils.relative_to(project_root))} +```python:{path_to_utils.relative_to(project_root)} import math from transform_utils import DataTransformer @@ -1469,8 +1477,8 @@ def __init__(self, default_prefix: str = "PREFIX_"): def circular_dependency(self, data: str) -> str: \"\"\"Test circular dependency\"\"\" return DataTransformer().circular_dependency(data) - -{get_code_block_splitter(path_to_transform_utils.relative_to(project_root))} +``` +```python:{path_to_transform_utils.relative_to(project_root)} from code_to_optimize.code_directories.retriever.utils import DataProcessor class DataTransformer: @@ -1479,9 +1487,8 @@ def __init__(self): def circular_dependency(self, data): return DataProcessor().circular_dependency(data) - - - """ +``` +""" expected_read_only_context = f""" ```python:{path_to_utils.relative_to(project_root)} class DataProcessor: @@ -1510,7 +1517,7 @@ def circular_dependency(self, data): ``` """ - assert read_write_context.flat.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1554,13 +1561,14 @@ def outside_method(): read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code hashing_context = code_ctx.hashing_code_context expected_read_write_context = f""" -{get_code_block_splitter(file_path.relative_to(opt.args.project_root))} +```python:{file_path.relative_to(opt.args.project_root)} class MyClass: def __init__(self): self.x = 1 self.y = outside_method() def target_method(self): return self.x + self.y +``` """ expected_read_only_context = f""" ```python:{file_path.relative_to(opt.args.project_root)} @@ -1576,7 +1584,7 @@ def target_method(self): return self.x + self.y ``` """ - assert read_write_context.flat.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1634,7 +1642,7 @@ def function_to_optimize(): ``` """ expected_read_write_context = f""" -{get_code_block_splitter(path_to_main.relative_to(project_root))} +```python:{path_to_main.relative_to(project_root)} import requests from globals import API_URL from utils import DataProcessor @@ -1651,14 +1659,15 @@ def fetch_and_transform_data(): transformed = processor.transform_data(processed) return transformed - -{get_code_block_splitter(path_to_fto.relative_to(project_root))} +``` +```python:{path_to_fto.relative_to(project_root)} import code_to_optimize.code_directories.retriever.main def function_to_optimize(): return code_to_optimize.code_directories.retriever.main.fetch_and_transform_data() +``` """ - assert read_write_context.flat.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -1818,7 +1827,7 @@ def get_system_details(): hashing_context = code_ctx.hashing_code_context # The expected contexts expected_read_write_context = f""" -{get_code_block_splitter(main_file_path.relative_to(opt.args.project_root))} +```python:{main_file_path.relative_to(opt.args.project_root)} import utility_module class Calculator: @@ -1845,6 +1854,7 @@ def calculate(self, operation, x, y): return self.subtract(x, y) else: return None +``` """ expected_read_only_context = """ ```python:utility_module.py @@ -1902,7 +1912,7 @@ def calculate(self, operation, x, y): ``` """ # Verify the contexts match the expected values - assert read_write_context.flat.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() assert hashing_context.strip() == expected_hashing_context.strip() @@ -2061,7 +2071,7 @@ def get_system_details(): read_write_context, read_only_context = code_ctx.read_writable_code, code_ctx.read_only_context_code # The expected contexts expected_read_write_context = f""" -{get_code_block_splitter("utility_module.py")} +```python:utility_module.py # Function that will be used in the main code def select_precision(precision, fallback_precision): @@ -2085,8 +2095,8 @@ def select_precision(precision, fallback_precision): return precision.lower() else: return DEFAULT_PRECISION - -{get_code_block_splitter(main_file_path.relative_to(opt.args.project_root))} +``` +```python:{main_file_path.relative_to(opt.args.project_root)} import utility_module class Calculator: @@ -2099,6 +2109,7 @@ def __init__(self, precision="high", fallback_precision=None, mode="standard"): self.backend = utility_module.CALCULATION_BACKEND self.system = utility_module.SYSTEM_TYPE self.default_precision = utility_module.DEFAULT_PRECISION +``` """ expected_read_only_context = """ ```python:utility_module.py @@ -2113,7 +2124,7 @@ def __init__(self, precision="high", fallback_precision=None, mode="standard"): CALCULATION_BACKEND = "python" ``` """ - assert read_write_context.flat.strip() == expected_read_write_context.strip() + assert read_write_context.markdown.strip() == expected_read_write_context.strip() assert read_only_context.strip() == expected_read_only_context.strip() diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 28e6dc3d5..d7c7772d7 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -13,7 +13,7 @@ replace_functions_in_file, ) from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent, get_code_block_splitter +from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, FunctionParent from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig @@ -43,12 +43,14 @@ class Args: def test_code_replacement_global_statements(): project_root = Path(__file__).parent.parent.resolve() code_path = (project_root / "code_to_optimize/bubble_sort_optimized.py").resolve() - optimized_code = f"""{get_code_block_splitter(code_path.relative_to(project_root))} + optimized_code = f"""```python:{code_path.relative_to(project_root)} import numpy as np inconsequential_var = '123' def sorter(arr): - return arr.sort()""" + return arr.sort() +``` +""" original_code_str = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").read_text( encoding="utf-8" ) @@ -1684,7 +1686,7 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) """ - optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))} + optimized_code = f"""```python:{code_path.relative_to(root_dir)} import numpy as np def some_fn(): @@ -1699,7 +1701,8 @@ def new_function2(value): return cst.ensure_type(value, str) a=2 print("Hello world") - """ +``` +""" expected_code = """import numpy as np print("Hello world") @@ -1760,7 +1763,7 @@ def new_function2(value): return cst.ensure_type(value, str) a=1 """ - optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))} + optimized_code = f"""```python:{code_path.relative_to(root_dir)} a=2 import numpy as np def some_fn(): @@ -1774,7 +1777,8 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) print("Hello world") - """ +``` +""" expected_code = """import numpy as np print("Hello world") @@ -1837,7 +1841,7 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) """ - optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))} + optimized_code = f"""```python:{code_path.relative_to(root_dir)} import numpy as np a=2 def some_fn(): @@ -1852,7 +1856,8 @@ def new_function2(value): return cst.ensure_type(value, str) a=3 print("Hello world") - """ +``` +""" expected_code = """import numpy as np print("Hello world") @@ -1915,7 +1920,7 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) """ - optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))} + optimized_code = f"""```python:{code_path.relative_to(root_dir)} a=2 import numpy as np def some_fn(): @@ -1929,7 +1934,8 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) print("Hello world") - """ +``` +""" expected_code = """import numpy as np print("Hello world") @@ -1992,7 +1998,7 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) """ - optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))} + optimized_code = f"""```python:{code_path.relative_to(root_dir)} import numpy as np a=2 def some_fn(): @@ -2007,7 +2013,8 @@ def new_function2(value): return cst.ensure_type(value, str) a=3 print("Hello world") - """ +``` +""" expected_code = """import numpy as np print("Hello world") @@ -2073,7 +2080,7 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) """ - optimized_code = f"""{get_code_block_splitter(code_path.relative_to(root_dir))} + optimized_code = f"""```python:{code_path.relative_to(root_dir)} import numpy as np if 1<2: a=2 @@ -2091,6 +2098,7 @@ def __call__(self, value): def new_function2(value): return cst.ensure_type(value, str) print("Hello world") +``` """ expected_code = """import numpy as np diff --git a/tests/test_multi_file_code_replacement.py b/tests/test_multi_file_code_replacement.py index 1edd7045b..05a9c01c0 100644 --- a/tests/test_multi_file_code_replacement.py +++ b/tests/test_multi_file_code_replacement.py @@ -1,6 +1,6 @@ from pathlib import Path from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, get_code_block_splitter +from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig @@ -50,7 +50,7 @@ def _get_string_usage(text: str) -> Usage: """ main_file.write_text(original_main, encoding="utf-8") - optimized_code = f"""{get_code_block_splitter(helper_file.relative_to(root_dir))} + optimized_code = f"""```python:{helper_file.relative_to(root_dir)} import re from collections.abc import Sequence @@ -83,14 +83,15 @@ def _estimate_string_tokens(content: str | Sequence[UserContent]) -> int: tokens += len(part.data) return tokens - -{get_code_block_splitter(main_file.relative_to(root_dir))} +``` +```python:{main_file.relative_to(root_dir)} from temp_helper import _estimate_string_tokens from pydantic_ai_slim.pydantic_ai.usage import Usage def _get_string_usage(text: str) -> Usage: response_tokens = _estimate_string_tokens(text) return Usage(response_tokens=response_tokens, total_tokens=response_tokens) +``` """ diff --git a/tests/test_unused_helper_revert.py b/tests/test_unused_helper_revert.py index 6a4a0827e..0c2756c3b 100644 --- a/tests/test_unused_helper_revert.py +++ b/tests/test_unused_helper_revert.py @@ -6,7 +6,7 @@ import pytest from codeflash.context.unused_definition_remover import detect_unused_helper_functions from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import CodeStringsMarkdown, get_code_block_splitter +from codeflash.models.models import CodeStringsMarkdown from codeflash.optimization.function_optimizer import FunctionOptimizer from codeflash.verification.verification_utils import TestConfig @@ -56,8 +56,8 @@ def test_detect_unused_helper_functions(temp_project): temp_dir, main_file, test_cfg = temp_project # Optimized version that only calls one helper - optimized_code = f""" -{get_code_block_splitter("main.py")} + optimized_code = """ +```python:main.py def entrypoint_function(n): \"\"\"Optimized function that only calls one helper.\"\"\" result1 = helper_function_1(n) @@ -70,6 +70,7 @@ def helper_function_1(x): def helper_function_2(x): \"\"\"Second helper function - MODIFIED VERSION should be reverted.\"\"\" return x * 4 # This change should be reverted to original x * 3 +``` """ # Create FunctionToOptimize instance @@ -91,7 +92,7 @@ def helper_function_2(x): code_context = ctx_result.unwrap() # Test unused helper detection - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat) # Should detect helper_function_2 as unused unused_names = {uh.qualified_name for uh in unused_helpers} @@ -101,8 +102,8 @@ def helper_function_2(x): # Also test the complete replace_function_and_helpers_with_optimized_code workflow # First modify the optimized code to include a MODIFIED unused helper - optimized_code_with_modified_helper = f""" -{get_code_block_splitter("main.py")} + optimized_code_with_modified_helper = """ +```python:main.py def entrypoint_function(n): \"\"\"Optimized function that only calls one helper.\"\"\" result1 = helper_function_1(n) @@ -115,6 +116,7 @@ def helper_function_1(x): def helper_function_2(x): \"\"\"Second helper function - MODIFIED VERSION should be reverted.\"\"\" return x * 7 # This should be reverted to x * 3 +``` """ original_helper_code = {main_file: main_file.read_text()} @@ -161,8 +163,8 @@ def test_revert_unused_helper_functions(temp_project): temp_dir, main_file, test_cfg = temp_project # Optimized version that only calls one helper and modifies the unused one - optimized_code = f""" -{get_code_block_splitter("main.py") } + optimized_code = """ +```python:main.py def entrypoint_function(n): \"\"\"Optimized function that only calls one helper.\"\"\" result1 = helper_function_1(n) @@ -175,6 +177,7 @@ def helper_function_1(x): def helper_function_2(x): \"\"\"Modified helper function - should be reverted.\"\"\" return x * 4 # This change should be reverted +``` """ # Create FunctionToOptimize instance @@ -224,8 +227,8 @@ def test_no_unused_helpers_no_revert(temp_project): temp_dir, main_file, test_cfg = temp_project # Optimized version that still calls both helpers - optimized_code = f""" -{get_code_block_splitter("main.py")} + optimized_code = """ +```python:main.py def entrypoint_function(n): \"\"\"Optimized function that still calls both helpers.\"\"\" result1 = helper_function_1(n) @@ -239,6 +242,7 @@ def helper_function_1(x): def helper_function_2(x): \"\"\"Second helper function - optimized.\"\"\" return x * 3 +``` """ # Create FunctionToOptimize instance @@ -263,7 +267,7 @@ def helper_function_2(x): original_helper_code = {main_file: main_file.read_text()} # Test detection - should find no unused helpers - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat) assert len(unused_helpers) == 0, "No helpers should be detected as unused" # Apply optimization @@ -307,14 +311,15 @@ def helper_function_2(x): """) # Optimized version that only calls one helper - optimized_code = f""" -{get_code_block_splitter("main.py")} + optimized_code = """ +```python:main.py from helpers import helper_function_1 def entrypoint_function(n): \"\"\"Optimized function that only calls one helper.\"\"\" result1 = helper_function_1(n) return result1 + n * 3 # Inlined helper_function_2 +``` """ # Create test config @@ -345,7 +350,7 @@ def entrypoint_function(n): code_context = ctx_result.unwrap() # Test unused helper detection - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat) # Should detect helper_function_2 as unused unused_names = {uh.qualified_name for uh in unused_helpers} @@ -482,8 +487,8 @@ def helper_method_2(self, x): """) # Optimized version that only calls one helper method - optimized_code = f""" -{get_code_block_splitter("main.py") } + optimized_code = """ +```python:main.py class Calculator: def entrypoint_method(self, n): \"\"\"Optimized method that only calls one helper.\"\"\" @@ -497,6 +502,7 @@ def helper_method_1(self, x): def helper_method_2(self, x): \"\"\"Second helper method - should be reverted.\"\"\" return x * 4 +``` """ # Create test config @@ -532,7 +538,7 @@ def helper_method_2(self, x): code_context = ctx_result.unwrap() # Test unused helper detection - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat) # Should detect Calculator.helper_method_2 as unused unused_names = {uh.qualified_name for uh in unused_helpers} @@ -542,8 +548,8 @@ def helper_method_2(self, x): # Also test the complete replace_function_and_helpers_with_optimized_code workflow # Update optimized code to include a MODIFIED unused helper - optimized_code_with_modified_helper = f""" -{get_code_block_splitter("main.py")} + optimized_code_with_modified_helper = """ +```python:main.py class Calculator: def entrypoint_method(self, n): \"\"\"Optimized method that only calls one helper.\"\"\" @@ -557,6 +563,7 @@ def helper_method_1(self, x): def helper_method_2(self, x): \"\"\"Second helper method - MODIFIED VERSION should be reverted.\"\"\" return x * 8 # This should be reverted to x * 3 +``` """ original_helper_code = {main_file: main_file.read_text()} @@ -625,8 +632,8 @@ def process_data(self, n): """) # Optimized version that only calls one external helper - optimized_code = f""" -{get_code_block_splitter("main.py") } + optimized_code = """ +```python:main.py def external_helper_1(x): \"\"\"External helper function.\"\"\" return x * 2 @@ -640,6 +647,7 @@ def process_data(self, n): \"\"\"Optimized method that only calls one helper.\"\"\" result1 = external_helper_1(n) return result1 + n * 3 # Inlined external_helper_2 +``` """ # Create test config @@ -675,7 +683,7 @@ def process_data(self, n): code_context = ctx_result.unwrap() # Test unused helper detection - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat) # Should detect external_helper_2 as unused unused_names = {uh.qualified_name for uh in unused_helpers} @@ -685,8 +693,8 @@ def process_data(self, n): # Also test the complete replace_function_and_helpers_with_optimized_code workflow # Update optimized code to include a MODIFIED unused helper - optimized_code_with_modified_helper = f""" -{get_code_block_splitter("main.py")} + optimized_code_with_modified_helper = """ +```python:main.py def external_helper_1(x): \"\"\"External helper function.\"\"\" return x * 2 @@ -700,6 +708,7 @@ def process_data(self, n): \"\"\"Optimized method that only calls one helper.\"\"\" result1 = external_helper_1(n) return result1 + n * 3 # Inlined external_helper_2 +``` """ original_helper_code = {main_file: main_file.read_text()} @@ -724,8 +733,8 @@ def process_data(self, n): # Also test the complete replace_function_and_helpers_with_optimized_code workflow # Update optimized code to include a MODIFIED unused helper - optimized_code_with_modified_helper = f""" -{get_code_block_splitter("main.py")} + optimized_code_with_modified_helper = """ +```python:main.py def external_helper_1(x): \"\"\"External helper function.\"\"\" return x * 2 @@ -739,6 +748,7 @@ def process_data(self, n): \"\"\"Optimized method that only calls one helper.\"\"\" result1 = external_helper_1(n) return result1 + n * 3 # Inlined external_helper_2 +``` """ original_helper_code = {main_file: main_file.read_text()} @@ -795,8 +805,8 @@ def local_helper(self, x): """) # Optimized version that inlines one helper - optimized_code = f""" -{get_code_block_splitter("main.py") } + optimized_code = """ +```python:main.py def global_helper_1(x): return x * 2 @@ -812,6 +822,7 @@ def compute(self, n): def local_helper(self, x): return x + 1 +``` """ # Create test config @@ -878,7 +889,7 @@ def local_helper(self, x): ] }, )(), - optimized_code, + CodeStringsMarkdown.parse_markdown_code(optimized_code).flat, ) # Should detect global_helper_2 as unused @@ -964,8 +975,8 @@ def clean_data(x): """) # Optimized version that only uses some functions - optimized_code = f""" -{get_code_block_splitter("main.py") } + optimized_code = """ +```python:main.py import utils from math_helpers import add @@ -976,6 +987,7 @@ def entrypoint_function(n): # Inlined multiply: result3 = n * 2 # Inlined process_data: result4 = n ** 2 return result1 + result2 + (n * 2) + (n ** 2) +``` """ # Create test config @@ -1006,7 +1018,7 @@ def entrypoint_function(n): code_context = ctx_result.unwrap() # Test unused helper detection - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat) # Should detect multiply, process_data as unused (at minimum) unused_names = {uh.qualified_name for uh in unused_helpers} @@ -1126,8 +1138,8 @@ def divide_numbers(x, y): """) # Optimized version that only uses add_numbers - optimized_code = f""" -{get_code_block_splitter("main.py") } + optimized_code = """ +```python:main.py import calculator def entrypoint_function(n): @@ -1135,6 +1147,7 @@ def entrypoint_function(n): result1 = calculator.add_numbers(n, 10) # Inlined: result2 = n * 5 return result1 + (n * 5) +``` """ # Create test config @@ -1165,7 +1178,7 @@ def entrypoint_function(n): code_context = ctx_result.unwrap() # Test unused helper detection - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, optimized_code) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat) # Should detect multiply_numbers and divide_numbers as unused unused_names = {uh.qualified_name for uh in unused_helpers} @@ -1329,8 +1342,8 @@ def calculate_class(cls, n): """) # Optimized static method that inlines one utility - optimized_static_code = f""" -{get_code_block_splitter("main.py")} + optimized_static_code = """ +```python:main.py def utility_function_1(x): return x * 2 @@ -1350,6 +1363,7 @@ def calculate_class(cls, n): result1 = utility_function_1(n) result2 = utility_function_2(n) return result1 - result2 +``` """ # Create test config @@ -1386,7 +1400,7 @@ def calculate_class(cls, n): # Test unused helper detection for static method unused_helpers = detect_unused_helper_functions( - optimizer.function_to_optimize, code_context, optimized_static_code + optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_static_code).flat ) # Should detect utility_function_2 as unused @@ -1397,8 +1411,8 @@ def calculate_class(cls, n): # Also test the complete replace_function_and_helpers_with_optimized_code workflow # Update optimized code to include a MODIFIED unused helper - optimized_static_code_with_modified_helper = f""" -{get_code_block_splitter("main.py")} + optimized_static_code_with_modified_helper = """ +```python:main.py def utility_function_1(x): return x * 2 @@ -1418,6 +1432,7 @@ def calculate_class(cls, n): result1 = utility_function_1(n) result2 = utility_function_2(n) return result1 - result2 +``` """ original_helper_code = {main_file: main_file.read_text()} From 684661e78ac68914d1867edd372ddf19b6a86dfd Mon Sep 17 00:00:00 2001 From: mohammed Date: Wed, 6 Aug 2025 03:40:06 +0300 Subject: [PATCH 22/25] typo --- code_to_optimize/bubble_sort.py | 2 +- codeflash/context/code_context_extractor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/code_to_optimize/bubble_sort.py b/code_to_optimize/bubble_sort.py index 787cc4a90..9e97f63a0 100644 --- a/code_to_optimize/bubble_sort.py +++ b/code_to_optimize/bubble_sort.py @@ -7,4 +7,4 @@ def sorter(arr): arr[j] = arr[j + 1] arr[j + 1] = temp print(f"result: {arr}") - return arr \ No newline at end of file + return arr diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index 97befcb4a..cd3728038 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -63,7 +63,7 @@ def get_code_optimization_context( # Extract code context for optimization final_read_writable_code = extract_code_markdown_context_from_files( helpers_of_fto_dict, - helpers_of_helpers_dict, + {}, project_root_path, remove_docstrings=False, code_context_type=CodeContextType.READ_WRITABLE, From a7ff7013091070971955a944ac9568cd75d2f1d4 Mon Sep 17 00:00:00 2001 From: mohammed Date: Wed, 6 Aug 2025 22:48:03 +0300 Subject: [PATCH 23/25] eleminate the use of flat code for parsing --- codeflash/api/aiservice.py | 3 --- codeflash/context/code_context_extractor.py | 4 ++-- codeflash/context/unused_definition_remover.py | 13 ++++++++++++- codeflash/lsp/beta.py | 4 ++-- codeflash/models/models.py | 11 ++++++----- codeflash/optimization/function_optimizer.py | 17 +++++++---------- tests/test_code_context_extractor.py | 2 +- tests/test_code_replacement.py | 1 + tests/test_unused_helper_revert.py | 18 +++++++++--------- 9 files changed, 40 insertions(+), 33 deletions(-) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index 65bc92bcc..d921469d1 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -73,9 +73,6 @@ def make_ai_service_request( url = f"{self.base_url}/ai{endpoint}" if method.upper() == "POST": json_payload = json.dumps(payload, indent=None, default=pydantic_encoder) - logger.debug(f"========JSON PAYLOAD FOR {url}==============") - logger.debug(json_payload) - logger.debug("======================") headers = {**self.headers, "Content-Type": "application/json"} response = requests.post(url, data=json_payload, headers=headers, timeout=timeout) else: diff --git a/codeflash/context/code_context_extractor.py b/codeflash/context/code_context_extractor.py index cd3728038..09c0c564a 100644 --- a/codeflash/context/code_context_extractor.py +++ b/codeflash/context/code_context_extractor.py @@ -85,14 +85,14 @@ def get_code_optimization_context( ) # Handle token limits - final_read_writable_tokens = encoded_tokens_len(final_read_writable_code.flat) + final_read_writable_tokens = encoded_tokens_len(final_read_writable_code.markdown) if final_read_writable_tokens > optim_token_limit: raise ValueError("Read-writable code has exceeded token limit, cannot proceed") # Setup preexisting objects for code replacer preexisting_objects = set( chain( - find_preexisting_objects(final_read_writable_code.flat), + *(find_preexisting_objects(codestring.code) for codestring in final_read_writable_code.code_strings), *(find_preexisting_objects(codestring.code) for codestring in read_only_code_markdown.code_strings), ) ) diff --git a/codeflash/context/unused_definition_remover.py b/codeflash/context/unused_definition_remover.py index 72586b5eb..cf57af031 100644 --- a/codeflash/context/unused_definition_remover.py +++ b/codeflash/context/unused_definition_remover.py @@ -3,6 +3,7 @@ import ast from collections import defaultdict from dataclasses import dataclass, field +from itertools import chain from pathlib import Path from typing import TYPE_CHECKING, Optional @@ -611,7 +612,9 @@ def _analyze_imports_in_optimized_code( def detect_unused_helper_functions( - function_to_optimize: FunctionToOptimize, code_context: CodeOptimizationContext, optimized_code: str + function_to_optimize: FunctionToOptimize, + code_context: CodeOptimizationContext, + optimized_code: str | CodeStringsMarkdown, ) -> list[FunctionSource]: """Detect helper functions that are no longer called by the optimized entrypoint function. @@ -624,6 +627,14 @@ def detect_unused_helper_functions( List of FunctionSource objects representing unused helper functions """ + if isinstance(optimized_code, CodeStringsMarkdown) and len(optimized_code.code_strings) > 0: + return list( + chain.from_iterable( + detect_unused_helper_functions(function_to_optimize, code_context, code.code) + for code in optimized_code.code_strings + ) + ) + try: # Parse the optimized code to analyze function calls and imports optimized_ast = ast.parse(optimized_code) diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index 6081270fb..4ed0b7a62 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -222,7 +222,7 @@ def generate_tests(server: CodeflashLanguageServer, params: FunctionOptimization generated_test.generated_original_test_source for generated_test in generated_tests_list.generated_tests ] optimizations_dict = { - candidate.optimization_id: {"source_code": candidate.source_code.flat, "explanation": candidate.explanation} + candidate.optimization_id: {"source_code": candidate.source_code.markdown, "explanation": candidate.explanation} for candidate in optimizations_set.control + optimizations_set.experiment } @@ -330,7 +330,7 @@ def perform_function_optimization( # noqa: PLR0911 "message": f"No best optimizations found for function {function_to_optimize_qualified_name}", } - optimized_source = best_optimization.candidate.source_code.flat + optimized_source = best_optimization.candidate.source_code.markdown speedup = original_code_baseline.runtime / best_optimization.runtime server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info") diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 0d236e9f4..1636d6889 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -157,12 +157,8 @@ class CodeString(BaseModel): file_path: Optional[Path] = None -# Used to split files by adding a marker at the start of each file followed by the file path. -LINE_SPLITTER_MARKER_PREFIX = "# --codeflash:file--" - - def get_code_block_splitter(file_path: Path) -> str: - return f"{LINE_SPLITTER_MARKER_PREFIX}{file_path}" + return f"# file: {file_path}" markdown_pattern = re.compile(r"```python:([^\n]+)\n(.*?)\n```", re.DOTALL) @@ -182,6 +178,11 @@ def flat(self) -> str: Returns: str: The concatenated code of all blocks with file path annotations. + !! Important !!: + Avoid parsing the flat code with multiple files, + parsing may result in unexpected behavior. + + """ if self._cache.get("flat") is not None: return self._cache["flat"] diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 6e7bc5a9b..6c97283a7 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -62,7 +62,6 @@ from codeflash.either import Failure, Success, is_successful from codeflash.models.ExperimentMetadata import ExperimentMetadata from codeflash.models.models import ( - LINE_SPLITTER_MARKER_PREFIX, BestOptimization, CodeOptimizationContext, GeneratedTests, @@ -171,7 +170,10 @@ def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[P helper_code = f.read() original_helper_code[helper_function_path] = helper_code - if has_any_async_functions(code_context.read_writable_code.flat): + async_code = any( + has_any_async_functions(code_string.code) for code_string in code_context.read_writable_code.code_strings + ) + if async_code: return Failure("Codeflash does not support async functions in the code to optimize.") # Random here means that we still attempt optimization with a fractional chance to see if # last time we could not find an optimization, maybe this time we do. @@ -731,7 +733,7 @@ def replace_function_and_helpers_with_optimized_code( preexisting_objects=code_context.preexisting_objects, project_root_path=self.project_root, ) - unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code.flat) + unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code) # Revert unused helper functions to their original definitions if unused_helpers: @@ -1165,15 +1167,10 @@ def process_review( optimized_runtimes_all=optimized_runtime_by_test, ) new_explanation_raw_str = self.aiservice_client.get_new_explanation( - source_code=code_context.read_writable_code.flat.replace( - LINE_SPLITTER_MARKER_PREFIX, - "# file: ", # for better readability - ), + source_code=code_context.read_writable_code.flat, dependency_code=code_context.read_only_context_code, trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id, - optimized_code=best_optimization.candidate.source_code.flat.replace( - LINE_SPLITTER_MARKER_PREFIX, "# file: " - ), + optimized_code=best_optimization.candidate.source_code.flat, original_line_profiler_results=original_code_baseline.line_profile_results["str_out"], optimized_line_profiler_results=best_optimization.line_profiler_test_results["str_out"], original_code_runtime=humanize_runtime(original_code_baseline.runtime), diff --git a/tests/test_code_context_extractor.py b/tests/test_code_context_extractor.py index a9751eb6f..3a7de5d1c 100644 --- a/tests/test_code_context_extractor.py +++ b/tests/test_code_context_extractor.py @@ -9,7 +9,7 @@ import pytest from codeflash.context.code_context_extractor import get_code_optimization_context from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import FunctionParent, get_code_block_splitter +from codeflash.models.models import FunctionParent from codeflash.optimization.optimizer import Optimizer from codeflash.code_utils.code_replacer import replace_functions_and_add_imports from codeflash.code_utils.code_extractor import add_global_assignments diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index d7c7772d7..d77d6a43e 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -123,6 +123,7 @@ def totally_new_function(value): function_name: str = "NewClass.new_function" preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]] = find_preexisting_objects(original_code) + print(f"Preexisting objects: {preexisting_objects}") new_code: str = replace_functions_and_add_imports( source_code=original_code, function_names=[function_name], diff --git a/tests/test_unused_helper_revert.py b/tests/test_unused_helper_revert.py index 0c2756c3b..30f291e62 100644 --- a/tests/test_unused_helper_revert.py +++ b/tests/test_unused_helper_revert.py @@ -92,7 +92,7 @@ def helper_function_2(x): code_context = ctx_result.unwrap() # Test unused helper detection - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) # Should detect helper_function_2 as unused unused_names = {uh.qualified_name for uh in unused_helpers} @@ -267,7 +267,7 @@ def helper_function_2(x): original_helper_code = {main_file: main_file.read_text()} # Test detection - should find no unused helpers - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) assert len(unused_helpers) == 0, "No helpers should be detected as unused" # Apply optimization @@ -350,7 +350,7 @@ def entrypoint_function(n): code_context = ctx_result.unwrap() # Test unused helper detection - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) # Should detect helper_function_2 as unused unused_names = {uh.qualified_name for uh in unused_helpers} @@ -538,7 +538,7 @@ def helper_method_2(self, x): code_context = ctx_result.unwrap() # Test unused helper detection - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) # Should detect Calculator.helper_method_2 as unused unused_names = {uh.qualified_name for uh in unused_helpers} @@ -683,7 +683,7 @@ def process_data(self, n): code_context = ctx_result.unwrap() # Test unused helper detection - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) # Should detect external_helper_2 as unused unused_names = {uh.qualified_name for uh in unused_helpers} @@ -889,7 +889,7 @@ def local_helper(self, x): ] }, )(), - CodeStringsMarkdown.parse_markdown_code(optimized_code).flat, + CodeStringsMarkdown.parse_markdown_code(optimized_code), ) # Should detect global_helper_2 as unused @@ -1018,7 +1018,7 @@ def entrypoint_function(n): code_context = ctx_result.unwrap() # Test unused helper detection - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) # Should detect multiply, process_data as unused (at minimum) unused_names = {uh.qualified_name for uh in unused_helpers} @@ -1178,7 +1178,7 @@ def entrypoint_function(n): code_context = ctx_result.unwrap() # Test unused helper detection - unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code).flat) + unused_helpers = detect_unused_helper_functions(optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_code)) # Should detect multiply_numbers and divide_numbers as unused unused_names = {uh.qualified_name for uh in unused_helpers} @@ -1400,7 +1400,7 @@ def calculate_class(cls, n): # Test unused helper detection for static method unused_helpers = detect_unused_helper_functions( - optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_static_code).flat + optimizer.function_to_optimize, code_context, CodeStringsMarkdown.parse_markdown_code(optimized_static_code) ) # Should detect utility_function_2 as unused From c8d4e05a8b32482b6893b0d9b95a0c2df4e6a93d Mon Sep 17 00:00:00 2001 From: mohammed Date: Thu, 7 Aug 2025 01:24:01 +0300 Subject: [PATCH 24/25] chore: trigger CI From ef083ecd04733142a56147d6cd04e627cccbf225 Mon Sep 17 00:00:00 2001 From: mohammed Date: Thu, 7 Aug 2025 05:48:34 +0300 Subject: [PATCH 25/25] remove comments --- codeflash/optimization/function_optimizer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 6c97283a7..6905bf47c 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -219,7 +219,7 @@ def generate_and_instrument_tests( revert_to_print=bool(get_pr_number()), ): generated_results = self.generate_tests_and_optimizations( - testgen_context_code=code_context.testgen_context_code, # TODO: should we send the markdow context for the testgen instead. + testgen_context_code=code_context.testgen_context_code, read_writable_code=code_context.read_writable_code, read_only_context_code=code_context.read_only_context_code, helper_functions=code_context.helper_functions, @@ -292,7 +292,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() - code_print(code_context.read_writable_code.flat) # Should we print the markdown or the flattened code? + code_print(code_context.read_writable_code.flat) test_setup_result = self.generate_and_instrument_tests( # also generates optimizations code_context, should_run_experiment=should_run_experiment