Skip to content

Commit a664330

Browse files
save optimization patches metadata
1 parent 674e69e commit a664330

File tree

4 files changed

+126
-29
lines changed

4 files changed

+126
-29
lines changed

codeflash/code_utils/git_utils.py

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
from __future__ import annotations
22

3+
import json
34
import os
45
import shutil
56
import subprocess
67
import sys
78
import tempfile
89
import time
9-
from functools import cache
10+
from functools import cache, lru_cache
1011
from io import StringIO
1112
from pathlib import Path
1213
from typing import TYPE_CHECKING, Optional
1314

1415
import git
16+
from filelock import FileLock
1517
from rich.prompt import Confirm
1618
from unidiff import PatchSet
1719

@@ -20,6 +22,8 @@
2022
from codeflash.code_utils.config_consts import N_CANDIDATES
2123

2224
if TYPE_CHECKING:
25+
from typing import Any
26+
2327
from git import Repo
2428

2529

@@ -199,6 +203,14 @@ def get_last_commit_author_if_pr_exists(repo: Repo | None = None) -> str | None:
199203
patches_dir = codeflash_cache_dir / "patches"
200204

201205

206+
@lru_cache(maxsize=1)
207+
def get_git_project_id() -> str:
208+
"""Return the first commit sha of the repo."""
209+
repo: Repo = git.Repo(search_parent_directories=True)
210+
root_commits = list(repo.iter_commits(rev="HEAD", max_parents=0))
211+
return root_commits[0].hexsha
212+
213+
202214
def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None:
203215
repository = git.Repo(worktree_dir, search_parent_directories=True)
204216
repository.git.add(".")
@@ -257,20 +269,70 @@ def remove_worktree(worktree_dir: Path) -> None:
257269
logger.exception(f"Failed to remove worktree: {worktree_dir}")
258270

259271

260-
def create_diff_patch_from_worktree(worktree_dir: Path, files: list[str], fto_name: str) -> Path:
272+
def get_patches_dir_for_project() -> Path:
273+
project_id = get_git_project_id() or ""
274+
return Path(patches_dir / project_id)
275+
276+
277+
def get_patches_metadata() -> dict[str, Any]:
278+
project_patches_dir = get_patches_dir_for_project()
279+
meta_file = project_patches_dir / "metadata.json"
280+
if meta_file.exists():
281+
return json.loads(meta_file.read_text())
282+
return {"id": get_git_project_id() or "", "patches": []}
283+
284+
285+
def save_patches_metadata(patch_metadata: dict) -> dict:
286+
project_patches_dir = get_patches_dir_for_project()
287+
meta_file = project_patches_dir / "metadata.json"
288+
lock_file = project_patches_dir / "metadata.json.lock"
289+
290+
with FileLock(lock_file, timeout=10):
291+
metadata = get_patches_metadata()
292+
293+
patch_metadata["id"] = time.strftime("%Y%m%d-%H%M%S")
294+
metadata["patches"].append(patch_metadata)
295+
296+
meta_file.write_text(json.dumps(metadata, indent=2))
297+
298+
return patch_metadata
299+
300+
301+
def overwrite_patch_metadata(patches: list[dict]) -> bool:
302+
project_patches_dir = get_patches_dir_for_project()
303+
meta_file = project_patches_dir / "metadata.json"
304+
lock_file = project_patches_dir / "metadata.json.lock"
305+
306+
with FileLock(lock_file, timeout=10):
307+
metadata = get_patches_metadata()
308+
metadata["patches"] = patches
309+
meta_file.write_text(json.dumps(metadata, indent=2))
310+
return True
311+
312+
313+
def create_diff_patch_from_worktree(
314+
worktree_dir: Path, files: list[str], metadata_input: dict[str, Any]
315+
) -> dict[str, Any]:
261316
repository = git.Repo(worktree_dir, search_parent_directories=True)
262317
uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True)
263318

264319
if not uni_diff_text:
265320
logger.warning("No changes found in worktree.")
266-
return None
321+
return {}
267322

268323
if not uni_diff_text.endswith("\n"):
269324
uni_diff_text += "\n"
270325

271-
# write to patches_dir
272-
patches_dir.mkdir(parents=True, exist_ok=True)
273-
patch_path = patches_dir / f"{worktree_dir.name}.{fto_name}.patch"
326+
project_patches_dir = get_patches_dir_for_project()
327+
project_patches_dir.mkdir(parents=True, exist_ok=True)
328+
329+
patch_path = project_patches_dir / f"{worktree_dir.name}.{metadata_input['fto_name']}.patch"
274330
with patch_path.open("w", encoding="utf8") as f:
275331
f.write(uni_diff_text)
276-
return patch_path
332+
333+
final_metadata = {}
334+
if metadata_input:
335+
metadata_input["patch_path"] = str(patch_path)
336+
final_metadata = save_patches_metadata(metadata_input)
337+
338+
return final_metadata

codeflash/lsp/beta.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111

1212
from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
1313
from codeflash.cli_cmds.cli import process_pyproject_config
14-
from codeflash.code_utils.git_utils import create_diff_patch_from_worktree
14+
from codeflash.code_utils.git_utils import (
15+
create_diff_patch_from_worktree,
16+
get_patches_metadata,
17+
overwrite_patch_metadata,
18+
)
1519
from codeflash.code_utils.shell_utils import save_api_key_to_rc
1620
from codeflash.discovery.functions_to_optimize import filter_functions, get_functions_within_git_diff
1721
from codeflash.either import is_successful
@@ -216,6 +220,29 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
216220
return {"status": "error", "message": "something went wrong while saving the api key"}
217221

218222

223+
@server.feature("onPatchApplied")
224+
def on_patch_applied(_server: CodeflashLanguageServer, params: dict[str, str]) -> dict[str, str]:
225+
# first remove the patch from the metadata
226+
patch_id = params["patch_id"]
227+
metadata = get_patches_metadata()
228+
229+
deleted_patch_file = None
230+
new_patches = []
231+
for patch in metadata["patches"]:
232+
if patch["id"] == patch_id:
233+
deleted_patch_file = patch["patch_path"]
234+
continue
235+
new_patches.append(patch)
236+
237+
overwrite_patch_metadata(new_patches)
238+
# then remove the patch file
239+
if deleted_patch_file:
240+
patch_path = Path(deleted_patch_file)
241+
patch_path.unlink(missing_ok=True)
242+
return {"status": "success"}
243+
return {"status": "error", "message": "Patch not found"}
244+
245+
219246
@server.feature("performFunctionOptimization")
220247
@server.thread()
221248
def perform_function_optimization( # noqa: PLR0911
@@ -317,24 +344,34 @@ def perform_function_optimization( # noqa: PLR0911
317344

318345
# generate a patch for the optimization
319346
relative_file_paths = [code_string.file_path for code_string in code_context.read_writable_code.code_strings]
320-
patch_file = create_diff_patch_from_worktree(
347+
348+
speedup = original_code_baseline.runtime / best_optimization.runtime
349+
350+
# get the original file path in the actual project (not in the worktree)
351+
original_args, _ = server.optimizer.original_args_and_test_cfg
352+
relative_file_path = current_function.file_path.relative_to(server.args.project_root)
353+
original_file_path = Path(original_args.project_root / relative_file_path).resolve()
354+
355+
metadata = create_diff_patch_from_worktree(
321356
server.optimizer.current_worktree,
322357
relative_file_paths,
323-
server.optimizer.current_function_optimizer.function_to_optimize.qualified_name,
358+
metadata_input={
359+
"fto_name": function_to_optimize_qualified_name,
360+
"explanation": best_optimization.explanation_v2,
361+
"file_path": str(original_file_path),
362+
"speedup": speedup,
363+
},
324364
)
325365

326-
optimized_source = best_optimization.candidate.source_code.markdown
327-
speedup = original_code_baseline.runtime / best_optimization.runtime
328-
329366
server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info")
330367

331368
return {
332369
"functionName": params.functionName,
333370
"status": "success",
334371
"message": "Optimization completed successfully",
335372
"extra": f"Speedup: {speedup:.2f}x faster",
336-
"optimization": optimized_source,
337-
"patch_file": str(patch_file),
373+
"patch_file": metadata["patch_path"],
374+
"patch_id": metadata["id"],
338375
"explanation": best_optimization.explanation_v2,
339376
}
340377
finally:

codeflash/optimization/function_optimizer.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -631,14 +631,13 @@ def determine_best_candidate(
631631
executor=self.executor,
632632
)
633633
)
634-
else:
635-
tree.add(
636-
f"Summed runtime: {humanize_runtime(best_test_runtime)} "
637-
f"(measured over {candidate_result.max_loop_count} "
638-
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
639-
)
640-
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
641-
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
634+
tree.add(
635+
f"Summed runtime: {humanize_runtime(best_test_runtime)} "
636+
f"(measured over {candidate_result.max_loop_count} "
637+
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
638+
)
639+
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
640+
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
642641
console.print(tree)
643642
if self.args.benchmark and benchmark_tree:
644643
console.print(benchmark_tree)

codeflash/optimization/optimizer.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -343,16 +343,15 @@ def run(self) -> None:
343343
optimizations_found += 1
344344
# create a diff patch for successful optimization
345345
if self.current_worktree:
346-
read_writable_code = best_optimization.unwrap().code_context.read_writable_code
346+
best_opt = best_optimization.unwrap()
347+
read_writable_code = best_opt.code_context.read_writable_code
347348
relative_file_paths = [
348349
code_string.file_path for code_string in read_writable_code.code_strings
349350
]
350-
patch_path = create_diff_patch_from_worktree(
351-
self.current_worktree,
352-
relative_file_paths,
353-
self.current_function_optimizer.function_to_optimize.qualified_name,
351+
metadata = create_diff_patch_from_worktree(
352+
self.current_worktree, relative_file_paths, metadata_input={}
354353
)
355-
self.patch_files.append(patch_path)
354+
self.patch_files.append(metadata["patch_path"])
356355
if i < len(functions_to_optimize) - 1:
357356
create_worktree_snapshot_commit(
358357
self.current_worktree,

0 commit comments

Comments
 (0)