Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 1 addition & 83 deletions codeflash/code_utils/git_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
from functools import cache
from io import StringIO
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

import git
from rich.prompt import Confirm
from unidiff import PatchSet

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.compat import codeflash_cache_dir
from codeflash.code_utils.config_consts import N_CANDIDATES

if TYPE_CHECKING:
Expand Down Expand Up @@ -199,84 +198,3 @@ def get_last_commit_author_if_pr_exists(repo: Repo | None = None) -> str | None:
return None
else:
return last_commit.author.name


worktree_dirs = codeflash_cache_dir / "worktrees"
patches_dir = codeflash_cache_dir / "patches"


def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None:
repository = git.Repo(worktree_dir, search_parent_directories=True)
repository.git.add(".")
repository.git.commit("-m", commit_message, "--no-verify")


def create_detached_worktree(module_root: Path) -> Optional[Path]:
if not check_running_in_git_repo(module_root):
logger.warning("Module is not in a git repository. Skipping worktree creation.")
return None
git_root = git_root_dir()
current_time_str = time.strftime("%Y%m%d-%H%M%S")
worktree_dir = worktree_dirs / f"{git_root.name}-{current_time_str}"

repository = git.Repo(git_root, search_parent_directories=True)

repository.git.worktree("add", "-d", str(worktree_dir))

# Get uncommitted diff from the original repo
repository.git.add("-N", ".") # add the index for untracked files to be included in the diff
exclude_binary_files = [":!*.pyc", ":!*.pyo", ":!*.pyd", ":!*.so", ":!*.dll", ":!*.whl", ":!*.egg", ":!*.egg-info", ":!*.pyz", ":!*.pkl", ":!*.pickle", ":!*.joblib", ":!*.npy", ":!*.npz", ":!*.h5", ":!*.hdf5", ":!*.pth", ":!*.pt", ":!*.pb", ":!*.onnx", ":!*.db", ":!*.sqlite", ":!*.sqlite3", ":!*.feather", ":!*.parquet", ":!*.jpg", ":!*.jpeg", ":!*.png", ":!*.gif", ":!*.bmp", ":!*.tiff", ":!*.webp", ":!*.wav", ":!*.mp3", ":!*.ogg", ":!*.flac", ":!*.mp4", ":!*.avi", ":!*.mov", ":!*.mkv", ":!*.pdf", ":!*.doc", ":!*.docx", ":!*.xls", ":!*.xlsx", ":!*.ppt", ":!*.pptx", ":!*.zip", ":!*.rar", ":!*.tar", ":!*.tar.gz", ":!*.tgz", ":!*.bz2", ":!*.xz"] # fmt: off
uni_diff_text = repository.git.diff(
None, "HEAD", "--", *exclude_binary_files, ignore_blank_lines=True, ignore_space_at_eol=True
)

if not uni_diff_text.strip():
logger.info("No uncommitted changes to copy to worktree.")
return worktree_dir

# Write the diff to a temporary file
with tempfile.NamedTemporaryFile(mode="w", suffix=".codeflash.patch", delete=False) as tmp_patch_file:
tmp_patch_file.write(uni_diff_text + "\n") # the new line here is a must otherwise the last hunk won't be valid
tmp_patch_file.flush()

patch_path = Path(tmp_patch_file.name).resolve()

# Apply the patch inside the worktree
try:
subprocess.run(
["git", "apply", "--ignore-space-change", "--ignore-whitespace", "--whitespace=nowarn", patch_path],
cwd=worktree_dir,
check=True,
)
create_worktree_snapshot_commit(worktree_dir, "Initial Snapshot")
except subprocess.CalledProcessError as e:
logger.error(f"Failed to apply patch to worktree: {e}")

return worktree_dir


def remove_worktree(worktree_dir: Path) -> None:
try:
repository = git.Repo(worktree_dir, search_parent_directories=True)
repository.git.worktree("remove", "--force", worktree_dir)
except Exception:
logger.exception(f"Failed to remove worktree: {worktree_dir}")


def create_diff_patch_from_worktree(worktree_dir: Path, files: list[str], fto_name: str) -> Path:
repository = git.Repo(worktree_dir, search_parent_directories=True)
uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True)

if not uni_diff_text:
logger.warning("No changes found in worktree.")
return None

if not uni_diff_text.endswith("\n"):
uni_diff_text += "\n"

# write to patches_dir
patches_dir.mkdir(parents=True, exist_ok=True)
patch_path = patches_dir / f"{worktree_dir.name}.{fto_name}.patch"
with patch_path.open("w", encoding="utf8") as f:
f.write(uni_diff_text)
return patch_path
170 changes: 170 additions & 0 deletions codeflash/code_utils/git_worktree_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from __future__ import annotations

import json
import subprocess
import tempfile
import time
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Optional

import git
from filelock import FileLock

from codeflash.cli_cmds.console import logger
from codeflash.code_utils.compat import codeflash_cache_dir
from codeflash.code_utils.git_utils import check_running_in_git_repo, git_root_dir

if TYPE_CHECKING:
from typing import Any

from git import Repo


worktree_dirs = codeflash_cache_dir / "worktrees"
patches_dir = codeflash_cache_dir / "patches"

if TYPE_CHECKING:
from git import Repo


@lru_cache(maxsize=1)
def get_git_project_id() -> str:
"""Return the first commit sha of the repo."""
repo: Repo = git.Repo(search_parent_directories=True)
root_commits = list(repo.iter_commits(rev="HEAD", max_parents=0))
return root_commits[0].hexsha


def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None:
repository = git.Repo(worktree_dir, search_parent_directories=True)
repository.git.add(".")
repository.git.commit("-m", commit_message, "--no-verify")


def create_detached_worktree(module_root: Path) -> Optional[Path]:
if not check_running_in_git_repo(module_root):
logger.warning("Module is not in a git repository. Skipping worktree creation.")
return None
git_root = git_root_dir()
current_time_str = time.strftime("%Y%m%d-%H%M%S")
worktree_dir = worktree_dirs / f"{git_root.name}-{current_time_str}"

repository = git.Repo(git_root, search_parent_directories=True)

repository.git.worktree("add", "-d", str(worktree_dir))

# Get uncommitted diff from the original repo
repository.git.add("-N", ".") # add the index for untracked files to be included in the diff
exclude_binary_files = [":!*.pyc", ":!*.pyo", ":!*.pyd", ":!*.so", ":!*.dll", ":!*.whl", ":!*.egg", ":!*.egg-info", ":!*.pyz", ":!*.pkl", ":!*.pickle", ":!*.joblib", ":!*.npy", ":!*.npz", ":!*.h5", ":!*.hdf5", ":!*.pth", ":!*.pt", ":!*.pb", ":!*.onnx", ":!*.db", ":!*.sqlite", ":!*.sqlite3", ":!*.feather", ":!*.parquet", ":!*.jpg", ":!*.jpeg", ":!*.png", ":!*.gif", ":!*.bmp", ":!*.tiff", ":!*.webp", ":!*.wav", ":!*.mp3", ":!*.ogg", ":!*.flac", ":!*.mp4", ":!*.avi", ":!*.mov", ":!*.mkv", ":!*.pdf", ":!*.doc", ":!*.docx", ":!*.xls", ":!*.xlsx", ":!*.ppt", ":!*.pptx", ":!*.zip", ":!*.rar", ":!*.tar", ":!*.tar.gz", ":!*.tgz", ":!*.bz2", ":!*.xz"] # fmt: off
uni_diff_text = repository.git.diff(
None, "HEAD", "--", *exclude_binary_files, ignore_blank_lines=True, ignore_space_at_eol=True
)

if not uni_diff_text.strip():
logger.info("No uncommitted changes to copy to worktree.")
return worktree_dir

# Write the diff to a temporary file
with tempfile.NamedTemporaryFile(mode="w", suffix=".codeflash.patch", delete=False) as tmp_patch_file:
tmp_patch_file.write(uni_diff_text + "\n") # the new line here is a must otherwise the last hunk won't be valid
tmp_patch_file.flush()

patch_path = Path(tmp_patch_file.name).resolve()

# Apply the patch inside the worktree
try:
subprocess.run(
["git", "apply", "--ignore-space-change", "--ignore-whitespace", "--whitespace=nowarn", patch_path],
cwd=worktree_dir,
check=True,
)
create_worktree_snapshot_commit(worktree_dir, "Initial Snapshot")
except subprocess.CalledProcessError as e:
logger.error(f"Failed to apply patch to worktree: {e}")

return worktree_dir


def remove_worktree(worktree_dir: Path) -> None:
try:
repository = git.Repo(worktree_dir, search_parent_directories=True)
repository.git.worktree("remove", "--force", worktree_dir)
except Exception:
logger.exception(f"Failed to remove worktree: {worktree_dir}")


@lru_cache(maxsize=1)
def get_patches_dir_for_project() -> Path:
project_id = get_git_project_id() or ""
return Path(patches_dir / project_id)


def get_patches_metadata() -> dict[str, Any]:
project_patches_dir = get_patches_dir_for_project()
meta_file = project_patches_dir / "metadata.json"
if meta_file.exists():
with meta_file.open("r", encoding="utf-8") as f:
return json.load(f)
return {"id": get_git_project_id() or "", "patches": []}


def save_patches_metadata(patch_metadata: dict) -> dict:
project_patches_dir = get_patches_dir_for_project()
meta_file = project_patches_dir / "metadata.json"
lock_file = project_patches_dir / "metadata.json.lock"

# we are not supporting multiple concurrent optimizations within the same process, but keep that in case we decide to do so in the future.
with FileLock(lock_file, timeout=10):
metadata = get_patches_metadata()

patch_metadata["id"] = time.strftime("%Y%m%d-%H%M%S")
metadata["patches"].append(patch_metadata)

meta_file.write_text(json.dumps(metadata, indent=2))

return patch_metadata


def overwrite_patch_metadata(patches: list[dict]) -> bool:
project_patches_dir = get_patches_dir_for_project()
meta_file = project_patches_dir / "metadata.json"
lock_file = project_patches_dir / "metadata.json.lock"

with FileLock(lock_file, timeout=10):
metadata = get_patches_metadata()
metadata["patches"] = patches
meta_file.write_text(json.dumps(metadata, indent=2))
return True


def create_diff_patch_from_worktree(
worktree_dir: Path,
files: list[str],
fto_name: Optional[str] = None,
metadata_input: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
repository = git.Repo(worktree_dir, search_parent_directories=True)
uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True)

if not uni_diff_text:
logger.warning("No changes found in worktree.")
return {}

if not uni_diff_text.endswith("\n"):
uni_diff_text += "\n"

project_patches_dir = get_patches_dir_for_project()
project_patches_dir.mkdir(parents=True, exist_ok=True)

final_function_name = fto_name or metadata_input.get("fto_name", "unknown")
patch_path = project_patches_dir / f"{worktree_dir.name}.{final_function_name}.patch"
with patch_path.open("w", encoding="utf8") as f:
f.write(uni_diff_text)

final_metadata = {"patch_path": str(patch_path)}
if metadata_input:
final_metadata.update(metadata_input)
final_metadata = save_patches_metadata(final_metadata)

return final_metadata
62 changes: 54 additions & 8 deletions codeflash/lsp/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@

from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
from codeflash.cli_cmds.cli import process_pyproject_config
from codeflash.code_utils.git_utils import create_diff_patch_from_worktree
from codeflash.code_utils.git_worktree_utils import (
create_diff_patch_from_worktree,
get_patches_metadata,
overwrite_patch_metadata,
)
from codeflash.code_utils.shell_utils import save_api_key_to_rc
from codeflash.discovery.functions_to_optimize import (
filter_functions,
Expand Down Expand Up @@ -45,6 +49,10 @@ class ProvideApiKeyParams:
api_key: str


@dataclass
class OnPatchAppliedParams:
patch_id: str

@dataclass
class OptimizableFunctionsInCommitParams:
commit_hash: str
Expand Down Expand Up @@ -245,6 +253,34 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
return {"status": "error", "message": "something went wrong while saving the api key"}


@server.feature("retrieveSuccessfulOptimizations")
def retrieve_successful_optimizations(_server: CodeflashLanguageServer, _params: any) -> dict[str, str]:
metadata = get_patches_metadata()
return {"status": "success", "patches": metadata["patches"]}


@server.feature("onPatchApplied")
def on_patch_applied(_server: CodeflashLanguageServer, params: OnPatchAppliedParams) -> dict[str, str]:
# first remove the patch from the metadata
metadata = get_patches_metadata()

deleted_patch_file = None
new_patches = []
for patch in metadata["patches"]:
if patch["id"] == params.patch_id:
deleted_patch_file = patch["patch_path"]
continue
new_patches.append(patch)

# then remove the patch file
if deleted_patch_file:
overwrite_patch_metadata(new_patches)
patch_path = Path(deleted_patch_file)
patch_path.unlink(missing_ok=True)
return {"status": "success"}
return {"status": "error", "message": "Patch not found"}


@server.feature("performFunctionOptimization")
@server.thread()
def perform_function_optimization( # noqa: PLR0911
Expand Down Expand Up @@ -346,24 +382,34 @@ def perform_function_optimization( # noqa: PLR0911

# generate a patch for the optimization
relative_file_paths = [code_string.file_path for code_string in code_context.read_writable_code.code_strings]
patch_file = create_diff_patch_from_worktree(

speedup = original_code_baseline.runtime / best_optimization.runtime

# get the original file path in the actual project (not in the worktree)
original_args, _ = server.optimizer.original_args_and_test_cfg
relative_file_path = current_function.file_path.relative_to(server.optimizer.current_worktree)
original_file_path = Path(original_args.project_root / relative_file_path).resolve()

metadata = create_diff_patch_from_worktree(
server.optimizer.current_worktree,
relative_file_paths,
server.optimizer.current_function_optimizer.function_to_optimize.qualified_name,
metadata_input={
"fto_name": function_to_optimize_qualified_name,
"explanation": best_optimization.explanation_v2,
"file_path": str(original_file_path),
"speedup": speedup,
},
)

optimized_source = best_optimization.candidate.source_code.markdown
speedup = original_code_baseline.runtime / best_optimization.runtime

server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info")

return {
"functionName": params.functionName,
"status": "success",
"message": "Optimization completed successfully",
"extra": f"Speedup: {speedup:.2f}x faster",
"optimization": optimized_source,
"patch_file": str(patch_file),
"patch_file": metadata["patch_path"],
"patch_id": metadata["id"],
"explanation": best_optimization.explanation_v2,
}
finally:
Expand Down
Loading
Loading