diff --git a/codeflash/code_utils/git_utils.py b/codeflash/code_utils/git_utils.py index c324b2b05..05056828b 100644 --- a/codeflash/code_utils/git_utils.py +++ b/codeflash/code_utils/git_utils.py @@ -9,14 +9,13 @@ from functools import cache from io import StringIO from pathlib import Path -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import git from rich.prompt import Confirm from unidiff import PatchSet from codeflash.cli_cmds.console import logger -from codeflash.code_utils.compat import codeflash_cache_dir from codeflash.code_utils.config_consts import N_CANDIDATES if TYPE_CHECKING: @@ -199,84 +198,3 @@ def get_last_commit_author_if_pr_exists(repo: Repo | None = None) -> str | None: return None else: return last_commit.author.name - - -worktree_dirs = codeflash_cache_dir / "worktrees" -patches_dir = codeflash_cache_dir / "patches" - - -def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None: - repository = git.Repo(worktree_dir, search_parent_directories=True) - repository.git.add(".") - repository.git.commit("-m", commit_message, "--no-verify") - - -def create_detached_worktree(module_root: Path) -> Optional[Path]: - if not check_running_in_git_repo(module_root): - logger.warning("Module is not in a git repository. Skipping worktree creation.") - return None - git_root = git_root_dir() - current_time_str = time.strftime("%Y%m%d-%H%M%S") - worktree_dir = worktree_dirs / f"{git_root.name}-{current_time_str}" - - repository = git.Repo(git_root, search_parent_directories=True) - - repository.git.worktree("add", "-d", str(worktree_dir)) - - # Get uncommitted diff from the original repo - repository.git.add("-N", ".") # add the index for untracked files to be included in the diff - exclude_binary_files = [":!*.pyc", ":!*.pyo", ":!*.pyd", ":!*.so", ":!*.dll", ":!*.whl", ":!*.egg", ":!*.egg-info", ":!*.pyz", ":!*.pkl", ":!*.pickle", ":!*.joblib", ":!*.npy", ":!*.npz", ":!*.h5", ":!*.hdf5", ":!*.pth", ":!*.pt", ":!*.pb", ":!*.onnx", ":!*.db", ":!*.sqlite", ":!*.sqlite3", ":!*.feather", ":!*.parquet", ":!*.jpg", ":!*.jpeg", ":!*.png", ":!*.gif", ":!*.bmp", ":!*.tiff", ":!*.webp", ":!*.wav", ":!*.mp3", ":!*.ogg", ":!*.flac", ":!*.mp4", ":!*.avi", ":!*.mov", ":!*.mkv", ":!*.pdf", ":!*.doc", ":!*.docx", ":!*.xls", ":!*.xlsx", ":!*.ppt", ":!*.pptx", ":!*.zip", ":!*.rar", ":!*.tar", ":!*.tar.gz", ":!*.tgz", ":!*.bz2", ":!*.xz"] # fmt: off - uni_diff_text = repository.git.diff( - None, "HEAD", "--", *exclude_binary_files, ignore_blank_lines=True, ignore_space_at_eol=True - ) - - if not uni_diff_text.strip(): - logger.info("No uncommitted changes to copy to worktree.") - return worktree_dir - - # Write the diff to a temporary file - with tempfile.NamedTemporaryFile(mode="w", suffix=".codeflash.patch", delete=False) as tmp_patch_file: - tmp_patch_file.write(uni_diff_text + "\n") # the new line here is a must otherwise the last hunk won't be valid - tmp_patch_file.flush() - - patch_path = Path(tmp_patch_file.name).resolve() - - # Apply the patch inside the worktree - try: - subprocess.run( - ["git", "apply", "--ignore-space-change", "--ignore-whitespace", "--whitespace=nowarn", patch_path], - cwd=worktree_dir, - check=True, - ) - create_worktree_snapshot_commit(worktree_dir, "Initial Snapshot") - except subprocess.CalledProcessError as e: - logger.error(f"Failed to apply patch to worktree: {e}") - - return worktree_dir - - -def remove_worktree(worktree_dir: Path) -> None: - try: - repository = git.Repo(worktree_dir, search_parent_directories=True) - repository.git.worktree("remove", "--force", worktree_dir) - except Exception: - logger.exception(f"Failed to remove worktree: {worktree_dir}") - - -def create_diff_patch_from_worktree(worktree_dir: Path, files: list[str], fto_name: str) -> Path: - repository = git.Repo(worktree_dir, search_parent_directories=True) - uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True) - - if not uni_diff_text: - logger.warning("No changes found in worktree.") - return None - - if not uni_diff_text.endswith("\n"): - uni_diff_text += "\n" - - # write to patches_dir - patches_dir.mkdir(parents=True, exist_ok=True) - patch_path = patches_dir / f"{worktree_dir.name}.{fto_name}.patch" - with patch_path.open("w", encoding="utf8") as f: - f.write(uni_diff_text) - return patch_path diff --git a/codeflash/code_utils/git_worktree_utils.py b/codeflash/code_utils/git_worktree_utils.py new file mode 100644 index 000000000..17768ff01 --- /dev/null +++ b/codeflash/code_utils/git_worktree_utils.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import json +import subprocess +import tempfile +import time +from functools import lru_cache +from pathlib import Path +from typing import TYPE_CHECKING, Optional + +import git +from filelock import FileLock + +from codeflash.cli_cmds.console import logger +from codeflash.code_utils.compat import codeflash_cache_dir +from codeflash.code_utils.git_utils import check_running_in_git_repo, git_root_dir + +if TYPE_CHECKING: + from typing import Any + + from git import Repo + + +worktree_dirs = codeflash_cache_dir / "worktrees" +patches_dir = codeflash_cache_dir / "patches" + +if TYPE_CHECKING: + from git import Repo + + +@lru_cache(maxsize=1) +def get_git_project_id() -> str: + """Return the first commit sha of the repo.""" + repo: Repo = git.Repo(search_parent_directories=True) + root_commits = list(repo.iter_commits(rev="HEAD", max_parents=0)) + return root_commits[0].hexsha + + +def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None: + repository = git.Repo(worktree_dir, search_parent_directories=True) + repository.git.add(".") + repository.git.commit("-m", commit_message, "--no-verify") + + +def create_detached_worktree(module_root: Path) -> Optional[Path]: + if not check_running_in_git_repo(module_root): + logger.warning("Module is not in a git repository. Skipping worktree creation.") + return None + git_root = git_root_dir() + current_time_str = time.strftime("%Y%m%d-%H%M%S") + worktree_dir = worktree_dirs / f"{git_root.name}-{current_time_str}" + + repository = git.Repo(git_root, search_parent_directories=True) + + repository.git.worktree("add", "-d", str(worktree_dir)) + + # Get uncommitted diff from the original repo + repository.git.add("-N", ".") # add the index for untracked files to be included in the diff + exclude_binary_files = [":!*.pyc", ":!*.pyo", ":!*.pyd", ":!*.so", ":!*.dll", ":!*.whl", ":!*.egg", ":!*.egg-info", ":!*.pyz", ":!*.pkl", ":!*.pickle", ":!*.joblib", ":!*.npy", ":!*.npz", ":!*.h5", ":!*.hdf5", ":!*.pth", ":!*.pt", ":!*.pb", ":!*.onnx", ":!*.db", ":!*.sqlite", ":!*.sqlite3", ":!*.feather", ":!*.parquet", ":!*.jpg", ":!*.jpeg", ":!*.png", ":!*.gif", ":!*.bmp", ":!*.tiff", ":!*.webp", ":!*.wav", ":!*.mp3", ":!*.ogg", ":!*.flac", ":!*.mp4", ":!*.avi", ":!*.mov", ":!*.mkv", ":!*.pdf", ":!*.doc", ":!*.docx", ":!*.xls", ":!*.xlsx", ":!*.ppt", ":!*.pptx", ":!*.zip", ":!*.rar", ":!*.tar", ":!*.tar.gz", ":!*.tgz", ":!*.bz2", ":!*.xz"] # fmt: off + uni_diff_text = repository.git.diff( + None, "HEAD", "--", *exclude_binary_files, ignore_blank_lines=True, ignore_space_at_eol=True + ) + + if not uni_diff_text.strip(): + logger.info("No uncommitted changes to copy to worktree.") + return worktree_dir + + # Write the diff to a temporary file + with tempfile.NamedTemporaryFile(mode="w", suffix=".codeflash.patch", delete=False) as tmp_patch_file: + tmp_patch_file.write(uni_diff_text + "\n") # the new line here is a must otherwise the last hunk won't be valid + tmp_patch_file.flush() + + patch_path = Path(tmp_patch_file.name).resolve() + + # Apply the patch inside the worktree + try: + subprocess.run( + ["git", "apply", "--ignore-space-change", "--ignore-whitespace", "--whitespace=nowarn", patch_path], + cwd=worktree_dir, + check=True, + ) + create_worktree_snapshot_commit(worktree_dir, "Initial Snapshot") + except subprocess.CalledProcessError as e: + logger.error(f"Failed to apply patch to worktree: {e}") + + return worktree_dir + + +def remove_worktree(worktree_dir: Path) -> None: + try: + repository = git.Repo(worktree_dir, search_parent_directories=True) + repository.git.worktree("remove", "--force", worktree_dir) + except Exception: + logger.exception(f"Failed to remove worktree: {worktree_dir}") + + +@lru_cache(maxsize=1) +def get_patches_dir_for_project() -> Path: + project_id = get_git_project_id() or "" + return Path(patches_dir / project_id) + + +def get_patches_metadata() -> dict[str, Any]: + project_patches_dir = get_patches_dir_for_project() + meta_file = project_patches_dir / "metadata.json" + if meta_file.exists(): + with meta_file.open("r", encoding="utf-8") as f: + return json.load(f) + return {"id": get_git_project_id() or "", "patches": []} + + +def save_patches_metadata(patch_metadata: dict) -> dict: + project_patches_dir = get_patches_dir_for_project() + meta_file = project_patches_dir / "metadata.json" + lock_file = project_patches_dir / "metadata.json.lock" + + # we are not supporting multiple concurrent optimizations within the same process, but keep that in case we decide to do so in the future. + with FileLock(lock_file, timeout=10): + metadata = get_patches_metadata() + + patch_metadata["id"] = time.strftime("%Y%m%d-%H%M%S") + metadata["patches"].append(patch_metadata) + + meta_file.write_text(json.dumps(metadata, indent=2)) + + return patch_metadata + + +def overwrite_patch_metadata(patches: list[dict]) -> bool: + project_patches_dir = get_patches_dir_for_project() + meta_file = project_patches_dir / "metadata.json" + lock_file = project_patches_dir / "metadata.json.lock" + + with FileLock(lock_file, timeout=10): + metadata = get_patches_metadata() + metadata["patches"] = patches + meta_file.write_text(json.dumps(metadata, indent=2)) + return True + + +def create_diff_patch_from_worktree( + worktree_dir: Path, + files: list[str], + fto_name: Optional[str] = None, + metadata_input: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + repository = git.Repo(worktree_dir, search_parent_directories=True) + uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True) + + if not uni_diff_text: + logger.warning("No changes found in worktree.") + return {} + + if not uni_diff_text.endswith("\n"): + uni_diff_text += "\n" + + project_patches_dir = get_patches_dir_for_project() + project_patches_dir.mkdir(parents=True, exist_ok=True) + + final_function_name = fto_name or metadata_input.get("fto_name", "unknown") + patch_path = project_patches_dir / f"{worktree_dir.name}.{final_function_name}.patch" + with patch_path.open("w", encoding="utf8") as f: + f.write(uni_diff_text) + + final_metadata = {"patch_path": str(patch_path)} + if metadata_input: + final_metadata.update(metadata_input) + final_metadata = save_patches_metadata(final_metadata) + + return final_metadata diff --git a/codeflash/lsp/beta.py b/codeflash/lsp/beta.py index a68688ed6..1b8edcec7 100644 --- a/codeflash/lsp/beta.py +++ b/codeflash/lsp/beta.py @@ -11,7 +11,11 @@ from codeflash.api.cfapi import get_codeflash_api_key, get_user_id from codeflash.cli_cmds.cli import process_pyproject_config -from codeflash.code_utils.git_utils import create_diff_patch_from_worktree +from codeflash.code_utils.git_worktree_utils import ( + create_diff_patch_from_worktree, + get_patches_metadata, + overwrite_patch_metadata, +) from codeflash.code_utils.shell_utils import save_api_key_to_rc from codeflash.discovery.functions_to_optimize import ( filter_functions, @@ -45,6 +49,10 @@ class ProvideApiKeyParams: api_key: str +@dataclass +class OnPatchAppliedParams: + patch_id: str + @dataclass class OptimizableFunctionsInCommitParams: commit_hash: str @@ -245,6 +253,34 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams return {"status": "error", "message": "something went wrong while saving the api key"} +@server.feature("retrieveSuccessfulOptimizations") +def retrieve_successful_optimizations(_server: CodeflashLanguageServer, _params: any) -> dict[str, str]: + metadata = get_patches_metadata() + return {"status": "success", "patches": metadata["patches"]} + + +@server.feature("onPatchApplied") +def on_patch_applied(_server: CodeflashLanguageServer, params: OnPatchAppliedParams) -> dict[str, str]: + # first remove the patch from the metadata + metadata = get_patches_metadata() + + deleted_patch_file = None + new_patches = [] + for patch in metadata["patches"]: + if patch["id"] == params.patch_id: + deleted_patch_file = patch["patch_path"] + continue + new_patches.append(patch) + + # then remove the patch file + if deleted_patch_file: + overwrite_patch_metadata(new_patches) + patch_path = Path(deleted_patch_file) + patch_path.unlink(missing_ok=True) + return {"status": "success"} + return {"status": "error", "message": "Patch not found"} + + @server.feature("performFunctionOptimization") @server.thread() def perform_function_optimization( # noqa: PLR0911 @@ -346,15 +382,25 @@ def perform_function_optimization( # noqa: PLR0911 # generate a patch for the optimization relative_file_paths = [code_string.file_path for code_string in code_context.read_writable_code.code_strings] - patch_file = create_diff_patch_from_worktree( + + speedup = original_code_baseline.runtime / best_optimization.runtime + + # get the original file path in the actual project (not in the worktree) + original_args, _ = server.optimizer.original_args_and_test_cfg + relative_file_path = current_function.file_path.relative_to(server.optimizer.current_worktree) + original_file_path = Path(original_args.project_root / relative_file_path).resolve() + + metadata = create_diff_patch_from_worktree( server.optimizer.current_worktree, relative_file_paths, - server.optimizer.current_function_optimizer.function_to_optimize.qualified_name, + metadata_input={ + "fto_name": function_to_optimize_qualified_name, + "explanation": best_optimization.explanation_v2, + "file_path": str(original_file_path), + "speedup": speedup, + }, ) - optimized_source = best_optimization.candidate.source_code.markdown - speedup = original_code_baseline.runtime / best_optimization.runtime - server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info") return { @@ -362,8 +408,8 @@ def perform_function_optimization( # noqa: PLR0911 "status": "success", "message": "Optimization completed successfully", "extra": f"Speedup: {speedup:.2f}x faster", - "optimization": optimized_source, - "patch_file": str(patch_file), + "patch_file": metadata["patch_path"], + "patch_id": metadata["id"], "explanation": best_optimization.explanation_v2, } finally: diff --git a/codeflash/optimization/optimizer.py b/codeflash/optimization/optimizer.py index 941705cfd..e1a0c4186 100644 --- a/codeflash/optimization/optimizer.py +++ b/codeflash/optimization/optimizer.py @@ -15,8 +15,8 @@ from codeflash.code_utils import env_utils from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file from codeflash.code_utils.env_utils import get_pr_number, is_pr_draft -from codeflash.code_utils.git_utils import ( - check_running_in_git_repo, +from codeflash.code_utils.git_utils import check_running_in_git_repo +from codeflash.code_utils.git_worktree_utils import ( create_detached_worktree, create_diff_patch_from_worktree, create_worktree_snapshot_commit, @@ -343,16 +343,18 @@ def run(self) -> None: optimizations_found += 1 # create a diff patch for successful optimization if self.current_worktree: - read_writable_code = best_optimization.unwrap().code_context.read_writable_code + best_opt = best_optimization.unwrap() + read_writable_code = best_opt.code_context.read_writable_code relative_file_paths = [ code_string.file_path for code_string in read_writable_code.code_strings ] - patch_path = create_diff_patch_from_worktree( + metadata = create_diff_patch_from_worktree( self.current_worktree, relative_file_paths, - self.current_function_optimizer.function_to_optimize.qualified_name, + fto_name=function_to_optimize.qualified_name, + metadata_input={}, ) - self.patch_files.append(patch_path) + self.patch_files.append(metadata["patch_path"]) if i < len(functions_to_optimize) - 1: create_worktree_snapshot_commit( self.current_worktree,