diff --git a/.gitignore b/.gitignore index 535acfb3e..b4a99e8c2 100644 --- a/.gitignore +++ b/.gitignore @@ -254,3 +254,5 @@ fabric.properties # Mac .DS_Store + +scratch/ diff --git a/code_to_optimize/bubble_sort_method_preserve_bad_formatting_for_nonoptimized_code.py b/code_to_optimize/bubble_sort_method_preserve_bad_formatting_for_nonoptimized_code.py new file mode 100644 index 000000000..29a00a922 --- /dev/null +++ b/code_to_optimize/bubble_sort_method_preserve_bad_formatting_for_nonoptimized_code.py @@ -0,0 +1,38 @@ +import sys + + +def lol(): + print( "lol" ) + + + + + + + + + +class BubbleSorter: + def __init__(self, x=0): + self.x = x + + def lol(self): + print( "lol" ) + + + + + + + + + def sorter(self, arr): + print("codeflash stdout : BubbleSorter.sorter() called") + for i in range(len(arr)): + for j in range(len(arr) - 1): + if arr[j] > arr[j + 1]: + temp = arr[j] + arr[j] = arr[j + 1] + arr[j + 1] = temp + print("stderr test", file=sys.stderr) + return arr diff --git a/code_to_optimize/bubble_sort_preserve_bad_formatting_for_nonoptimized_code.py b/code_to_optimize/bubble_sort_preserve_bad_formatting_for_nonoptimized_code.py new file mode 100644 index 000000000..b506ddfbb --- /dev/null +++ b/code_to_optimize/bubble_sort_preserve_bad_formatting_for_nonoptimized_code.py @@ -0,0 +1,19 @@ +def lol(): + print( "lol" ) + + + + + + + +def sorter(arr): + print("codeflash stdout: Sorting list") + for i in range(len(arr)): + for j in range(len(arr) - 1): + if arr[j] > arr[j + 1]: + temp = arr[j] + arr[j] = arr[j + 1] + arr[j + 1] = temp + print(f"result: {arr}") + return arr diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index eb367bdfa..1030608ba 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -34,7 +34,7 @@ def normalize_code(code: str) -> str: class OptimFunctionCollector(cst.CSTVisitor): - METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,) + METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider, cst.metadata.PositionProvider) def __init__( self, @@ -52,8 +52,11 @@ def __init__( self.new_class_functions: dict[str, list[cst.FunctionDef]] = defaultdict(list) self.current_class = None self.modified_init_functions: dict[str, cst.FunctionDef] = {} + self.modification_code_range_lines: list[tuple[int, int]] = [] def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + modification = True + if (self.current_class, node.name.value) in self.function_names: self.modified_functions[(self.current_class, node.name.value)] = node elif self.current_class and node.name.value == "__init__": @@ -64,6 +67,13 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: and self.current_class is None ): self.new_functions.append(node) + else: + modification = False + + if modification: + pos = self.get_metadata(cst.metadata.PositionProvider, node) + self.modification_code_range_lines.append((pos.start.line, pos.end.line)) + return False def visit_ClassDef(self, node: cst.ClassDef) -> bool: diff --git a/codeflash/code_utils/formatter.py b/codeflash/code_utils/formatter.py index 875fd0a1f..1ae04d08b 100644 --- a/codeflash/code_utils/formatter.py +++ b/codeflash/code_utils/formatter.py @@ -6,12 +6,17 @@ from typing import TYPE_CHECKING import isort +import libcst as cst from codeflash.cli_cmds.console import console, logger +from codeflash.code_utils.code_replacer import OptimFunctionCollector if TYPE_CHECKING: from pathlib import Path + from codeflash.discovery.functions_to_optimize import FunctionToOptimize + from codeflash.models.models import FunctionParent, FunctionSource + def format_code(formatter_cmds: list[str], path: Path) -> str: # TODO: Only allow a particular whitelist of formatters here to prevent arbitrary code execution @@ -55,3 +60,34 @@ def sort_imports(code: str) -> str: return code # Fall back to original code if isort fails return sorted_code + + +def get_modification_code_ranges( + modified_code: str, + fto: FunctionToOptimize, + preexisting_functions: set[tuple[str, tuple[FunctionParent, ...]]], + helper_functions: list[FunctionSource], +) -> list[tuple[int, int]]: + """Return the starting and ending line numbers of modified and new functions in a file with edits.""" + modified_functions = set() + modified_functions.add(fto.qualified_name) + for helper_function in helper_functions: + if helper_function.jedi_definition.type != "class": + modified_functions.add(helper_function.qualified_name) + + parsed_function_names = set() + for original_function_name in modified_functions: + if original_function_name.count(".") == 0: + class_name, function_name = None, original_function_name + elif original_function_name.count(".") == 1: + class_name, function_name = original_function_name.split(".") + else: + msg = f"Unable to find {original_function_name}. Returning unchanged source code." + logger.error(msg) + continue + parsed_function_names.add((class_name, function_name)) + + module = cst.metadata.MetadataWrapper(cst.parse_module(modified_code)) + visitor = OptimFunctionCollector(preexisting_functions, parsed_function_names) + module.visit(visitor) + return visitor.modification_code_range_lines diff --git a/codeflash/main.py b/codeflash/main.py index 9eb22dde1..47072d510 100644 --- a/codeflash/main.py +++ b/codeflash/main.py @@ -4,6 +4,7 @@ solved problem, please reach out to us at careers@codeflash.ai. We're hiring! """ +import os from pathlib import Path from codeflash.cli_cmds.cli import parse_args, process_pyproject_config @@ -22,25 +23,29 @@ def main() -> None: CODEFLASH_LOGO, panel_args={"title": "https://codeflash.ai", "expand": False}, text_args={"style": "bold gold3"} ) args = parse_args() + + disable_telemetry_env = os.environ.get("CODEFLASH_DISABLE_TELEMETRY", "").lower() in {"true", "t", "1", "yes", "y"} + if args.command: - if args.config_file and Path.exists(args.config_file): + disable_telemetry = disable_telemetry_env + if (not disable_telemetry) and args.config_file and Path.exists(args.config_file): pyproject_config, _ = parse_config_file(args.config_file) disable_telemetry = pyproject_config.get("disable_telemetry", False) - else: - disable_telemetry = False init_sentry(not disable_telemetry, exclude_errors=True) posthog_cf.initialize_posthog(not disable_telemetry) args.func() elif args.verify_setup: args = process_pyproject_config(args) - init_sentry(not args.disable_telemetry, exclude_errors=True) - posthog_cf.initialize_posthog(not args.disable_telemetry) + disable_telemetry = args.disable_telemetry or disable_telemetry_env + init_sentry(not disable_telemetry, exclude_errors=True) + posthog_cf.initialize_posthog(not disable_telemetry) ask_run_end_to_end_test(args) else: args = process_pyproject_config(args) args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(args) - init_sentry(not args.disable_telemetry, exclude_errors=True) - posthog_cf.initialize_posthog(not args.disable_telemetry) + disable_telemetry = args.disable_telemetry or disable_telemetry_env + init_sentry(not disable_telemetry, exclude_errors=True) + posthog_cf.initialize_posthog(not disable_telemetry) optimizer.run_with_args(args) diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index d74f59ecc..f97caf309 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -21,6 +21,7 @@ from codeflash.benchmarking.utils import process_benchmark_data from codeflash.cli_cmds.console import code_print, console, logger, progress_bar from codeflash.code_utils import env_utils +from codeflash.code_utils.code_extractor import find_preexisting_objects from codeflash.code_utils.code_replacer import replace_function_definitions_in_module from codeflash.code_utils.code_utils import ( cleanup_paths, @@ -35,7 +36,7 @@ N_TESTS_TO_GENERATE, TOTAL_LOOPING_TIME, ) -from codeflash.code_utils.formatter import format_code, sort_imports +from codeflash.code_utils.formatter import format_code, get_modification_code_ranges, sort_imports from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test from codeflash.code_utils.line_profile_utils import add_decorator_imports from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests @@ -77,10 +78,14 @@ from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.either import Result - from codeflash.models.models import BenchmarkKey, CoverageData, FunctionSource, OptimizedCandidate + from codeflash.models.models import BenchmarkKey, CoverageData, FunctionParent, FunctionSource, OptimizedCandidate from codeflash.verification.verification_utils import TestConfig +class FunctionOptimizerError(Exception): + pass + + class FunctionOptimizer: def __init__( self, @@ -296,12 +301,24 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911 self.log_successful_optimization(explanation, generated_tests, exp_type) + preexisting_functions_by_filepath: dict[Path, list[str]] = {} + filepaths_to_inspect = [ + self.function_to_optimize.file_path, + *list({helper.file_path for helper in code_context.helper_functions}), + ] + for filepath in filepaths_to_inspect: + source_code = filepath.read_text(encoding="utf8") + preexisting_functions_by_filepath[filepath] = find_preexisting_objects(source_code) + self.replace_function_and_helpers_with_optimized_code( code_context=code_context, optimized_code=best_optimization.candidate.source_code ) new_code, new_helper_code = self.reformat_code_and_helpers( - code_context.helper_functions, explanation.file_path, self.function_to_optimize_source_code + preexisting_functions_by_filepath, + code_context.helper_functions, + explanation.file_path, + self.function_to_optimize_source_code, ) existing_tests = existing_tests_source_for( @@ -587,25 +604,59 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, f.write(helper_code) def reformat_code_and_helpers( - self, helper_functions: list[FunctionSource], path: Path, original_code: str + self, + preexisting_functions_by_filepath: dict[Path, set[tuple[str, tuple[FunctionParent, ...]]]], + helper_functions: list[FunctionSource], + fto_path: Path, + original_code: str, ) -> tuple[str, dict[Path, str]]: should_sort_imports = not self.args.disable_imports_sorting if should_sort_imports and isort.code(original_code) != original_code: should_sort_imports = False - new_code = format_code(self.args.formatter_cmds, path) - if should_sort_imports: - new_code = sort_imports(new_code) - + paths = [fto_path, *list({hf.file_path for hf in helper_functions})] + new_target_code = None new_helper_code: dict[Path, str] = {} - helper_functions_paths = {hf.file_path for hf in helper_functions} - for module_abspath in helper_functions_paths: - formatted_helper_code = format_code(self.args.formatter_cmds, module_abspath) + for i, path in enumerate(paths): + unformatted_code = path.read_text(encoding="utf8") + code_ranges_unformatted = get_modification_code_ranges( + unformatted_code, self.function_to_optimize, preexisting_functions_by_filepath[path], helper_functions + ) + formatted_code = format_code(self.args.formatter_cmds, path) + # Note: We do not need to refresh the code_context because we only use it to refer to names of original + # functions (even before optimization was applied) and filepaths, none of which is changing. + code_ranges_formatted = get_modification_code_ranges( + formatted_code, self.function_to_optimize, preexisting_functions_by_filepath[path], helper_functions + ) + + if len(code_ranges_formatted) != len(code_ranges_unformatted): + raise FunctionOptimizerError("Formatting had unexpected effects on code ranges") + + # It is important to sort in descending order so that the index arithmetic remains simple as we modify new_code + code_ranges_unformatted.sort(key=lambda r: r[0], reverse=True) + code_ranges_formatted.sort(key=lambda r: r[0], reverse=True) + formatted_code_lines = formatted_code.split("\n") + new_code_lines = unformatted_code.split("\n") + for range_0, range_1 in zip(code_ranges_unformatted, code_ranges_formatted): + range_0_0, range_0_1 = range_0 + range_1_0, range_1_1 = range_1 + new_code_lines = ( + new_code_lines[:range_0_0] + + formatted_code_lines[range_1_0 : range_1_1 + 1] + + new_code_lines[range_0_1 + 1 :] + ) + new_code = "\n".join(new_code_lines) + path.write_text(new_code, encoding="utf8") + if should_sort_imports: - formatted_helper_code = sort_imports(formatted_helper_code) - new_helper_code[module_abspath] = formatted_helper_code + new_code = sort_imports(new_code) + + if i == 0: + new_target_code = new_code + else: + new_helper_code[path] = new_code - return new_code, new_helper_code + return new_target_code, new_helper_code def replace_function_and_helpers_with_optimized_code( self, code_context: CodeOptimizationContext, optimized_code: str diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 5c0a91c38..b1fcbd615 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -2,10 +2,13 @@ import tempfile from pathlib import Path +from jedi.api.classes import Name import pytest from codeflash.code_utils.config_parser import parse_config_file -from codeflash.code_utils.formatter import format_code, sort_imports +from codeflash.code_utils.formatter import format_code, get_modification_code_ranges, sort_imports +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.models.models import FunctionSource def test_remove_duplicate_imports(): @@ -209,3 +212,15 @@ def foo(): tmp_path = tmp.name with pytest.raises(FileNotFoundError): format_code(formatter_cmds=["exit 1"], path=Path(tmp_path)) + +def test_get_modification_code_ranges_self_contained_fto(): + modified_code = """ +def hello(name): + print(f"Hello, {{name}}") +""" + + fto = FunctionToOptimize(function_name="hello", file_path=Path("hello.py"), parents=[]) + code_ranges = get_modification_code_ranges(modified_code, fto, set(), []) + + assert len(code_ranges) == 1 + assert code_ranges[0] == (2, 3) diff --git a/tests/test_function_optimizer.py b/tests/test_function_optimizer.py new file mode 100644 index 000000000..a74327c31 --- /dev/null +++ b/tests/test_function_optimizer.py @@ -0,0 +1,68 @@ +import argparse +from pathlib import Path +import shutil +import tempfile + +import pytest + +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.optimization.function_optimizer import FunctionOptimizer +from codeflash.verification.verification_utils import TestConfig + +def test_bubble_sort_preserve_bad_formatting(): + """ + Test the bubble sort implementation in code_to_optimize/bubble_sort_preserve_bad_formatting_for_nonoptimized_code.py. + + This test sets the rubric for all other tests of formatting functionality. + """ + with tempfile.TemporaryDirectory() as test_dir_str: + test_dir = Path(test_dir_str) + target_path = test_dir / "target.py" + this_file = Path(__file__).resolve() + repo_root_dir = this_file.parent.parent + source_file = repo_root_dir / "code_to_optimize" / "bubble_sort_preserve_bad_formatting_for_nonoptimized_code.py" + shutil.copy2(source_file, target_path) + + original_content = source_file.read_text() + + function_to_optimize = FunctionToOptimize( + function_name="sorter", + file_path=target_path, + parents=[], + starting_line=None, + ending_line=None, + ) + test_cfg = TestConfig( + tests_root=test_dir, + project_root_path=test_dir, + test_framework="pytest", + tests_project_rootdir=test_dir, + ) + args = argparse.Namespace( + disable_imports_sorting=False, + formatter_cmds=["uvx ruff check --exit-zero --fix $file", "uvx ruff format $file"], + ) + optimizer = FunctionOptimizer( + function_to_optimize=function_to_optimize, + test_cfg=test_cfg, + args=args, + ) + + preexisting_functions_by_filepath = { + target_path: {("lol", tuple())}, + } + + # add a newline after the function definition + target_content = target_path.read_text() + target_content = target_content.replace("def sorter(arr):", "def sorter(arr):\n") + assert target_content != original_content + target_path.write_text(target_content) + + optimizer.reformat_code_and_helpers( + preexisting_functions_by_filepath=preexisting_functions_by_filepath, + helper_functions=[], + fto_path=target_path, + original_code=optimizer.function_to_optimize_source_code, + ) + content = target_path.read_text() + assert content == original_content