From 3cbd6b71280e73e82bb54b16fe092e8d3b9053e0 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Wed, 14 May 2025 11:16:01 -0700 Subject: [PATCH 01/29] feat(optimizer): Implement targeted formatting (CF-637) --- codeflash/optimization/function_optimizer.py | 77 ++++++++++++++------ 1 file changed, 53 insertions(+), 24 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 56124a9cb..419d63848 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -5,6 +5,7 @@ import os import shutil import subprocess +import tempfile import time import uuid from collections import defaultdict, deque @@ -124,6 +125,7 @@ def __init__( self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {} self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {} self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None + self.optimizer_temp_dir = Path(tempfile.mkdtemp(prefix="codeflash_opt_fmt_")) def optimize_function(self) -> Result[BestOptimization, str]: should_run_experiment = self.experiment_id is not None @@ -301,9 +303,30 @@ def optimize_function(self) -> Result[BestOptimization, str]: code_context=code_context, optimized_code=best_optimization.candidate.source_code ) - new_code, new_helper_code = self.reformat_code_and_helpers( - code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code - ) + if not self.args.disable_imports_sorting: + main_file_path = self.function_to_optimize.file_path + if main_file_path.exists(): + current_main_content = main_file_path.read_text(encoding="utf8") + sorted_main_content = sort_imports(current_main_content) + if sorted_main_content != current_main_content: + main_file_path.write_text(sorted_main_content, encoding="utf8") + + writable_helper_file_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_file_path in writable_helper_file_paths: + if helper_file_path.exists(): + current_helper_content = helper_file_path.read_text(encoding="utf8") + sorted_helper_content = sort_imports(current_helper_content) + if sorted_helper_content != current_helper_content: + helper_file_path.write_text(sorted_helper_content, encoding="utf8") + + new_code = self.function_to_optimize.file_path.read_text(encoding="utf8") + new_helper_code: dict[Path, str] = {} + for helper_file_path_key in original_helper_code: + if helper_file_path_key.exists(): + new_helper_code[helper_file_path_key] = helper_file_path_key.read_text(encoding="utf8") + else: + logger.warning(f"Helper file {helper_file_path_key} not found after optimization. It will not be included in new_helper_code for PR.") + existing_tests = existing_tests_source_for( self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), @@ -405,6 +428,33 @@ def determine_best_candidate( future_line_profile_results = None candidate_index += 1 candidate = candidates.popleft() + + formatted_candidate_code = candidate.source_code + if self.args.formatter_cmds: + temp_code_file_path: Path | None = None + try: + with tempfile.NamedTemporaryFile( + mode="w", + suffix=".py", + delete=False, + encoding="utf8", + dir=self.optimizer_temp_dir + ) as tmp_file: + tmp_file.write(candidate.source_code) + temp_code_file_path = Path(tmp_file.name) + + formatted_candidate_code = format_code( + formatter_cmds=self.args.formatter_cmds, + path=temp_code_file_path + ) + except Exception as e: + logger.error(f"Error during formatting candidate code via temp file: {e}. Using original candidate code.") + finally: + if temp_code_file_path and temp_code_file_path.exists(): + temp_code_file_path.unlink(missing_ok=True) + + candidate.source_code = formatted_candidate_code + get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) logger.info(f"Optimization candidate {candidate_index}/{original_len}:") @@ -580,27 +630,6 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, with Path(module_abspath).open("w", encoding="utf8") as f: f.write(original_helper_code[module_abspath]) - def reformat_code_and_helpers( - self, helper_functions: list[FunctionSource], path: Path, original_code: str - ) -> tuple[str, dict[Path, str]]: - should_sort_imports = not self.args.disable_imports_sorting - if should_sort_imports and isort.code(original_code) != original_code: - should_sort_imports = False - - new_code = format_code(self.args.formatter_cmds, path) - if should_sort_imports: - new_code = sort_imports(new_code) - - new_helper_code: dict[Path, str] = {} - helper_functions_paths = {hf.file_path for hf in helper_functions} - for module_abspath in helper_functions_paths: - formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath) - if should_sort_imports: - formatted_helper_code = sort_imports(formatted_helper_code) - new_helper_code[module_abspath] = formatted_helper_code - - return new_code, new_helper_code - def replace_function_and_helpers_with_optimized_code( self, code_context: CodeOptimizationContext, optimized_code: str ) -> bool: From a10600c3e3956bce1d995bdbf9feb10996fb0fb9 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Wed, 14 May 2025 11:50:39 -0700 Subject: [PATCH 02/29] Fixed changes to the FunctionOptimizer --- codeflash/optimization/function_optimizer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 419d63848..b8b8b8388 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -2,6 +2,7 @@ import ast import concurrent.futures +import dataclasses import os import shutil import subprocess @@ -453,7 +454,7 @@ def determine_best_candidate( if temp_code_file_path and temp_code_file_path.exists(): temp_code_file_path.unlink(missing_ok=True) - candidate.source_code = formatted_candidate_code + candidate = dataclasses.replace(candidate, source_code=formatted_candidate_code) get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) From bcba52723dee1fb77196cc1985833ab6b199aaab Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Wed, 14 May 2025 11:52:50 -0700 Subject: [PATCH 03/29] CODEFLASH_DISABLE_TELEMETRY environment variable can be set to disable telemetry --- codeflash/main.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/codeflash/main.py b/codeflash/main.py index 02b13d5aa..78a66acf1 100644 --- a/codeflash/main.py +++ b/codeflash/main.py @@ -2,6 +2,7 @@ solved problem, please reach out to us at careers@codeflash.ai. We're hiring! """ +import os from pathlib import Path from codeflash.cli_cmds.cli import parse_args, process_pyproject_config @@ -20,12 +21,12 @@ def main() -> None: CODEFLASH_LOGO, panel_args={"title": "https://codeflash.ai", "expand": False}, text_args={"style": "bold gold3"} ) args = parse_args() + if args.command: - if args.config_file and Path.exists(args.config_file): + disable_telemetry = os.environ.get("CODEFLASH_DISABLE_TELEMETRY", "").lower() in {"true", "t", "1", "yes", "y"} + if (not disable_telemetry) and args.config_file and Path.exists(args.config_file): pyproject_config, _ = parse_config_file(args.config_file) disable_telemetry = pyproject_config.get("disable_telemetry", False) - else: - disable_telemetry = False init_sentry(not disable_telemetry, exclude_errors=True) posthog_cf.initialize_posthog(not disable_telemetry) args.func() From 36da6403e2a320b50ba4f1dbf3fba9b8fd7c7f79 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Wed, 14 May 2025 13:20:21 -0700 Subject: [PATCH 04/29] Added bubble sort implementation with bad formatting in non-optimized sections of the file... This is to test the new formatting changes --- ...ve_bad_formatting_for_nonoptimized_code.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 code_to_optimize/bubble_sort_preserve_bad_formatting_for_nonoptimized_code.py diff --git a/code_to_optimize/bubble_sort_preserve_bad_formatting_for_nonoptimized_code.py b/code_to_optimize/bubble_sort_preserve_bad_formatting_for_nonoptimized_code.py new file mode 100644 index 000000000..b506ddfbb --- /dev/null +++ b/code_to_optimize/bubble_sort_preserve_bad_formatting_for_nonoptimized_code.py @@ -0,0 +1,19 @@ +def lol(): + print( "lol" ) + + + + + + + +def sorter(arr): + print("codeflash stdout: Sorting list") + for i in range(len(arr)): + for j in range(len(arr) - 1): + if arr[j] > arr[j + 1]: + temp = arr[j] + arr[j] = arr[j + 1] + arr[j + 1] = temp + print(f"result: {arr}") + return arr From 85fd3c0155494186137c2fdfc9dbfc50583041e9 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Wed, 14 May 2025 14:53:30 -0700 Subject: [PATCH 05/29] Added a file containing a bubble sort method in a class To test that the new formatting logic correctly handles indentation --- ...ve_bad_formatting_for_nonoptimized_code.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 code_to_optimize/bubble_sort_method_preserve_bad_formatting_for_nonoptimized_code.py diff --git a/code_to_optimize/bubble_sort_method_preserve_bad_formatting_for_nonoptimized_code.py b/code_to_optimize/bubble_sort_method_preserve_bad_formatting_for_nonoptimized_code.py new file mode 100644 index 000000000..29a00a922 --- /dev/null +++ b/code_to_optimize/bubble_sort_method_preserve_bad_formatting_for_nonoptimized_code.py @@ -0,0 +1,38 @@ +import sys + + +def lol(): + print( "lol" ) + + + + + + + + + +class BubbleSorter: + def __init__(self, x=0): + self.x = x + + def lol(self): + print( "lol" ) + + + + + + + + + def sorter(self, arr): + print("codeflash stdout : BubbleSorter.sorter() called") + for i in range(len(arr)): + for j in range(len(arr) - 1): + if arr[j] > arr[j + 1]: + temp = arr[j] + arr[j] = arr[j + 1] + arr[j + 1] = temp + print("stderr test", file=sys.stderr) + return arr From 2c4001827e70e299e66b925cc703f2d787de5c72 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Wed, 14 May 2025 15:15:36 -0700 Subject: [PATCH 06/29] Added "scratch/" directory to .gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 535acfb3e..b4a99e8c2 100644 --- a/.gitignore +++ b/.gitignore @@ -254,3 +254,5 @@ fabric.properties # Mac .DS_Store + +scratch/ From 82b9d416e7322e027d9ee0ba43ce10202901ab01 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Wed, 14 May 2025 17:18:21 -0700 Subject: [PATCH 07/29] Cleaned up the import sorting code in FunctionOptimizer --- codeflash/code_utils/formatter.py | 9 +++++++++ codeflash/optimization/function_optimizer.py | 18 +++--------------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 875fd0a1f..8f426ad8a 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -55,3 +55,12 @@ def sort_imports(code: str) -> str: return code # Fall back to original code if isort fails return sorted_code + + +def sort_imports_in_place(paths: list[Path]) -> None: + for path in paths: + if path.exists(): + content = path.read_text(encoding="utf8") + sorted_content = sort_imports(content) + if sorted_content != content: + path.write_text(sorted_content, encoding="utf8") diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index b8b8b8388..369b081fd 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -38,7 +38,7 @@ N_TESTS_TO_GENERATE, TOTAL_LOOPING_TIME, ) -from codeflash.code_utils.formatter import format_code, sort_imports +from codeflash.code_utils.formatter import format_code, sort_imports_in_place from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test from codeflash.code_utils.line_profile_utils import add_decorator_imports from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests @@ -305,20 +305,8 @@ def optimize_function(self) -> Result[BestOptimization, str]: ) if not self.args.disable_imports_sorting: - main_file_path = self.function_to_optimize.file_path - if main_file_path.exists(): - current_main_content = main_file_path.read_text(encoding="utf8") - sorted_main_content = sort_imports(current_main_content) - if sorted_main_content != current_main_content: - main_file_path.write_text(sorted_main_content, encoding="utf8") - - writable_helper_file_paths = {hf.file_path for hf in code_context.helper_functions} - for helper_file_path in writable_helper_file_paths: - if helper_file_path.exists(): - current_helper_content = helper_file_path.read_text(encoding="utf8") - sorted_helper_content = sort_imports(current_helper_content) - if sorted_helper_content != current_helper_content: - helper_file_path.write_text(sorted_helper_content, encoding="utf8") + path_to_sort_imports_for = [self.function_to_optimize.file_path] + [hf.file_path for hf in code_context.helper_functions] + sort_imports_in_place(path_to_sort_imports_for) new_code = self.function_to_optimize.file_path.read_text(encoding="utf8") new_helper_code: dict[Path, str] = {} From ce8783292296d9546da37e5989533d3a83d46808 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Wed, 14 May 2025 17:28:33 -0700 Subject: [PATCH 08/29] Added test for sort_imports_in_place --- tests/test_formatter.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 5c0a91c38..4f2ac6d9d 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -5,7 +5,7 @@ import pytest from codeflash.code_utils.config_parser import parse_config_file -from codeflash.code_utils.formatter import format_code, sort_imports +from codeflash.code_utils.formatter import format_code, sort_imports, sort_imports_in_place def test_remove_duplicate_imports(): @@ -30,6 +30,23 @@ def test_sorting_imports(): new_code = sort_imports(original_code) assert new_code == "import os\nimport sys\nimport unittest\n" +def test_sort_imports_in_place(): + """Test that sorting imports in place in multiple files works.""" + original_code = "import sys\nimport unittest\nimport os\n" + expected_code = "import os\nimport sys\nimport unittest\n" + + with tempfile.TemporaryDirectory() as tmpdir: + file_paths = [] + for i in range(3): + file_path = Path(tmpdir) / f"test_file_{i}.py" + file_path.write_text(original_code, encoding="utf8") + file_paths.append(file_path) + + sort_imports_in_place(file_paths) + + for file_path in file_paths: + assert file_path.read_text(encoding="utf8") == expected_code + def test_sort_imports_without_formatting(): """Test that imports are sorted when formatting is disabled and should_sort_imports is True.""" From 0373bfa049399ec603834730eebd149000d7f0e6 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Fri, 16 May 2025 08:38:58 -0700 Subject: [PATCH 09/29] Reverted changes to optimizer and formatter on targeted-formatting branch We are going to take a gentler, CST-based approach to targeted formatting. --- codeflash/code_utils/formatter.py | 9 --- codeflash/optimization/function_optimizer.py | 68 +++++++------------- tests/test_formatter.py | 19 +----- 3 files changed, 26 insertions(+), 70 deletions(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 8f426ad8a..875fd0a1f 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -55,12 +55,3 @@ def sort_imports(code: str) -> str: return code # Fall back to original code if isort fails return sorted_code - - -def sort_imports_in_place(paths: list[Path]) -> None: - for path in paths: - if path.exists(): - content = path.read_text(encoding="utf8") - sorted_content = sort_imports(content) - if sorted_content != content: - path.write_text(sorted_content, encoding="utf8") diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 369b081fd..56124a9cb 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -2,11 +2,9 @@ import ast import concurrent.futures -import dataclasses import os import shutil import subprocess -import tempfile import time import uuid from collections import defaultdict, deque @@ -38,7 +36,7 @@ N_TESTS_TO_GENERATE, TOTAL_LOOPING_TIME, ) -from codeflash.code_utils.formatter import format_code, sort_imports_in_place +from codeflash.code_utils.formatter import format_code, sort_imports from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test from codeflash.code_utils.line_profile_utils import add_decorator_imports from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests @@ -126,7 +124,6 @@ def __init__( self.function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else {} self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {} self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None - self.optimizer_temp_dir = Path(tempfile.mkdtemp(prefix="codeflash_opt_fmt_")) def optimize_function(self) -> Result[BestOptimization, str]: should_run_experiment = self.experiment_id is not None @@ -304,18 +301,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: code_context=code_context, optimized_code=best_optimization.candidate.source_code ) - if not self.args.disable_imports_sorting: - path_to_sort_imports_for = [self.function_to_optimize.file_path] + [hf.file_path for hf in code_context.helper_functions] - sort_imports_in_place(path_to_sort_imports_for) - - new_code = self.function_to_optimize.file_path.read_text(encoding="utf8") - new_helper_code: dict[Path, str] = {} - for helper_file_path_key in original_helper_code: - if helper_file_path_key.exists(): - new_helper_code[helper_file_path_key] = helper_file_path_key.read_text(encoding="utf8") - else: - logger.warning(f"Helper file {helper_file_path_key} not found after optimization. It will not be included in new_helper_code for PR.") - + new_code, new_helper_code = self.reformat_code_and_helpers( + code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code + ) existing_tests = existing_tests_source_for( self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root), @@ -417,33 +405,6 @@ def determine_best_candidate( future_line_profile_results = None candidate_index += 1 candidate = candidates.popleft() - - formatted_candidate_code = candidate.source_code - if self.args.formatter_cmds: - temp_code_file_path: Path | None = None - try: - with tempfile.NamedTemporaryFile( - mode="w", - suffix=".py", - delete=False, - encoding="utf8", - dir=self.optimizer_temp_dir - ) as tmp_file: - tmp_file.write(candidate.source_code) - temp_code_file_path = Path(tmp_file.name) - - formatted_candidate_code = format_code( - formatter_cmds=self.args.formatter_cmds, - path=temp_code_file_path - ) - except Exception as e: - logger.error(f"Error during formatting candidate code via temp file: {e}. Using original candidate code.") - finally: - if temp_code_file_path and temp_code_file_path.exists(): - temp_code_file_path.unlink(missing_ok=True) - - candidate = dataclasses.replace(candidate, source_code=formatted_candidate_code) - get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) logger.info(f"Optimization candidate {candidate_index}/{original_len}:") @@ -619,6 +580,27 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, with Path(module_abspath).open("w", encoding="utf8") as f: f.write(original_helper_code[module_abspath]) + def reformat_code_and_helpers( + self, helper_functions: list[FunctionSource], path: Path, original_code: str + ) -> tuple[str, dict[Path, str]]: + should_sort_imports = not self.args.disable_imports_sorting + if should_sort_imports and isort.code(original_code) != original_code: + should_sort_imports = False + + new_code = format_code(self.args.formatter_cmds, path) + if should_sort_imports: + new_code = sort_imports(new_code) + + new_helper_code: dict[Path, str] = {} + helper_functions_paths = {hf.file_path for hf in helper_functions} + for module_abspath in helper_functions_paths: + formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath) + if should_sort_imports: + formatted_helper_code = sort_imports(formatted_helper_code) + new_helper_code[module_abspath] = formatted_helper_code + + return new_code, new_helper_code + def replace_function_and_helpers_with_optimized_code( self, code_context: CodeOptimizationContext, optimized_code: str ) -> bool: diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 4f2ac6d9d..5c0a91c38 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -5,7 +5,7 @@ import pytest from codeflash.code_utils.config_parser import parse_config_file -from codeflash.code_utils.formatter import format_code, sort_imports, sort_imports_in_place +from codeflash.code_utils.formatter import format_code, sort_imports def test_remove_duplicate_imports(): @@ -30,23 +30,6 @@ def test_sorting_imports(): new_code = sort_imports(original_code) assert new_code == "import os\nimport sys\nimport unittest\n" -def test_sort_imports_in_place(): - """Test that sorting imports in place in multiple files works.""" - original_code = "import sys\nimport unittest\nimport os\n" - expected_code = "import os\nimport sys\nimport unittest\n" - - with tempfile.TemporaryDirectory() as tmpdir: - file_paths = [] - for i in range(3): - file_path = Path(tmpdir) / f"test_file_{i}.py" - file_path.write_text(original_code, encoding="utf8") - file_paths.append(file_path) - - sort_imports_in_place(file_paths) - - for file_path in file_paths: - assert file_path.read_text(encoding="utf8") == expected_code - def test_sort_imports_without_formatting(): """Test that imports are sorted when formatting is disabled and should_sort_imports is True.""" From 9345d8090cee78beb7092ceec338562f82db6980 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Fri, 16 May 2025 10:47:41 -0700 Subject: [PATCH 10/29] Started work on targeted formatting using the CST --- codeflash/code_utils/code_replacer.py | 13 ++++++++-- codeflash/code_utils/formatter.py | 35 +++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index ccb935f42..279b1df65 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -34,7 +34,7 @@ def normalize_code(code: str) -> str: class OptimFunctionCollector(cst.CSTVisitor): - METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,) + METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider, cst.metadata.PositionProvider) def __init__( self, @@ -52,8 +52,11 @@ def __init__( self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list) self.current_class = None self.modified_init_functions: dict[str, cst.FunctionDef] = {} + self.modification_code_ranges: list[tuple[int, int]] = [] def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + modification = True + if (self.current_class, node.name.value) in self.function_names: self.modified_functions[(self.current_class, node.name.value)] = node elif self.current_class and node.name.value == "__init__": @@ -64,6 +67,13 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: and self.current_class is None ): self.new_functions.append(node) + else: + modification = False + + if modification: + pos = self.get_metadata(cst.metadata.PositionProvider, node) + self.modification_code_ranges.append((pos.start, pos.end)) + return False def visit_ClassDef(self, node: cst.ClassDef) -> bool: @@ -154,7 +164,6 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c node = node.with_changes(body=(*self.new_functions, *node.body)) return node - def replace_functions_in_file( source_code: str, original_function_names: list[str], diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 875fd0a1f..8358e9ab4 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -6,8 +6,12 @@ from typing import TYPE_CHECKING import isort +import libcst as cst from codeflash.cli_cmds.console import console, logger +from codeflash.code_utils.code_replacer import OptimFunctionCollector +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.models.models import CodeOptimizationContext if TYPE_CHECKING: from pathlib import Path @@ -55,3 +59,34 @@ def sort_imports(code: str) -> str: return code # Fall back to original code if isort fails return sorted_code + +def get_modification_code_ranges( + modified_code: str, + fto: FunctionToOptimize, + code_context: CodeOptimizationContext, +) -> list[tuple[int, int]]: + """ + Returns the line number of modified and new functions in a string containing containing the code in a fully modified file. + """ + modified_functions = set() + modified_functions.add(fto.qualified_name) + for helper_function in code_context.helper_functions: + if helper_function.jedi_definition.type != "class": + modified_functions.add(helper_function.qualified_name) + + parsed_function_names = set() + for original_function_name in modified_functions: + if original_function_name.count(".") == 0: + class_name, function_name = None, original_function_name + elif original_function_name.count(".") == 1: + class_name, function_name = original_function_name.split(".") + else: + msg = f"Unable to find {original_function_name}. Returning unchanged source code." + logger.error(msg) + continue + parsed_function_names.add((class_name, function_name)) + + module = cst.metadata.MetadataWrapper(cst.parse_module(modified_code)) + visitor = OptimFunctionCollector(code_context.preexisting_objects, parsed_function_names) + module.visit(visitor) + return visitor.modification_code_ranges From bfc8423bec095d6dd6f592c935c239f208ea6b57 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Fri, 16 May 2025 10:48:33 -0700 Subject: [PATCH 11/29] TODO --- codeflash/code_utils/formatter.py | 1 + 1 file changed, 1 insertion(+) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 8358e9ab4..ebd178026 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -60,6 +60,7 @@ def sort_imports(code: str) -> str: return sorted_code +# TODO(zomglings): Write unit tests. def get_modification_code_ranges( modified_code: str, fto: FunctionToOptimize, From 91fe6a717c0f5a01da316eebdf2c474c8ca1451f Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Fri, 16 May 2025 11:24:21 -0700 Subject: [PATCH 12/29] Updated implementation of FunctionOptimizer.reformat_code_and_helpers --- codeflash/optimization/function_optimizer.py | 49 ++++++++++++++------ 1 file changed, 36 insertions(+), 13 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 56124a9cb..55c4fac9e 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -36,7 +36,7 @@ N_TESTS_TO_GENERATE, TOTAL_LOOPING_TIME, ) -from codeflash.code_utils.formatter import format_code, sort_imports +from codeflash.code_utils.formatter import format_code, get_modification_code_ranges, sort_imports from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test from codeflash.code_utils.line_profile_utils import add_decorator_imports from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests @@ -581,25 +581,48 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, f.write(original_helper_code[module_abspath]) def reformat_code_and_helpers( - self, helper_functions: list[FunctionSource], path: Path, original_code: str + self, helper_functions: list[FunctionSource], fto_path: Path, original_code: str ) -> tuple[str, dict[Path, str]]: should_sort_imports = not self.args.disable_imports_sorting if should_sort_imports and isort.code(original_code) != original_code: should_sort_imports = False - - new_code = format_code(self.args.formatter_cmds, path) - if should_sort_imports: - new_code = sort_imports(new_code) - + + paths = [fto_path] + list({hf.file_path for hf in helper_functions}) + new_target_code = None new_helper_code: dict[Path, str] = {} - helper_functions_paths = {hf.file_path for hf in helper_functions} - for module_abspath in helper_functions_paths: - formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath) + for i, path in enumerate(paths): + unformatted_code = path.read_text(encoding="utf8") + code_context = self.get_code_optimization_context() + code_ranges_unformatted = get_modification_code_ranges(unformatted_code, self.function_to_optimize, code_context) + + formatted_code = format_code(self.args.formatter_cmds, path) + # Note: We do not need to refresh the code_context because we only use it to refer to names of original + # functions (even before optimization was applied) and filepaths, none of which is changing. + code_ranges_formatted = get_modification_code_ranges(formatted_code, self.function_to_optimize, code_context) + + if len(code_ranges_formatted != code_ranges_unformatted): + raise Exception("Formatting had unexpected effects on code ranges") + + # It is important to sort in descending order so that the index arithmetic remains simple as we modify new_code + code_ranges_unformatted.sort(key=lambda range: range[0], reverse=True) + code_ranges_formatted.sort(key=lambda range: range[0], reverse=True) + new_code = unformatted_code + for range_0, range_1 in zip(code_ranges_unformatted, code_ranges_formatted): + range_0_0, range_0_1 = range_0 + range_1_0, range_1_1 = range_1 + new_code = new_code[:range_0_0] + new_code[range_1_0:range_1_1 + 1] + new_code[range_0_1 + 1] + + path.write_text(new_code, encoding="utf8") + if should_sort_imports: - formatted_helper_code = sort_imports(formatted_helper_code) - new_helper_code[module_abspath] = formatted_helper_code + new_code = sort_imports(new_code) + + if i == 0: + new_target_code = new_code + else: + new_helper_code[path] = new_code - return new_code, new_helper_code + return new_target_code, new_helper_code def replace_function_and_helpers_with_optimized_code( self, code_context: CodeOptimizationContext, optimized_code: str From b89622cfffa22a47204a49f1b9c4a69df664ef09 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Fri, 16 May 2025 11:35:52 -0700 Subject: [PATCH 13/29] Fixed a few bugs in reformat_code_and_helpers --- codeflash/optimization/function_optimizer.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 55c4fac9e..312e9ca2d 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -9,7 +9,7 @@ import uuid from collections import defaultdict, deque from pathlib import Path -from typing import TYPE_CHECKING +from typing import cast, TYPE_CHECKING import isort import libcst as cst @@ -592,7 +592,10 @@ def reformat_code_and_helpers( new_helper_code: dict[Path, str] = {} for i, path in enumerate(paths): unformatted_code = path.read_text(encoding="utf8") - code_context = self.get_code_optimization_context() + code_context_result = self.get_code_optimization_context() + if code_context_result.is_failure(): + raise Exception("Unable to generate code context for formatting purposes") + code_context = cast(CodeOptimizationContext, code_context_result.unwrap()) code_ranges_unformatted = get_modification_code_ranges(unformatted_code, self.function_to_optimize, code_context) formatted_code = format_code(self.args.formatter_cmds, path) @@ -600,17 +603,15 @@ def reformat_code_and_helpers( # functions (even before optimization was applied) and filepaths, none of which is changing. code_ranges_formatted = get_modification_code_ranges(formatted_code, self.function_to_optimize, code_context) - if len(code_ranges_formatted != code_ranges_unformatted): + if len(code_ranges_formatted) != len(code_ranges_unformatted): raise Exception("Formatting had unexpected effects on code ranges") # It is important to sort in descending order so that the index arithmetic remains simple as we modify new_code - code_ranges_unformatted.sort(key=lambda range: range[0], reverse=True) - code_ranges_formatted.sort(key=lambda range: range[0], reverse=True) + code_ranges_unformatted.sort(key=lambda range: range.start, reverse=True) + code_ranges_formatted.sort(key=lambda range: range.start, reverse=True) new_code = unformatted_code for range_0, range_1 in zip(code_ranges_unformatted, code_ranges_formatted): - range_0_0, range_0_1 = range_0 - range_1_0, range_1_1 = range_1 - new_code = new_code[:range_0_0] + new_code[range_1_0:range_1_1 + 1] + new_code[range_0_1 + 1] + new_code = new_code[:range_0.start] + new_code[range_1.start:range_1.end + 1] + new_code[range_0.end + 1] path.write_text(new_code, encoding="utf8") From 5a9265c9c6507ae8f5d7c3f4c55ac77b63430392 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Fri, 16 May 2025 11:40:35 -0700 Subject: [PATCH 14/29] More codeposition bugs --- codeflash/optimization/function_optimizer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 312e9ca2d..aad13a69c 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -607,11 +607,13 @@ def reformat_code_and_helpers( raise Exception("Formatting had unexpected effects on code ranges") # It is important to sort in descending order so that the index arithmetic remains simple as we modify new_code - code_ranges_unformatted.sort(key=lambda range: range.start, reverse=True) - code_ranges_formatted.sort(key=lambda range: range.start, reverse=True) + code_ranges_unformatted.sort(key=lambda range: range[0].line, reverse=True) + code_ranges_formatted.sort(key=lambda range: range[0].line, reverse=True) new_code = unformatted_code for range_0, range_1 in zip(code_ranges_unformatted, code_ranges_formatted): - new_code = new_code[:range_0.start] + new_code[range_1.start:range_1.end + 1] + new_code[range_0.end + 1] + range_0_0, range_0_1 = range_0 + range_1_0, range_1_1 = range_1 + new_code = new_code[:range_0_0.line] + new_code[range_1_0.line:range_1_1.line + 1] + new_code[range_0_1.line + 1] path.write_text(new_code, encoding="utf8") From b903d1aa31c1003c5383f736687a757d3d3ce460 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Fri, 16 May 2025 11:45:12 -0700 Subject: [PATCH 15/29] Issue with splicing --- codeflash/code_utils/code_replacer.py | 4 ++-- codeflash/code_utils/formatter.py | 2 +- codeflash/optimization/function_optimizer.py | 6 +++--- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 279b1df65..ae973b286 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -52,7 +52,7 @@ def __init__( self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list) self.current_class = None self.modified_init_functions: dict[str, cst.FunctionDef] = {} - self.modification_code_ranges: list[tuple[int, int]] = [] + self.modification_code_range_lines: list[tuple[int, int]] = [] def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: modification = True @@ -72,7 +72,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: if modification: pos = self.get_metadata(cst.metadata.PositionProvider, node) - self.modification_code_ranges.append((pos.start, pos.end)) + self.modification_code_range_lines.append((pos.start.line, pos.end.line)) return False diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index ebd178026..71ef5773b 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -90,4 +90,4 @@ def get_modification_code_ranges( module = cst.metadata.MetadataWrapper(cst.parse_module(modified_code)) visitor = OptimFunctionCollector(code_context.preexisting_objects, parsed_function_names) module.visit(visitor) - return visitor.modification_code_ranges + return visitor.modification_code_range_lines diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index aad13a69c..16bf3b8da 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -607,13 +607,13 @@ def reformat_code_and_helpers( raise Exception("Formatting had unexpected effects on code ranges") # It is important to sort in descending order so that the index arithmetic remains simple as we modify new_code - code_ranges_unformatted.sort(key=lambda range: range[0].line, reverse=True) - code_ranges_formatted.sort(key=lambda range: range[0].line, reverse=True) + code_ranges_unformatted.sort(key=lambda range: range[0], reverse=True) + code_ranges_formatted.sort(key=lambda range: range[0], reverse=True) new_code = unformatted_code for range_0, range_1 in zip(code_ranges_unformatted, code_ranges_formatted): range_0_0, range_0_1 = range_0 range_1_0, range_1_1 = range_1 - new_code = new_code[:range_0_0.line] + new_code[range_1_0.line:range_1_1.line + 1] + new_code[range_0_1.line + 1] + new_code = new_code[:range_0_0] + new_code[range_1_0:range_1_1 + 1] + new_code[range_0_1 + 1] path.write_text(new_code, encoding="utf8") From 16ca27e27b1b2e3ae2098774ab05a50322325371 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Fri, 16 May 2025 11:56:42 -0700 Subject: [PATCH 16/29] Fixing more bugs, testing live... Also updated CODEFLASH_DISABLE_TELEMETRY behavior --- codeflash/main.py | 14 +++++++++----- codeflash/optimization/function_optimizer.py | 8 +++++--- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/codeflash/main.py b/codeflash/main.py index 78a66acf1..2d57ae9a1 100644 --- a/codeflash/main.py +++ b/codeflash/main.py @@ -22,8 +22,10 @@ def main() -> None: ) args = parse_args() + disable_telemetry_env = os.environ.get("CODEFLASH_DISABLE_TELEMETRY", "").lower() in {"true", "t", "1", "yes", "y"} + if args.command: - disable_telemetry = os.environ.get("CODEFLASH_DISABLE_TELEMETRY", "").lower() in {"true", "t", "1", "yes", "y"} + disable_telemetry = disable_telemetry_env if (not disable_telemetry) and args.config_file and Path.exists(args.config_file): pyproject_config, _ = parse_config_file(args.config_file) disable_telemetry = pyproject_config.get("disable_telemetry", False) @@ -32,14 +34,16 @@ def main() -> None: args.func() elif args.verify_setup: args = process_pyproject_config(args) - init_sentry(not args.disable_telemetry, exclude_errors=True) - posthog_cf.initialize_posthog(not args.disable_telemetry) + disable_telemetry = args.disable_telemetry or disable_telemetry_env + init_sentry(not disable_telemetry, exclude_errors=True) + posthog_cf.initialize_posthog(not disable_telemetry) ask_run_end_to_end_test(args) else: args = process_pyproject_config(args) args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(args) - init_sentry(not args.disable_telemetry, exclude_errors=True) - posthog_cf.initialize_posthog(not args.disable_telemetry) + disable_telemetry = args.disable_telemetry or disable_telemetry_env + init_sentry(not disable_telemetry, exclude_errors=True) + posthog_cf.initialize_posthog(not disable_telemetry) optimizer.run_with_args(args) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 16bf3b8da..1665825f4 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -609,12 +609,14 @@ def reformat_code_and_helpers( # It is important to sort in descending order so that the index arithmetic remains simple as we modify new_code code_ranges_unformatted.sort(key=lambda range: range[0], reverse=True) code_ranges_formatted.sort(key=lambda range: range[0], reverse=True) - new_code = unformatted_code + formatted_code_lines = formatted_code.split("\n") + new_code_lines = unformatted_code.split("\n") for range_0, range_1 in zip(code_ranges_unformatted, code_ranges_formatted): range_0_0, range_0_1 = range_0 range_1_0, range_1_1 = range_1 - new_code = new_code[:range_0_0] + new_code[range_1_0:range_1_1 + 1] + new_code[range_0_1 + 1] - + new_code_lines = new_code_lines[:range_0_0] + formatted_code_lines[range_1_0:range_1_1 + 1] + new_code_lines[range_0_1 + 1:] + new_code = "\n".join(new_code_lines) + breakpoint() path.write_text(new_code, encoding="utf8") if should_sort_imports: From 59e36679eff653717bca57915cd8499368030472 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Fri, 16 May 2025 12:01:58 -0700 Subject: [PATCH 17/29] Got it functional --- codeflash/optimization/function_optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 1665825f4..e80f59f4c 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -616,7 +616,6 @@ def reformat_code_and_helpers( range_1_0, range_1_1 = range_1 new_code_lines = new_code_lines[:range_0_0] + formatted_code_lines[range_1_0:range_1_1 + 1] + new_code_lines[range_0_1 + 1:] new_code = "\n".join(new_code_lines) - breakpoint() path.write_text(new_code, encoding="utf8") if should_sort_imports: From c3b80630290a0abad2835ef59e8ba0a34b60d290 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Fri, 16 May 2025 14:04:43 -0700 Subject: [PATCH 18/29] Do not recalculate code_context when reformatting --- codeflash/optimization/function_optimizer.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index e80f59f4c..0b6ca11f8 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -9,7 +9,7 @@ import uuid from collections import defaultdict, deque from pathlib import Path -from typing import cast, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING import isort import libcst as cst @@ -302,7 +302,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: ) new_code, new_helper_code = self.reformat_code_and_helpers( - code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code + code_context, + explanation.file_path, + self.function_to_optimize_source_code, ) existing_tests = existing_tests_source_for( @@ -581,21 +583,24 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, f.write(original_helper_code[module_abspath]) def reformat_code_and_helpers( - self, helper_functions: list[FunctionSource], fto_path: Path, original_code: str + self, + code_context: CodeOptimizationContext, + fto_path: Path, + original_code: str, ) -> tuple[str, dict[Path, str]]: should_sort_imports = not self.args.disable_imports_sorting if should_sort_imports and isort.code(original_code) != original_code: should_sort_imports = False + helper_functions = code_context.helper_functions + paths = [fto_path] + list({hf.file_path for hf in helper_functions}) new_target_code = None new_helper_code: dict[Path, str] = {} for i, path in enumerate(paths): unformatted_code = path.read_text(encoding="utf8") - code_context_result = self.get_code_optimization_context() - if code_context_result.is_failure(): - raise Exception("Unable to generate code context for formatting purposes") - code_context = cast(CodeOptimizationContext, code_context_result.unwrap()) + # TODO(zomglings): code_context.preexisting_objects doesn't read all functions in the old file. We should add that to context + # separately. That's a much bigger change. code_ranges_unformatted = get_modification_code_ranges(unformatted_code, self.function_to_optimize, code_context) formatted_code = format_code(self.args.formatter_cmds, path) From 6dd72cf9a33811faf6bf2607859dd12a28792a5a Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Fri, 16 May 2025 21:14:20 -0700 Subject: [PATCH 19/29] Correct calculation of all preexisting "function" symbols for formatting purposes This is done on a per-path basis. --- codeflash/code_utils/formatter.py | 9 +++--- codeflash/optimization/function_optimizer.py | 34 ++++++++++++++------ 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 71ef5773b..7632ecda1 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -11,7 +11,7 @@ from codeflash.cli_cmds.console import console, logger from codeflash.code_utils.code_replacer import OptimFunctionCollector from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import CodeOptimizationContext +from codeflash.models.models import FunctionParent, FunctionSource if TYPE_CHECKING: from pathlib import Path @@ -64,14 +64,15 @@ def sort_imports(code: str) -> str: def get_modification_code_ranges( modified_code: str, fto: FunctionToOptimize, - code_context: CodeOptimizationContext, + preexisting_functions: set[tuple[str, tuple[FunctionParent,...]]], + helper_functions: list[FunctionSource], ) -> list[tuple[int, int]]: """ Returns the line number of modified and new functions in a string containing containing the code in a fully modified file. """ modified_functions = set() modified_functions.add(fto.qualified_name) - for helper_function in code_context.helper_functions: + for helper_function in helper_functions: if helper_function.jedi_definition.type != "class": modified_functions.add(helper_function.qualified_name) @@ -88,6 +89,6 @@ def get_modification_code_ranges( parsed_function_names.add((class_name, function_name)) module = cst.metadata.MetadataWrapper(cst.parse_module(modified_code)) - visitor = OptimFunctionCollector(code_context.preexisting_objects, parsed_function_names) + visitor = OptimFunctionCollector(preexisting_functions, parsed_function_names) module.visit(visitor) return visitor.modification_code_range_lines diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 0b6ca11f8..a7a0a9502 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -22,6 +22,7 @@ from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, progress_bar from codeflash.code_utils import env_utils +from codeflash.code_utils.code_extractor import find_preexisting_objects from codeflash.code_utils.code_replacer import replace_function_definitions_in_module from codeflash.code_utils.code_utils import ( cleanup_paths, @@ -49,6 +50,7 @@ BestOptimization, CodeOptimizationContext, FunctionCalledInTest, + FunctionParent, GeneratedTests, GeneratedTestsList, OptimizationSet, @@ -297,12 +299,20 @@ def optimize_function(self) -> Result[BestOptimization, str]: self.log_successful_optimization(explanation, generated_tests, exp_type) + # xylophone + preexisting_functions_by_filepath: dict[Path, list[str]] = {} + filepaths_to_inspect = [self.function_to_optimize.file_path] + list({helper.file_path for helper in code_context.helper_functions}) + for filepath in filepaths_to_inspect: + source_code = filepath.read_text(encoding="utf8") + preexisting_functions_by_filepath[filepath] = find_preexisting_objects(source_code) + self.replace_function_and_helpers_with_optimized_code( code_context=code_context, optimized_code=best_optimization.candidate.source_code ) new_code, new_helper_code = self.reformat_code_and_helpers( - code_context, + preexisting_functions_by_filepath, + code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code, ) @@ -584,7 +594,8 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, def reformat_code_and_helpers( self, - code_context: CodeOptimizationContext, + preexisting_functions_by_filepath: dict[Path, set[tuple[str, tuple[FunctionParent,...]]]], + helper_functions: list[FunctionSource], fto_path: Path, original_code: str, ) -> tuple[str, dict[Path, str]]: @@ -592,21 +603,26 @@ def reformat_code_and_helpers( if should_sort_imports and isort.code(original_code) != original_code: should_sort_imports = False - helper_functions = code_context.helper_functions - paths = [fto_path] + list({hf.file_path for hf in helper_functions}) new_target_code = None new_helper_code: dict[Path, str] = {} for i, path in enumerate(paths): unformatted_code = path.read_text(encoding="utf8") - # TODO(zomglings): code_context.preexisting_objects doesn't read all functions in the old file. We should add that to context - # separately. That's a much bigger change. - code_ranges_unformatted = get_modification_code_ranges(unformatted_code, self.function_to_optimize, code_context) - + code_ranges_unformatted = get_modification_code_ranges( + unformatted_code, + self.function_to_optimize, + preexisting_functions_by_filepath[path], + helper_functions, + ) formatted_code = format_code(self.args.formatter_cmds, path) # Note: We do not need to refresh the code_context because we only use it to refer to names of original # functions (even before optimization was applied) and filepaths, none of which is changing. - code_ranges_formatted = get_modification_code_ranges(formatted_code, self.function_to_optimize, code_context) + code_ranges_formatted = get_modification_code_ranges( + formatted_code, + self.function_to_optimize, + preexisting_functions_by_filepath[path], + helper_functions, + ) if len(code_ranges_formatted) != len(code_ranges_unformatted): raise Exception("Formatting had unexpected effects on code ranges") From eae756a954c276cc8827ad075ea6be2a90b1b36e Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Fri, 16 May 2025 21:32:34 -0700 Subject: [PATCH 20/29] Clarified docstring for get_modification_code_ranges ... and started adding tests for that function. --- codeflash/code_utils/formatter.py | 3 +-- tests/test_formatter.py | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 7632ecda1..1eda8c994 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -60,7 +60,6 @@ def sort_imports(code: str) -> str: return sorted_code -# TODO(zomglings): Write unit tests. def get_modification_code_ranges( modified_code: str, fto: FunctionToOptimize, @@ -68,7 +67,7 @@ def get_modification_code_ranges( helper_functions: list[FunctionSource], ) -> list[tuple[int, int]]: """ - Returns the line number of modified and new functions in a string containing containing the code in a fully modified file. + Returns the starting and ending line numbers of modified and new functions in a file containing edits. """ modified_functions = set() modified_functions.add(fto.qualified_name) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 5c0a91c38..a0ddadce2 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -5,7 +5,8 @@ import pytest from codeflash.code_utils.config_parser import parse_config_file -from codeflash.code_utils.formatter import format_code, sort_imports +from codeflash.code_utils.formatter import format_code, get_modification_code_ranges, sort_imports +from codeflash.discovery.functions_to_optimize import FunctionToOptimize def test_remove_duplicate_imports(): @@ -209,3 +210,15 @@ def foo(): tmp_path = tmp.name with pytest.raises(FileNotFoundError): format_code(formatter_cmds=["exit 1"], path=Path(tmp_path)) + +def test_get_modification_code_ranges_self_contained_fto(): + modified_code = """ +def hello(name): + print(f"Hello, {{name}}") +""" + + fto = FunctionToOptimize(function_name="hello", file_path=Path("hello.py"), parents=[]) + code_ranges = get_modification_code_ranges(modified_code, fto, set(), []) + + assert len(code_ranges) == 1 + assert code_ranges[0] == (2, 3) \ No newline at end of file From 05817f9c26e45cc87fd9f0198fc8fb3e66b382ff Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Sat, 17 May 2025 09:29:22 -0700 Subject: [PATCH 21/29] removed xylophone --- codeflash/optimization/function_optimizer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index a7a0a9502..1a04a6344 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -299,7 +299,6 @@ def optimize_function(self) -> Result[BestOptimization, str]: self.log_successful_optimization(explanation, generated_tests, exp_type) - # xylophone preexisting_functions_by_filepath: dict[Path, list[str]] = {} filepaths_to_inspect = [self.function_to_optimize.file_path] + list({helper.file_path for helper in code_context.helper_functions}) for filepath in filepaths_to_inspect: From 9efdc21ad8f1c35e2286650ce185b368b1ac1379 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Mon, 19 May 2025 09:56:33 -0700 Subject: [PATCH 22/29] Added test for get_modification_code_ranges. --- tests/test_formatter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_formatter.py b/tests/test_formatter.py index a0ddadce2..b1fcbd615 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -2,11 +2,13 @@ import tempfile from pathlib import Path +from jedi.api.classes import Name import pytest from codeflash.code_utils.config_parser import parse_config_file from codeflash.code_utils.formatter import format_code, get_modification_code_ranges, sort_imports from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.models.models import FunctionSource def test_remove_duplicate_imports(): @@ -221,4 +223,4 @@ def hello(name): code_ranges = get_modification_code_ranges(modified_code, fto, set(), []) assert len(code_ranges) == 1 - assert code_ranges[0] == (2, 3) \ No newline at end of file + assert code_ranges[0] == (2, 3) From 9face63e07112316f3f24fc23c2d5bb66400b6ad Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Wed, 21 May 2025 14:17:03 -0700 Subject: [PATCH 23/29] "ruff check --fix" --- codeflash/code_utils/formatter.py | 5 ++--- codeflash/optimization/function_optimizer.py | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 1eda8c994..dcecfd127 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -66,15 +66,14 @@ def get_modification_code_ranges( preexisting_functions: set[tuple[str, tuple[FunctionParent,...]]], helper_functions: list[FunctionSource], ) -> list[tuple[int, int]]: - """ - Returns the starting and ending line numbers of modified and new functions in a file containing edits. + """Returns the starting and ending line numbers of modified and new functions in a file containing edits. """ modified_functions = set() modified_functions.add(fto.qualified_name) for helper_function in helper_functions: if helper_function.jedi_definition.type != "class": modified_functions.add(helper_function.qualified_name) - + parsed_function_names = set() for original_function_name in modified_functions: if original_function_name.count(".") == 0: diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 658a597de..04373c930 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -8,7 +8,7 @@ import uuid from collections import defaultdict, deque from pathlib import Path -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING import isort import libcst as cst @@ -607,7 +607,7 @@ def reformat_code_and_helpers( should_sort_imports = not self.args.disable_imports_sorting if should_sort_imports and isort.code(original_code) != original_code: should_sort_imports = False - + paths = [fto_path] + list({hf.file_path for hf in helper_functions}) new_target_code = None new_helper_code: dict[Path, str] = {} @@ -646,7 +646,7 @@ def reformat_code_and_helpers( if should_sort_imports: new_code = sort_imports(new_code) - + if i == 0: new_target_code = new_code else: From 0615f5525cafe75a0364bc53644157d57510d5bf Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Wed, 21 May 2025 14:20:49 -0700 Subject: [PATCH 24/29] Fixed some more ruff check issues --- codeflash/code_utils/formatter.py | 6 +++--- codeflash/optimization/function_optimizer.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index dcecfd127..fed28b2be 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -10,12 +10,12 @@ from codeflash.cli_cmds.console import console, logger from codeflash.code_utils.code_replacer import OptimFunctionCollector -from codeflash.discovery.functions_to_optimize import FunctionToOptimize -from codeflash.models.models import FunctionParent, FunctionSource if TYPE_CHECKING: from pathlib import Path + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.models.models import FunctionParent, FunctionSource def format_code(formatter_cmds: list[str], path: Path) -> str: # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution @@ -66,7 +66,7 @@ def get_modification_code_ranges( preexisting_functions: set[tuple[str, tuple[FunctionParent,...]]], helper_functions: list[FunctionSource], ) -> list[tuple[int, int]]: - """Returns the starting and ending line numbers of modified and new functions in a file containing edits. + """Returns the starting and ending line numbers of modified and new functions in a file with edits. """ modified_functions = set() modified_functions.add(fto.qualified_name) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 04373c930..38880ef9c 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -633,8 +633,8 @@ def reformat_code_and_helpers( raise Exception("Formatting had unexpected effects on code ranges") # It is important to sort in descending order so that the index arithmetic remains simple as we modify new_code - code_ranges_unformatted.sort(key=lambda range: range[0], reverse=True) - code_ranges_formatted.sort(key=lambda range: range[0], reverse=True) + code_ranges_unformatted.sort(key=lambda r: r[0], reverse=True) + code_ranges_formatted.sort(key=lambda r: r[0], reverse=True) formatted_code_lines = formatted_code.split("\n") new_code_lines = unformatted_code.split("\n") for range_0, range_1 in zip(code_ranges_unformatted, code_ranges_formatted): From af4df4a8fab530bb047f2be86cd6a12748622743 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Wed, 21 May 2025 14:23:29 -0700 Subject: [PATCH 25/29] "ruff format" --- codeflash/code_utils/code_replacer.py | 1 + codeflash/code_utils/formatter.py | 7 ++++--- codeflash/optimization/function_optimizer.py | 22 ++++++++++---------- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 29412341a..1030608ba 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -164,6 +164,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c node = node.with_changes(body=(*self.new_functions, *node.body)) return node + def replace_functions_in_file( source_code: str, original_function_names: list[str], diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index fed28b2be..47c048106 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -17,6 +17,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import FunctionParent, FunctionSource + def format_code(formatter_cmds: list[str], path: Path) -> str: # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution formatter_name = formatter_cmds[0].lower() @@ -60,14 +61,14 @@ def sort_imports(code: str) -> str: return sorted_code + def get_modification_code_ranges( modified_code: str, fto: FunctionToOptimize, - preexisting_functions: set[tuple[str, tuple[FunctionParent,...]]], + preexisting_functions: set[tuple[str, tuple[FunctionParent, ...]]], helper_functions: list[FunctionSource], ) -> list[tuple[int, int]]: - """Returns the starting and ending line numbers of modified and new functions in a file with edits. - """ + """Returns the starting and ending line numbers of modified and new functions in a file with edits.""" modified_functions = set() modified_functions.add(fto.qualified_name) for helper_function in helper_functions: diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 38880ef9c..d0f6ffe88 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -299,7 +299,9 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 self.log_successful_optimization(explanation, generated_tests, exp_type) preexisting_functions_by_filepath: dict[Path, list[str]] = {} - filepaths_to_inspect = [self.function_to_optimize.file_path] + list({helper.file_path for helper in code_context.helper_functions}) + filepaths_to_inspect = [self.function_to_optimize.file_path] + list( + {helper.file_path for helper in code_context.helper_functions} + ) for filepath in filepaths_to_inspect: source_code = filepath.read_text(encoding="utf8") preexisting_functions_by_filepath[filepath] = find_preexisting_objects(source_code) @@ -599,7 +601,7 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, def reformat_code_and_helpers( self, - preexisting_functions_by_filepath: dict[Path, set[tuple[str, tuple[FunctionParent,...]]]], + preexisting_functions_by_filepath: dict[Path, set[tuple[str, tuple[FunctionParent, ...]]]], helper_functions: list[FunctionSource], fto_path: Path, original_code: str, @@ -614,19 +616,13 @@ def reformat_code_and_helpers( for i, path in enumerate(paths): unformatted_code = path.read_text(encoding="utf8") code_ranges_unformatted = get_modification_code_ranges( - unformatted_code, - self.function_to_optimize, - preexisting_functions_by_filepath[path], - helper_functions, + unformatted_code, self.function_to_optimize, preexisting_functions_by_filepath[path], helper_functions ) formatted_code = format_code(self.args.formatter_cmds, path) # Note: We do not need to refresh the code_context because we only use it to refer to names of original # functions (even before optimization was applied) and filepaths, none of which is changing. code_ranges_formatted = get_modification_code_ranges( - formatted_code, - self.function_to_optimize, - preexisting_functions_by_filepath[path], - helper_functions, + formatted_code, self.function_to_optimize, preexisting_functions_by_filepath[path], helper_functions ) if len(code_ranges_formatted) != len(code_ranges_unformatted): @@ -640,7 +636,11 @@ def reformat_code_and_helpers( for range_0, range_1 in zip(code_ranges_unformatted, code_ranges_formatted): range_0_0, range_0_1 = range_0 range_1_0, range_1_1 = range_1 - new_code_lines = new_code_lines[:range_0_0] + formatted_code_lines[range_1_0:range_1_1 + 1] + new_code_lines[range_0_1 + 1:] + new_code_lines = ( + new_code_lines[:range_0_0] + + formatted_code_lines[range_1_0 : range_1_1 + 1] + + new_code_lines[range_0_1 + 1 :] + ) new_code = "\n".join(new_code_lines) path.write_text(new_code, encoding="utf8") From cf4a6659be1d76ba5bbbadf102451c2f44c14bc5 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Wed, 21 May 2025 14:24:35 -0700 Subject: [PATCH 26/29] more fixes for "ruff check"... with --unsafe-fixes --- codeflash/optimization/function_optimizer.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index d0f6ffe88..444d80351 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -49,7 +49,6 @@ BestOptimization, CodeOptimizationContext, FunctionCalledInTest, - FunctionParent, GeneratedTests, GeneratedTestsList, OptimizationSet, @@ -79,7 +78,7 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.either import Result - from codeflash.models.models import BenchmarkKey, CoverageData, FunctionSource, OptimizedCandidate + from codeflash.models.models import BenchmarkKey, CoverageData, FunctionParent, FunctionSource, OptimizedCandidate from codeflash.verification.verification_utils import TestConfig @@ -299,9 +298,7 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 self.log_successful_optimization(explanation, generated_tests, exp_type) preexisting_functions_by_filepath: dict[Path, list[str]] = {} - filepaths_to_inspect = [self.function_to_optimize.file_path] + list( - {helper.file_path for helper in code_context.helper_functions} - ) + filepaths_to_inspect = [self.function_to_optimize.file_path, *list({helper.file_path for helper in code_context.helper_functions})] for filepath in filepaths_to_inspect: source_code = filepath.read_text(encoding="utf8") preexisting_functions_by_filepath[filepath] = find_preexisting_objects(source_code) @@ -610,7 +607,7 @@ def reformat_code_and_helpers( if should_sort_imports and isort.code(original_code) != original_code: should_sort_imports = False - paths = [fto_path] + list({hf.file_path for hf in helper_functions}) + paths = [fto_path, *list({hf.file_path for hf in helper_functions})] new_target_code = None new_helper_code: dict[Path, str] = {} for i, path in enumerate(paths): From d43862ec44bcddc7c223e3a9af0bc421dd4660ae Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Wed, 21 May 2025 14:25:08 -0700 Subject: [PATCH 27/29] ruff format... ... AGAIN --- codeflash/optimization/function_optimizer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 444d80351..4a6297772 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -298,7 +298,10 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 self.log_successful_optimization(explanation, generated_tests, exp_type) preexisting_functions_by_filepath: dict[Path, list[str]] = {} - filepaths_to_inspect = [self.function_to_optimize.file_path, *list({helper.file_path for helper in code_context.helper_functions})] + filepaths_to_inspect = [ + self.function_to_optimize.file_path, + *list({helper.file_path for helper in code_context.helper_functions}), + ] for filepath in filepaths_to_inspect: source_code = filepath.read_text(encoding="utf8") preexisting_functions_by_filepath[filepath] = find_preexisting_objects(source_code) From 6ec5ef0fccde77a7bcaab3c87139ed1e739462aa Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Wed, 21 May 2025 14:28:55 -0700 Subject: [PATCH 28/29] That should be hte last of the ruff stuff --- codeflash/code_utils/formatter.py | 2 +- codeflash/optimization/function_optimizer.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 47c048106..1ae04d08b 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -68,7 +68,7 @@ def get_modification_code_ranges( preexisting_functions: set[tuple[str, tuple[FunctionParent, ...]]], helper_functions: list[FunctionSource], ) -> list[tuple[int, int]]: - """Returns the starting and ending line numbers of modified and new functions in a file with edits.""" + """Return the starting and ending line numbers of modified and new functions in a file with edits.""" modified_functions = set() modified_functions.add(fto.qualified_name) for helper_function in helper_functions: diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 4a6297772..f97caf309 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -82,6 +82,10 @@ from codeflash.verification.verification_utils import TestConfig +class FunctionOptimizerError(Exception): + pass + + class FunctionOptimizer: def __init__( self, @@ -626,7 +630,7 @@ def reformat_code_and_helpers( ) if len(code_ranges_formatted) != len(code_ranges_unformatted): - raise Exception("Formatting had unexpected effects on code ranges") + raise FunctionOptimizerError("Formatting had unexpected effects on code ranges") # It is important to sort in descending order so that the index arithmetic remains simple as we modify new_code code_ranges_unformatted.sort(key=lambda r: r[0], reverse=True) From a468ba7dbab6e716ccb23ef451a1aa6c29bf5936 Mon Sep 17 00:00:00 2001 From: Neeraj Kashyap Date: Wed, 21 May 2025 15:44:05 -0700 Subject: [PATCH 29/29] Added a test for FunctionOptimizer.reformat_code_and_helpers --- tests/test_function_optimizer.py | 68 ++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 tests/test_function_optimizer.py diff --git a/tests/test_function_optimizer.py b/tests/test_function_optimizer.py new file mode 100644 index 000000000..a74327c31 --- /dev/null +++ b/tests/test_function_optimizer.py @@ -0,0 +1,68 @@ +import argparse +from pathlib import Path +import shutil +import tempfile + +import pytest + +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.optimization.function_optimizer import FunctionOptimizer +from codeflash.verification.verification_utils import TestConfig + +def test_bubble_sort_preserve_bad_formatting(): + """ + Test the bubble sort implementation in code_to_optimize/bubble_sort_preserve_bad_formatting_for_nonoptimized_code.py. + + This test sets the rubric for all other tests of formatting functionality. + """ + with tempfile.TemporaryDirectory() as test_dir_str: + test_dir = Path(test_dir_str) + target_path = test_dir / "target.py" + this_file = Path(__file__).resolve() + repo_root_dir = this_file.parent.parent + source_file = repo_root_dir / "code_to_optimize" / "bubble_sort_preserve_bad_formatting_for_nonoptimized_code.py" + shutil.copy2(source_file, target_path) + + original_content = source_file.read_text() + + function_to_optimize = FunctionToOptimize( + function_name="sorter", + file_path=target_path, + parents=[], + starting_line=None, + ending_line=None, + ) + test_cfg = TestConfig( + tests_root=test_dir, + project_root_path=test_dir, + test_framework="pytest", + tests_project_rootdir=test_dir, + ) + args = argparse.Namespace( + disable_imports_sorting=False, + formatter_cmds=["uvx ruff check --exit-zero --fix $file", "uvx ruff format $file"], + ) + optimizer = FunctionOptimizer( + function_to_optimize=function_to_optimize, + test_cfg=test_cfg, + args=args, + ) + + preexisting_functions_by_filepath = { + target_path: {("lol", tuple())}, + } + + # add a newline after the function definition + target_content = target_path.read_text() + target_content = target_content.replace("def sorter(arr):", "def sorter(arr):\n") + assert target_content != original_content + target_path.write_text(target_content) + + optimizer.reformat_code_and_helpers( + preexisting_functions_by_filepath=preexisting_functions_by_filepath, + helper_functions=[], + fto_path=target_path, + original_code=optimizer.function_to_optimize_source_code, + ) + content = target_path.read_text() + assert content == original_content