diff --git a/codeflash/code_utils/git_worktree_utils.py b/codeflash/code_utils/git_worktree_utils.py index c960a8af1..c41637c8d 100644 --- a/codeflash/code_utils/git_worktree_utils.py +++ b/codeflash/code_utils/git_worktree_utils.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json import subprocess import tempfile import time @@ -9,15 +8,12 @@ from typing import TYPE_CHECKING, Optional import git -from filelock import FileLock from codeflash.cli_cmds.console import logger from codeflash.code_utils.compat import codeflash_cache_dir from codeflash.code_utils.git_utils import check_running_in_git_repo, git_root_dir if TYPE_CHECKING: - from typing import Any - from git import Repo @@ -100,56 +96,15 @@ def get_patches_dir_for_project() -> Path: return Path(patches_dir / project_id) -def get_patches_metadata() -> dict[str, Any]: - project_patches_dir = get_patches_dir_for_project() - meta_file = project_patches_dir / "metadata.json" - if meta_file.exists(): - with meta_file.open("r", encoding="utf-8") as f: - return json.load(f) - return {"id": get_git_project_id() or "", "patches": []} - - -def save_patches_metadata(patch_metadata: dict) -> dict: - project_patches_dir = get_patches_dir_for_project() - meta_file = project_patches_dir / "metadata.json" - lock_file = project_patches_dir / "metadata.json.lock" - - # we are not supporting multiple concurrent optimizations within the same process, but keep that in case we decide to do so in the future. - with FileLock(lock_file, timeout=10): - metadata = get_patches_metadata() - - patch_metadata["id"] = time.strftime("%Y%m%d-%H%M%S") - metadata["patches"].append(patch_metadata) - - meta_file.write_text(json.dumps(metadata, indent=2)) - - return patch_metadata - - -def overwrite_patch_metadata(patches: list[dict]) -> bool: - project_patches_dir = get_patches_dir_for_project() - meta_file = project_patches_dir / "metadata.json" - lock_file = project_patches_dir / "metadata.json.lock" - - with FileLock(lock_file, timeout=10): - metadata = get_patches_metadata() - metadata["patches"] = patches - meta_file.write_text(json.dumps(metadata, indent=2)) - return True - - def create_diff_patch_from_worktree( - worktree_dir: Path, - files: list[str], - fto_name: Optional[str] = None, - metadata_input: Optional[dict[str, Any]] = None, -) -> dict[str, Any]: + worktree_dir: Path, files: list[str], fto_name: Optional[str] = None +) -> Optional[Path]: repository = git.Repo(worktree_dir, search_parent_directories=True) uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True) if not uni_diff_text: logger.warning("No changes found in worktree.") - return {} + return None if not uni_diff_text.endswith("\n"): uni_diff_text += "\n" @@ -157,14 +112,8 @@ def create_diff_patch_from_worktree( project_patches_dir = get_patches_dir_for_project() project_patches_dir.mkdir(parents=True, exist_ok=True) - final_function_name = fto_name or metadata_input.get("fto_name", "unknown") - patch_path = project_patches_dir / f"{worktree_dir.name}.{final_function_name}.patch" + patch_path = project_patches_dir / f"{worktree_dir.name}.{fto_name}.patch" with patch_path.open("w", encoding="utf8") as f: f.write(uni_diff_text) - final_metadata = {"patch_path": str(patch_path)} - if metadata_input: - final_metadata.update(metadata_input) - final_metadata = save_patches_metadata(final_metadata) - - return final_metadata + return patch_path diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index c78e551be..ee9e24286 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -12,11 +12,8 @@ from codeflash.api.cfapi import get_codeflash_api_key, get_user_id from codeflash.cli_cmds.cli import process_pyproject_config from codeflash.cli_cmds.console import code_print -from codeflash.code_utils.git_worktree_utils import ( - create_diff_patch_from_worktree, - get_patches_metadata, - overwrite_patch_metadata, -) +from codeflash.code_utils.git_utils import git_root_dir +from codeflash.code_utils.git_worktree_utils import create_diff_patch_from_worktree from codeflash.code_utils.shell_utils import save_api_key_to_rc from codeflash.discovery.functions_to_optimize import ( filter_functions, @@ -39,10 +36,17 @@ class OptimizableFunctionsParams: textDocument: types.TextDocumentIdentifier # noqa: N815 +@dataclass +class FunctionOptimizationInitParams: + textDocument: types.TextDocumentIdentifier # noqa: N815 + functionName: str # noqa: N815 + + @dataclass class FunctionOptimizationParams: textDocument: types.TextDocumentIdentifier # noqa: N815 functionName: str # noqa: N815 + task_id: str @dataclass @@ -59,7 +63,7 @@ class ValidateProjectParams: @dataclass class OnPatchAppliedParams: - patch_id: str + task_id: str @dataclass @@ -132,42 +136,6 @@ def get_optimizable_functions( return path_to_qualified_names -@server.feature("initializeFunctionOptimization") -def initialize_function_optimization( - server: CodeflashLanguageServer, params: FunctionOptimizationParams -) -> dict[str, str]: - file_path = Path(uris.to_fs_path(params.textDocument.uri)) - server.show_message_log(f"Initializing optimization for function: {params.functionName} in {file_path}", "Info") - - if server.optimizer is None: - _initialize_optimizer_if_api_key_is_valid(server) - - server.optimizer.worktree_mode() - - original_args, _ = server.optimizer.original_args_and_test_cfg - - server.optimizer.args.function = params.functionName - original_relative_file_path = file_path.relative_to(original_args.project_root) - server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path - server.optimizer.args.previous_checkpoint_functions = False - - server.show_message_log( - f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info" - ) - - optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions() - - if count == 0: - server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning") - server.cleanup_the_optimizer() - return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None} - - fto = optimizable_funcs.popitem()[1][0] - server.optimizer.current_function_being_optimized = fto - server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info") - return {"functionName": params.functionName, "status": "success"} - - def _find_pyproject_toml(workspace_path: str) -> Path | None: workspace_path_obj = Path(workspace_path) max_depth = 2 @@ -207,13 +175,18 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) if pyproject_toml_path: server.prepare_optimizer_arguments(pyproject_toml_path) else: - return { - "status": "error", - "message": "No pyproject.toml found in workspace.", - } # TODO: enhancec this message to say there is not tool.codeflash in pyproject.toml or smth + return {"status": "error", "message": "No pyproject.toml found in workspace."} + + # since we are using worktrees, optimization diffs are generated with respect to the root of the repo, also the args.project_root is set to the root of the repo when creating a worktree + root = str(git_root_dir()) if getattr(params, "skip_validation", False): - return {"status": "success", "moduleRoot": server.args.module_root, "pyprojectPath": pyproject_toml_path} + return { + "status": "success", + "moduleRoot": server.args.module_root, + "pyprojectPath": pyproject_toml_path, + "root": root, + } server.show_message_log("Validating project...", "Info") config = is_valid_pyproject_toml(pyproject_toml_path) @@ -234,7 +207,7 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams) except Exception: return {"status": "error", "message": "Repository has no commits (unborn HEAD)"} - return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path} + return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path, "root": root} def _initialize_optimizer_if_api_key_is_valid( @@ -296,78 +269,85 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams return {"status": "error", "message": "something went wrong while saving the api key"} -@server.feature("retrieveSuccessfulOptimizations") -def retrieve_successful_optimizations(_server: CodeflashLanguageServer, _params: any) -> dict[str, str]: - metadata = get_patches_metadata() - return {"status": "success", "patches": metadata["patches"]} +@server.feature("initializeFunctionOptimization") +def initialize_function_optimization( + server: CodeflashLanguageServer, params: FunctionOptimizationInitParams +) -> dict[str, str]: + file_path = Path(uris.to_fs_path(params.textDocument.uri)) + server.show_message_log(f"Initializing optimization for function: {params.functionName} in {file_path}", "Info") + if server.optimizer is None: + _initialize_optimizer_if_api_key_is_valid(server) -@server.feature("onPatchApplied") -def on_patch_applied(_server: CodeflashLanguageServer, params: OnPatchAppliedParams) -> dict[str, str]: - # first remove the patch from the metadata - metadata = get_patches_metadata() + server.optimizer.worktree_mode() - deleted_patch_file = None - new_patches = [] - for patch in metadata["patches"]: - if patch["id"] == params.patch_id: - deleted_patch_file = patch["patch_path"] - continue - new_patches.append(patch) + original_args, _ = server.optimizer.original_args_and_test_cfg - # then remove the patch file - if deleted_patch_file: - overwrite_patch_metadata(new_patches) - patch_path = Path(deleted_patch_file) - patch_path.unlink(missing_ok=True) - return {"status": "success"} - return {"status": "error", "message": "Patch not found"} + server.optimizer.args.function = params.functionName + original_relative_file_path = file_path.relative_to(original_args.project_root) + server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path + server.optimizer.args.previous_checkpoint_functions = False + server.show_message_log( + f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info" + ) -@server.feature("performFunctionOptimization") -@server.thread() -def perform_function_optimization( # noqa: PLR0911 - server: CodeflashLanguageServer, params: FunctionOptimizationParams -) -> dict[str, str]: - try: - server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info") - current_function = server.optimizer.current_function_being_optimized + optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions() - if not current_function: - server.show_message_log(f"No current function being optimized for {params.functionName}", "Error") - return { - "functionName": params.functionName, - "status": "error", - "message": "No function currently being optimized", - } + if count == 0: + server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning") + server.cleanup_the_optimizer() + return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None} - module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path) - if not module_prep_result: - return { - "functionName": params.functionName, - "status": "error", - "message": "Failed to prepare module for optimization", - } + fto = optimizable_funcs.popitem()[1][0] - validated_original_code, original_module_ast = module_prep_result + module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path) + if not module_prep_result: + return { + "functionName": params.functionName, + "status": "error", + "message": "Failed to prepare module for optimization", + } - function_optimizer = server.optimizer.create_function_optimizer( - current_function, - function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code, - original_module_ast=original_module_ast, - original_module_path=current_function.file_path, - function_to_tests={}, - ) + validated_original_code, original_module_ast = module_prep_result + + function_optimizer = server.optimizer.create_function_optimizer( + fto, + function_to_optimize_source_code=validated_original_code[fto.file_path].source_code, + original_module_ast=original_module_ast, + original_module_path=fto.file_path, + function_to_tests={}, + ) + + server.optimizer.current_function_optimizer = function_optimizer + if not function_optimizer: + return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"} + + initialization_result = function_optimizer.can_be_optimized() + if not is_successful(initialization_result): + return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()} + + server.current_optimization_init_result = initialization_result.unwrap() + server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info") - server.optimizer.current_function_optimizer = function_optimizer - if not function_optimizer: - return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"} + files = [function_optimizer.function_to_optimize.file_path] - initialization_result = function_optimizer.can_be_optimized() - if not is_successful(initialization_result): - return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()} + _, _, original_helpers = server.current_optimization_init_result + files.extend([str(helper_path) for helper_path in original_helpers]) - should_run_experiment, code_context, original_helper_code = initialization_result.unwrap() + return {"functionName": params.functionName, "status": "success", "files_inside_context": files} + + +@server.feature("performFunctionOptimization") +@server.thread() +def perform_function_optimization( + server: CodeflashLanguageServer, params: FunctionOptimizationParams +) -> dict[str, str]: + try: + server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info") + should_run_experiment, code_context, original_helper_code = server.current_optimization_init_result + function_optimizer = server.optimizer.current_function_optimizer + current_function = function_optimizer.function_to_optimize code_print( code_context.read_writable_code.flat, @@ -447,22 +427,17 @@ def perform_function_optimization( # noqa: PLR0911 speedup = original_code_baseline.runtime / best_optimization.runtime - # get the original file path in the actual project (not in the worktree) - original_args, _ = server.optimizer.original_args_and_test_cfg - relative_file_path = current_function.file_path.relative_to(server.optimizer.current_worktree) - original_file_path = Path(original_args.project_root / relative_file_path).resolve() - - metadata = create_diff_patch_from_worktree( - server.optimizer.current_worktree, - relative_file_paths, - metadata_input={ - "fto_name": function_to_optimize_qualified_name, - "explanation": best_optimization.explanation_v2, - "file_path": str(original_file_path), - "speedup": speedup, - }, + patch_path = create_diff_patch_from_worktree( + server.optimizer.current_worktree, relative_file_paths, function_to_optimize_qualified_name ) + if not patch_path: + return { + "functionName": params.functionName, + "status": "error", + "message": "Failed to create a patch for optimization", + } + server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info") return { @@ -470,8 +445,8 @@ def perform_function_optimization( # noqa: PLR0911 "status": "success", "message": "Optimization completed successfully", "extra": f"Speedup: {speedup:.2f}x faster", - "patch_file": metadata["patch_path"], - "patch_id": metadata["id"], + "patch_file": str(patch_path), + "task_id": params.task_id, "explanation": best_optimization.explanation_v2, } finally: diff --git a/codeflash/lsp/server.py b/codeflash/lsp/server.py index daaa4efe5..b33fcc173 100644 --- a/codeflash/lsp/server.py +++ b/codeflash/lsp/server.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from pathlib import Path + from codeflash.models.models import CodeOptimizationContext from codeflash.optimization.optimizer import Optimizer @@ -22,6 +23,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401 self.optimizer: Optimizer | None = None self.args_processed_before: bool = False self.args = None + self.current_optimization_init_result: tuple[bool, CodeOptimizationContext, dict[Path, str]] | None = None def prepare_optimizer_arguments(self, config_file: Path) -> None: from codeflash.cli_cmds.cli import parse_args @@ -57,6 +59,7 @@ def show_message_log(self, message: str, message_type: str) -> None: self.lsp.notify("window/logMessage", log_params) def cleanup_the_optimizer(self) -> None: + self.current_optimization_init_result = None if not self.optimizer: return try: diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 50bc8341f..c5ab0e1e0 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -350,13 +350,12 @@ def run(self) -> None: relative_file_paths = [ code_string.file_path for code_string in read_writable_code.code_strings ] - metadata = create_diff_patch_from_worktree( + patch_path = create_diff_patch_from_worktree( self.current_worktree, relative_file_paths, fto_name=function_to_optimize.qualified_name, - metadata_input={}, ) - self.patch_files.append(metadata["patch_path"]) + self.patch_files.append(patch_path) if i < len(functions_to_optimize) - 1: create_worktree_snapshot_commit( self.current_worktree,