Skip to content

Commit c595a7a

Browse files
getting worktree to work in both lsp and normal cli mode
1 parent 681fa17 commit c595a7a

File tree

9 files changed

+191
-19
lines changed

9 files changed

+191
-19
lines changed

codeflash/cli_cmds/cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def parse_args() -> Namespace:
9494
help="Path to the directory of the project, where all the pytest-benchmark tests are located.",
9595
)
9696
parser.add_argument("--no-draft", default=False, action="store_true", help="Skip optimization for draft PRs")
97+
parser.add_argument("--worktree", default=False, action="store_true", help="Use worktree for optimization")
9798

9899
args, unknown_args = parser.parse_known_args()
99100
sys.argv[:] = [sys.argv[0], *unknown_args]

codeflash/code_utils/coverage_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,20 @@ def build_fully_qualified_name(function_name: str, code_context: CodeOptimizatio
3939
return full_name
4040

4141

42-
def generate_candidates(source_code_path: Path) -> list[str]:
42+
def generate_candidates(source_code_path: Path) -> set[str]:
4343
"""Generate all the possible candidates for coverage data based on the source code path."""
44-
candidates = [source_code_path.name]
44+
candidates = set()
45+
candidates.add(source_code_path.name)
4546
current_path = source_code_path.parent
4647

48+
last_added = source_code_path.name
4749
while current_path != current_path.parent:
48-
candidate_path = str(Path(current_path.name) / candidates[-1])
49-
candidates.append(candidate_path)
50+
candidate_path = str(Path(current_path.name) / last_added)
51+
candidates.add(candidate_path)
52+
last_added = candidate_path
5053
current_path = current_path.parent
5154

55+
candidates.add(str(source_code_path))
5256
return candidates
5357

5458

codeflash/code_utils/git_utils.py

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
from functools import cache
1010
from io import StringIO
1111
from pathlib import Path
12-
from typing import TYPE_CHECKING
12+
from typing import TYPE_CHECKING, Optional
1313

1414
import git
1515
from rich.prompt import Confirm
1616
from unidiff import PatchSet
1717

1818
from codeflash.cli_cmds.console import logger
19+
from codeflash.code_utils.compat import codeflash_cache_dir
1920
from codeflash.code_utils.config_consts import N_CANDIDATES
2021

2122
if TYPE_CHECKING:
@@ -192,3 +193,75 @@ def get_last_commit_author_if_pr_exists(repo: Repo | None = None) -> str | None:
192193
return None
193194
else:
194195
return last_commit.author.name
196+
197+
198+
worktree_dirs = codeflash_cache_dir / "worktrees"
199+
patches_dir = codeflash_cache_dir / "patches"
200+
201+
202+
def create_worktree_snapshot_commit(worktree_dir: Path, commit_message: str) -> None:
203+
repository = git.Repo(worktree_dir, search_parent_directories=True)
204+
repository.git.commit("-am", commit_message, "--no-verify")
205+
206+
207+
def create_detached_worktree(module_root: Path) -> Optional[Path]:
208+
if not check_running_in_git_repo(module_root):
209+
logger.warning("Module is not in a git repository. Skipping worktree creation.")
210+
return None
211+
git_root = git_root_dir()
212+
current_time_str = time.strftime("%Y%m%d-%H%M%S")
213+
worktree_dir = worktree_dirs / f"{git_root.name}-{current_time_str}"
214+
215+
result = subprocess.run(["git", "worktree", "add", "-d", str(worktree_dir)], cwd=git_root, check=True)
216+
if result.returncode != 0:
217+
logger.error(f"Failed to create worktree: {result.stderr}")
218+
return None
219+
220+
print(result.stdout)
221+
222+
# Get uncommitted diff from the original repo
223+
repository = git.Repo(module_root, search_parent_directories=True)
224+
uni_diff_text = repository.git.diff(None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True)
225+
226+
if not uni_diff_text.strip():
227+
logger.info("No uncommitted changes to copy to worktree.")
228+
return worktree_dir
229+
230+
# Write the diff to a temporary file
231+
with tempfile.NamedTemporaryFile(mode="w+", suffix=".codeflash.patch", delete=False) as tmp_patch_file:
232+
tmp_patch_file.write(uni_diff_text + "\n") # the new line here is a must otherwise the last hunk won't be valid
233+
tmp_patch_file.flush()
234+
235+
patch_path = Path(tmp_patch_file.name).resolve()
236+
237+
# Apply the patch inside the worktree
238+
try:
239+
subprocess.run(["git", "apply", patch_path], cwd=worktree_dir, check=True)
240+
create_worktree_snapshot_commit(worktree_dir, "Initial Snapshot")
241+
except subprocess.CalledProcessError as e:
242+
logger.error(f"Failed to apply patch to worktree: {e}")
243+
244+
return worktree_dir
245+
246+
247+
def remove_worktree(worktree_dir: Path) -> None:
248+
try:
249+
repository = git.Repo(worktree_dir, search_parent_directories=True)
250+
repository.git.worktree("remove", "--force", worktree_dir)
251+
except Exception:
252+
logger.exception(f"Failed to remove worktree: {worktree_dir}")
253+
254+
255+
def create_diff_from_worktree(worktree_dir: Path, files: list[str], fto_name: str) -> Path:
256+
repository = git.Repo(worktree_dir, search_parent_directories=True)
257+
uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True)
258+
uni_diff_text = uni_diff_text.strip()
259+
if not uni_diff_text:
260+
logger.warning("No changes found in worktree.")
261+
return None
262+
# write to patches_dir
263+
patches_dir.mkdir(parents=True, exist_ok=True)
264+
patch_path = patches_dir / f"{worktree_dir.name}.{fto_name}.patch"
265+
with patch_path.open("w", encoding="utf8") as f:
266+
f.write(uni_diff_text + "\n")
267+
return patch_path

codeflash/lsp/beta.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def initialize_function_optimization(
100100
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
101101
)
102102

103+
server.optimizer.worktree_mode()
103104
optimizable_funcs, _, _ = server.optimizer.get_optimizable_functions()
104105
if not optimizable_funcs:
105106
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
@@ -334,9 +335,13 @@ def perform_function_optimization( # noqa: PLR0911
334335
speedup = original_code_baseline.runtime / best_optimization.runtime
335336

336337
server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info")
338+
diff_patch_files = server.optimizer.patch_files
337339

338340
# CRITICAL: Clear the function filter after optimization to prevent state corruption
339341
server.optimizer.args.function = None
342+
server.optimizer.patch_files = []
343+
server.optimizer.current_worktree = None
344+
server.optimizer.cleanup_temporary_paths()
340345
server.show_message_log("Cleared function filter to prevent state corruption", "Info")
341346

342347
return {
@@ -345,4 +350,5 @@ def perform_function_optimization( # noqa: PLR0911
345350
"message": "Optimization completed successfully",
346351
"extra": f"Speedup: {speedup:.2f}x faster",
347352
"optimization": optimized_source,
353+
"diff_patch_files": diff_patch_files,
348354
}

codeflash/lsp/server.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
if TYPE_CHECKING:
1212
from lsprotocol.types import InitializeParams, InitializeResult
1313

14+
from codeflash.optimization.optimizer import Optimizer
15+
1416

1517
class CodeflashLanguageServerProtocol(LanguageServerProtocol):
1618
_server: CodeflashLanguageServer
@@ -44,7 +46,7 @@ def _find_pyproject_toml(self, workspace_path: str) -> Path | None:
4446
class CodeflashLanguageServer(LanguageServer):
4547
def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
4648
super().__init__(*args, **kwargs)
47-
self.optimizer = None
49+
self.optimizer: Optimizer | None = None
4850
self.args = None
4951

5052
def prepare_optimizer_arguments(self, config_file: Path) -> None:

codeflash/optimization/function_optimizer.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,7 +1239,9 @@ def process_review(
12391239
"coverage_message": coverage_message,
12401240
}
12411241

1242-
if not self.args.no_pr and not self.args.staging_review:
1242+
raise_pr = not self.args.no_pr
1243+
1244+
if raise_pr and not self.args.staging_review:
12431245
data["git_remote"] = self.args.git_remote
12441246
check_create_pr(**data)
12451247
elif self.args.staging_review:
@@ -1250,12 +1252,24 @@ def process_review(
12501252
trace_id=self.function_trace_id, is_optimization_found=best_optimization is not None
12511253
)
12521254

1253-
if ((not self.args.no_pr) or not self.args.staging_review) and (
1254-
self.args.all or env_utils.get_pr_number() or (self.args.file and not self.args.function)
1255+
if raise_pr and (
1256+
self.args.all
1257+
or env_utils.get_pr_number()
1258+
or self.args.replay_test
1259+
or (self.args.file and not self.args.function)
12551260
):
1256-
self.write_code_and_helpers(
1257-
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
1258-
)
1261+
self.revert_code_and_helpers(original_helper_code)
1262+
return
1263+
1264+
if self.args.staging_review:
1265+
# always revert code and helpers when staging review
1266+
self.revert_code_and_helpers(original_helper_code)
1267+
return
1268+
1269+
def revert_code_and_helpers(self, original_helper_code: dict[Path, str]) -> None:
1270+
self.write_code_and_helpers(
1271+
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
1272+
)
12591273

12601274
def establish_original_code_baseline(
12611275
self,

codeflash/optimization/optimizer.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414
from codeflash.code_utils import env_utils
1515
from codeflash.code_utils.code_utils import cleanup_paths, get_run_tmp_file
1616
from codeflash.code_utils.env_utils import get_pr_number, is_pr_draft
17+
from codeflash.code_utils.git_utils import (
18+
check_running_in_git_repo,
19+
create_detached_worktree,
20+
create_diff_from_worktree,
21+
create_worktree_snapshot_commit,
22+
remove_worktree,
23+
)
1724
from codeflash.either import is_successful
1825
from codeflash.models.models import ValidCode
1926
from codeflash.telemetry.posthog_cf import ph
@@ -48,6 +55,8 @@ def __init__(self, args: Namespace) -> None:
4855
self.functions_checkpoint: CodeflashRunCheckpoint | None = None
4956
self.current_function_being_optimized: FunctionToOptimize | None = None # current only for the LSP
5057
self.current_function_optimizer: FunctionOptimizer | None = None
58+
self.current_worktree: Path | None = None
59+
self.patch_files: list[Path] = []
5160

5261
def run_benchmarks(
5362
self, file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]], num_optimizable_functions: int
@@ -252,6 +261,10 @@ def run(self) -> None:
252261
if self.args.no_draft and is_pr_draft():
253262
logger.warning("PR is in draft mode, skipping optimization")
254263
return
264+
265+
if self.args.worktree:
266+
self.worktree_mode()
267+
255268
cleanup_paths(Optimizer.find_leftover_instrumented_test_files(self.test_cfg.tests_root))
256269

257270
function_optimizer = None
@@ -260,7 +273,6 @@ def run(self) -> None:
260273
file_to_funcs_to_optimize, num_optimizable_functions
261274
)
262275
optimizations_found: int = 0
263-
function_iterator_count: int = 0
264276
if self.args.test_framework == "pytest":
265277
self.test_cfg.concolic_test_root_dir = Path(
266278
tempfile.mkdtemp(dir=self.args.tests_root, prefix="codeflash_concolic_")
@@ -296,8 +308,8 @@ def run(self) -> None:
296308
except Exception as e:
297309
logger.debug(f"Could not rank functions in {original_module_path}: {e}")
298310

299-
for function_to_optimize in functions_to_optimize:
300-
function_iterator_count += 1
311+
for i, function_to_optimize in enumerate(functions_to_optimize):
312+
function_iterator_count = i + 1
301313
logger.info(
302314
f"Optimizing function {function_iterator_count} of {num_optimizable_functions}: "
303315
f"{function_to_optimize.qualified_name}"
@@ -327,6 +339,22 @@ def run(self) -> None:
327339
)
328340
if is_successful(best_optimization):
329341
optimizations_found += 1
342+
if self.current_worktree:
343+
read_writable_code = best_optimization.unwrap().code_context.read_writable_code
344+
relative_file_paths = [
345+
code_string.file_path for code_string in read_writable_code.code_strings
346+
]
347+
patch_path = create_diff_from_worktree(
348+
self.current_worktree,
349+
relative_file_paths,
350+
self.current_function_optimizer.function_to_optimize.qualified_name,
351+
)
352+
self.patch_files.append(patch_path)
353+
if i < len(functions_to_optimize) - 1:
354+
create_worktree_snapshot_commit(
355+
self.current_worktree,
356+
f"Optimizing {functions_to_optimize[i + 1].qualified_name}",
357+
)
330358
else:
331359
logger.warning(best_optimization.failure())
332360
console.rule()
@@ -337,6 +365,10 @@ def run(self) -> None:
337365
function_optimizer.cleanup_generated_files()
338366

339367
ph("cli-optimize-run-finished", {"optimizations_found": optimizations_found})
368+
if len(self.patch_files) > 0:
369+
logger.info(
370+
f"Created {len(self.patch_files)} patch(es) ({[str(patch_path) for patch_path in self.patch_files]})"
371+
)
340372
if self.functions_checkpoint:
341373
self.functions_checkpoint.cleanup()
342374
if hasattr(self.args, "command") and self.args.command == "optimize":
@@ -388,7 +420,46 @@ def cleanup_temporary_paths(self) -> None:
388420
if hasattr(get_run_tmp_file, "tmpdir"):
389421
get_run_tmp_file.tmpdir.cleanup()
390422
del get_run_tmp_file.tmpdir
391-
cleanup_paths([self.test_cfg.concolic_test_root_dir, self.replay_tests_dir])
423+
paths_to_clean = [self.test_cfg.concolic_test_root_dir, self.replay_tests_dir]
424+
if self.current_worktree:
425+
remove_worktree(self.current_worktree)
426+
cleanup_paths(paths_to_clean)
427+
428+
def worktree_mode(self) -> None:
429+
if self.current_worktree:
430+
return
431+
project_root = self.args.project_root
432+
module_root = self.args.module_root
433+
434+
if check_running_in_git_repo(module_root):
435+
relative_module_root = module_root.relative_to(project_root)
436+
relative_optimized_file = self.args.file.relative_to(project_root) if self.args.file else None
437+
relative_tests_root = self.test_cfg.tests_root.relative_to(project_root)
438+
relative_benchmarks_root = (
439+
self.args.benchmarks_root.relative_to(project_root) if self.args.benchmarks_root else None
440+
)
441+
442+
worktree_dir = create_detached_worktree(module_root)
443+
if worktree_dir is None:
444+
logger.warning("Failed to create worktree. Skipping optimization.")
445+
return
446+
self.current_worktree = worktree_dir
447+
# TODO: use a helper function to mutate self.args and self.test_cfg
448+
self.args.module_root = worktree_dir / relative_module_root
449+
self.args.project_root = worktree_dir
450+
self.args.test_project_root = worktree_dir
451+
self.args.tests_root = worktree_dir / relative_tests_root
452+
if relative_benchmarks_root:
453+
self.args.benchmarks_root = worktree_dir / relative_benchmarks_root
454+
455+
self.test_cfg.project_root_path = worktree_dir
456+
self.test_cfg.tests_project_rootdir = worktree_dir
457+
self.test_cfg.tests_root = worktree_dir / relative_tests_root
458+
if relative_benchmarks_root:
459+
self.test_cfg.benchmark_tests_root = worktree_dir / relative_benchmarks_root
460+
461+
if relative_optimized_file is not None:
462+
self.args.file = worktree_dir / relative_optimized_file
392463

393464

394465
def run_with_args(args: Namespace) -> None:

codeflash/telemetry/sentry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def init_sentry(enabled: bool = False, exclude_errors: bool = False) -> None: #
1414
)
1515

1616
sentry_sdk.init(
17-
dsn="https://4b9a1902f9361b48c04376df6483bc96@o4506833230561280.ingest.sentry.io/4506833262477312",
17+
dsn="",
1818
integrations=[sentry_logging],
1919
# Set traces_sample_rate to 1.0 to capture 100%
2020
# of transactions for performance monitoring.

tests/test_code_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def test_cleanup_paths(multiple_existing_and_non_existing_files: list[Path]) ->
357357

358358
def test_generate_candidates() -> None:
359359
source_code_path = Path("/Users/krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py")
360-
expected_candidates = [
360+
expected_candidates = {
361361
"coverage_utils.py",
362362
"code_utils/coverage_utils.py",
363363
"codeflash/code_utils/coverage_utils.py",
@@ -367,7 +367,8 @@ def test_generate_candidates() -> None:
367367
"Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py",
368368
"krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py",
369369
"Users/krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py",
370-
]
370+
"/Users/krrt7/Desktop/work/codeflash/cli/codeflash/code_utils/coverage_utils.py"
371+
}
371372
assert generate_candidates(source_code_path) == expected_candidates
372373

373374

0 commit comments

Comments
 (0)