Skip to content

Commit a716c9b

Browse files
Merge pull request #649 from codeflash-ai/feat/detached-worktrees
[FEAT][LSP] Worktrees
2 parents cbce47d + 8addeb3 commit a716c9b

File tree

20 files changed

+396
-189
lines changed

20 files changed

+396
-189
lines changed

codeflash/api/aiservice.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from pydantic.json import pydantic_encoder
1111

1212
from codeflash.cli_cmds.console import console, logger
13-
from codeflash.code_utils.env_utils import get_codeflash_api_key, is_LSP_enabled
13+
from codeflash.code_utils.env_utils import get_codeflash_api_key
1414
from codeflash.code_utils.git_utils import get_last_commit_author_if_pr_exists, get_repo_owner_and_name
15+
from codeflash.lsp.helpers import is_LSP_enabled
1516
from codeflash.models.ExperimentMetadata import ExperimentMetadata
1617
from codeflash.models.models import AIServiceRefinerRequest, CodeStringsMarkdown, OptimizedCandidate
1718
from codeflash.telemetry.posthog_cf import ph

codeflash/api/cfapi.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
from codeflash.cli_cmds.console import console, logger
1616
from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number
17-
from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name, git_root_dir
17+
from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name
1818
from codeflash.github.PrComment import FileDiffContent, PrComment
19+
from codeflash.lsp.helpers import is_LSP_enabled
1920
from codeflash.version import __version__
2021

2122
if TYPE_CHECKING:
@@ -101,7 +102,7 @@ def get_user_id() -> Optional[str]:
101102
if min_version and version.parse(min_version) > version.parse(__version__):
102103
msg = "Your Codeflash CLI version is outdated. Please update to the latest version using `pip install --upgrade codeflash`."
103104
console.print(f"[bold red]{msg}[/bold red]")
104-
if console.quiet: # lsp
105+
if is_LSP_enabled():
105106
logger.debug(msg)
106107
return f"Error: {msg}"
107108
sys.exit(1)
@@ -203,8 +204,9 @@ def create_staging(
203204
generated_original_test_source: str,
204205
function_trace_id: str,
205206
coverage_message: str,
206-
replay_tests: str = "",
207-
concolic_tests: str = "",
207+
replay_tests: str,
208+
concolic_tests: str,
209+
root_dir: Path,
208210
) -> Response:
209211
"""Create a staging pull request, targeting the specified branch. (usually 'staging').
210212
@@ -217,12 +219,10 @@ def create_staging(
217219
:param coverage_message: Coverage report or summary.
218220
:return: The response object from the backend.
219221
"""
220-
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()
222+
relative_path = explanation.file_path.relative_to(root_dir).as_posix()
221223

222224
build_file_changes = {
223-
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
224-
oldContent=original_code[p], newContent=new_code[p]
225-
)
225+
Path(p).relative_to(root_dir).as_posix(): FileDiffContent(oldContent=original_code[p], newContent=new_code[p])
226226
for p in original_code
227227
}
228228

codeflash/cli_cmds/cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def parse_args() -> Namespace:
9494
help="Path to the directory of the project, where all the pytest-benchmark tests are located.",
9595
)
9696
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
97+
parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
9798

9899
args, unknown_args = parser.parse_known_args()
99100
sys.argv[:] = [sys.argv[0], *unknown_args]

codeflash/cli_cmds/console.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
import os
45
from contextlib import contextmanager
56
from itertools import cycle
67
from typing import TYPE_CHECKING
@@ -28,6 +29,10 @@
2829
DEBUG_MODE = logging.getLogger().getEffectiveLevel() == logging.DEBUG
2930

3031
console = Console()
32+
33+
if os.getenv("CODEFLASH_LSP"):
34+
console.quiet = True
35+
3136
logging.basicConfig(
3237
level=logging.INFO,
3338
handlers=[RichHandler(rich_tracebacks=True, markup=False, console=console, show_path=False, show_time=False)],

codeflash/code_utils/coverage_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,20 @@ def build_fully_qualified_name(function_name: str, code_context: CodeOptimizatio
3939
return full_name
4040

4141

42-
def generate_candidates(source_code_path: Path) -> list[str]:
42+
def generate_candidates(source_code_path: Path) -> set[str]:
4343
"""Generate all the possible candidates for coverage data based on the source code path."""
44-
candidates = [source_code_path.name]
44+
candidates = set()
45+
candidates.add(source_code_path.name)
4546
current_path = source_code_path.parent
4647

48+
last_added = source_code_path.name
4749
while current_path != current_path.parent:
48-
candidate_path = str(Path(current_path.name) / candidates[-1])
49-
candidates.append(candidate_path)
50+
candidate_path = str(Path(current_path.name) / last_added)
51+
candidates.add(candidate_path)
52+
last_added = candidate_path
5053
current_path = current_path.parent
5154

55+
candidates.add(str(source_code_path))
5256
return candidates
5357

5458

codeflash/code_utils/env_utils.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from pathlib import Path
88
from typing import Any, Optional
99

10-
from codeflash.cli_cmds.console import console, logger
10+
from codeflash.cli_cmds.console import logger
1111
from codeflash.code_utils.code_utils import exit_with_message
1212
from codeflash.code_utils.formatter import format_code
1313
from codeflash.code_utils.shell_utils import read_api_key_from_shell_config
14+
from codeflash.lsp.helpers import is_LSP_enabled
1415

1516

1617
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa
@@ -34,11 +35,12 @@ def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool =
3435

3536
@lru_cache(maxsize=1)
3637
def get_codeflash_api_key() -> str:
37-
if console.quiet: # lsp
38-
# prefer shell config over env var in lsp mode
39-
api_key = read_api_key_from_shell_config()
40-
else:
41-
api_key = os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()
38+
# prefer shell config over env var in lsp mode
39+
api_key = (
40+
read_api_key_from_shell_config()
41+
if is_LSP_enabled()
42+
else os.environ.get("CODEFLASH_API_KEY") or read_api_key_from_shell_config()
43+
)
4244

4345
api_secret_docs_message = "For more information, refer to the documentation at [https://docs.codeflash.ai/getting-started/codeflash-github-actions#add-your-api-key-to-your-repository-secrets]." # noqa
4446
if not api_key:
@@ -125,11 +127,6 @@ def is_ci() -> bool:
125127
return bool(os.environ.get("CI") or os.environ.get("GITHUB_ACTIONS"))
126128

127129

128-
@lru_cache(maxsize=1)
129-
def is_LSP_enabled() -> bool:
130-
return console.quiet
131-
132-
133130
def is_pr_draft() -> bool:
134131
"""Check if the PR is draft. in the github action context."""
135132
event = get_cached_gh_event_data()

codeflash/code_utils/formatter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import isort
1414

1515
from codeflash.cli_cmds.console import console, logger
16+
from codeflash.lsp.helpers import is_LSP_enabled
1617

1718

1819
def generate_unified_diff(original: str, modified: str, from_file: str, to_file: str) -> str:
@@ -106,8 +107,7 @@ def format_code(
106107
print_status: bool = True, # noqa
107108
exit_on_failure: bool = True, # noqa
108109
) -> str:
109-
if console.quiet:
110-
# lsp mode
110+
if is_LSP_enabled():
111111
exit_on_failure = False
112112

113113
if isinstance(path, str):

codeflash/code_utils/git_utils.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
from functools import cache
1010
from io import StringIO
1111
from pathlib import Path
12-
from typing import TYPE_CHECKING
12+
from typing import TYPE_CHECKING, Optional
1313

1414
import git
1515
from rich.prompt import Confirm
1616
from unidiff import PatchSet
1717

1818
from codeflash.cli_cmds.console import logger
19+
from codeflash.code_utils.compat import codeflash_cache_dir
1920
from codeflash.code_utils.config_consts import N_CANDIDATES
2021

2122
if TYPE_CHECKING:
@@ -192,3 +193,80 @@ def get_last_commit_author_if_pr_exists(repo: Repo | None = None) -> str | None:
192193
return None
193194
else:
194195
return last_commit.author.name
196+
197+
198+
worktree_dirs = codeflash_cache_dir / "worktrees"
199+
patches_dir = codeflash_cache_dir / "patches"
200+
201+
202+
def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None:
203+
repository = git.Repo(worktree_dir, search_parent_directories=True)
204+
repository.git.commit("-am", commit_message, "--no-verify")
205+
206+
207+
def create_detached_worktree(module_root: Path) -> Optional[Path]:
208+
if not check_running_in_git_repo(module_root):
209+
logger.warning("Module is not in a git repository. Skipping worktree creation.")
210+
return None
211+
git_root = git_root_dir()
212+
current_time_str = time.strftime("%Y%m%d-%H%M%S")
213+
worktree_dir = worktree_dirs / f"{git_root.name}-{current_time_str}"
214+
215+
repository = git.Repo(git_root, search_parent_directories=True)
216+
217+
repository.git.worktree("add", "-d", str(worktree_dir))
218+
219+
# Get uncommitted diff from the original repo
220+
repository.git.add("-N", ".") # add the index for untracked files to be included in the diff
221+
uni_diff_text = repository.git.diff(None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True)
222+
223+
if not uni_diff_text.strip():
224+
logger.info("No uncommitted changes to copy to worktree.")
225+
return worktree_dir
226+
227+
# Write the diff to a temporary file
228+
with tempfile.NamedTemporaryFile(mode="w", suffix=".codeflash.patch", delete=False) as tmp_patch_file:
229+
tmp_patch_file.write(uni_diff_text + "\n") # the new line here is a must otherwise the last hunk won't be valid
230+
tmp_patch_file.flush()
231+
232+
patch_path = Path(tmp_patch_file.name).resolve()
233+
234+
# Apply the patch inside the worktree
235+
try:
236+
subprocess.run(
237+
["git", "apply", "--ignore-space-change", "--ignore-whitespace", patch_path],
238+
cwd=worktree_dir,
239+
check=True,
240+
)
241+
create_worktree_snapshot_commit(worktree_dir, "Initial Snapshot")
242+
except subprocess.CalledProcessError as e:
243+
logger.error(f"Failed to apply patch to worktree: {e}")
244+
245+
return worktree_dir
246+
247+
248+
def remove_worktree(worktree_dir: Path) -> None:
249+
try:
250+
repository = git.Repo(worktree_dir, search_parent_directories=True)
251+
repository.git.worktree("remove", "--force", worktree_dir)
252+
except Exception:
253+
logger.exception(f"Failed to remove worktree: {worktree_dir}")
254+
255+
256+
def create_diff_patch_from_worktree(worktree_dir: Path, files: list[str], fto_name: str) -> Path:
257+
repository = git.Repo(worktree_dir, search_parent_directories=True)
258+
uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True)
259+
260+
if not uni_diff_text:
261+
logger.warning("No changes found in worktree.")
262+
return None
263+
264+
if not uni_diff_text.endswith("\n"):
265+
uni_diff_text += "\n"
266+
267+
# write to patches_dir
268+
patches_dir.mkdir(parents=True, exist_ok=True)
269+
patch_path = patches_dir / f"{worktree_dir.name}.{fto_name}.patch"
270+
with patch_path.open("w", encoding="utf8") as f:
271+
f.write(uni_diff_text)
272+
return patch_path

codeflash/discovery/functions_to_optimize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def get_functions_to_optimize(
208208
logger.info("Finding all functions modified in the current git diff ...")
209209
console.rule()
210210
ph("cli-optimizing-git-diff")
211-
functions = get_functions_within_git_diff()
211+
functions = get_functions_within_git_diff(uncommitted_changes=False)
212212
filtered_modified_functions, functions_count = filter_functions(
213213
functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions
214214
)
@@ -224,8 +224,8 @@ def get_functions_to_optimize(
224224
return filtered_modified_functions, functions_count, trace_file_path
225225

226226

227-
def get_functions_within_git_diff() -> dict[str, list[FunctionToOptimize]]:
228-
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=False)
227+
def get_functions_within_git_diff(uncommitted_changes: bool) -> dict[str, list[FunctionToOptimize]]: # noqa: FBT001
228+
modified_lines: dict[str, list[int]] = get_git_diff(uncommitted_changes=uncommitted_changes)
229229
modified_functions: dict[str, list[FunctionToOptimize]] = {}
230230
for path_str, lines_in_file in modified_lines.items():
231231
path = Path(path_str)

codeflash/lsp/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +0,0 @@
1-
# Silence the console module to prevent stdout pollution
2-
from codeflash.cli_cmds.console import console
3-
4-
console.quiet = True

0 commit comments

Comments
 (0)