Skip to content

Commit 32099b4

Browse files
fixes and getting the worktree to work with PR and staging
1 parent bdf1770 commit 32099b4

File tree

5 files changed

+18
-35
lines changed

5 files changed

+18
-35
lines changed

codeflash/api/cfapi.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from codeflash.cli_cmds.console import console, logger
1616
from codeflash.code_utils.env_utils import ensure_codeflash_api_key, get_codeflash_api_key, get_pr_number
17-
from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name, git_root_dir
17+
from codeflash.code_utils.git_utils import get_current_branch, get_repo_owner_and_name
1818
from codeflash.github.PrComment import FileDiffContent, PrComment
1919
from codeflash.lsp.helpers import is_LSP_enabled
2020
from codeflash.version import __version__
@@ -206,6 +206,7 @@ def create_staging(
206206
coverage_message: str,
207207
replay_tests: str = "",
208208
concolic_tests: str = "",
209+
root_dir: Optional[Path] = None,
209210
) -> Response:
210211
"""Create a staging pull request, targeting the specified branch. (usually 'staging').
211212
@@ -218,12 +219,10 @@ def create_staging(
218219
:param coverage_message: Coverage report or summary.
219220
:return: The response object from the backend.
220221
"""
221-
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()
222+
relative_path = explanation.file_path.relative_to(root_dir).as_posix()
222223

223224
build_file_changes = {
224-
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
225-
oldContent=original_code[p], newContent=new_code[p]
226-
)
225+
Path(p).relative_to(root_dir).as_posix(): FileDiffContent(oldContent=original_code[p], newContent=new_code[p])
227226
for p in original_code
228227
}
229228

codeflash/code_utils/git_utils.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from codeflash.cli_cmds.console import logger
1919
from codeflash.code_utils.compat import codeflash_cache_dir
2020
from codeflash.code_utils.config_consts import N_CANDIDATES
21-
from codeflash.lsp.helpers import is_LSP_enabled
2221

2322
if TYPE_CHECKING:
2423
from git import Repo
@@ -213,19 +212,11 @@ def create_detached_worktree(module_root: Path) -> Optional[Path]:
213212
current_time_str = time.strftime("%Y%m%d-%H%M%S")
214213
worktree_dir = worktree_dirs / f"{git_root.name}-{current_time_str}"
215214

216-
result = subprocess.run(
217-
["git", "worktree", "add", "-d", str(worktree_dir)],
218-
cwd=git_root,
219-
check=True,
220-
stdout=subprocess.DEVNULL if is_LSP_enabled() else None,
221-
stderr=subprocess.DEVNULL if is_LSP_enabled() else None,
222-
)
223-
if result.returncode != 0:
224-
logger.error(f"Failed to create worktree: {result.stderr}")
225-
return None
215+
repository = git.Repo(git_root, search_parent_directories=True)
216+
217+
repository.git.worktree("add", "-d", str(worktree_dir))
226218

227219
# Get uncommitted diff from the original repo
228-
repository = git.Repo(module_root, search_parent_directories=True)
229220
repository.git.add("-N", ".") # add the index for untracked files to be included in the diff
230221
uni_diff_text = repository.git.diff(None, "HEAD", ignore_blank_lines=True, ignore_space_at_eol=True)
231222

@@ -234,7 +225,7 @@ def create_detached_worktree(module_root: Path) -> Optional[Path]:
234225
return worktree_dir
235226

236227
# Write the diff to a temporary file
237-
with tempfile.NamedTemporaryFile(mode="w+", suffix=".codeflash.patch", delete=False) as tmp_patch_file:
228+
with tempfile.NamedTemporaryFile(mode="w", suffix=".codeflash.patch", delete=False) as tmp_patch_file:
238229
tmp_patch_file.write(uni_diff_text + "\n") # the new line here is a must otherwise the last hunk won't be valid
239230
tmp_patch_file.flush()
240231

codeflash/models/models.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -343,15 +343,12 @@ class TestsInFile:
343343
test_type: TestType
344344

345345

346-
@dataclass(frozen=True)
346+
@dataclass
347347
class OptimizedCandidate:
348348
source_code: CodeStringsMarkdown
349349
explanation: str
350350
optimization_id: str
351351

352-
def set_explanation(self, new_explanation: str) -> None:
353-
object.__setattr__(self, "explanation", new_explanation)
354-
355352

356353
@dataclass(frozen=True)
357354
class FunctionCalledInTest:

codeflash/optimization/function_optimizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,7 +1228,7 @@ def process_review(
12281228
benchmark_details=explanation.benchmark_details,
12291229
)
12301230

1231-
best_optimization.candidate.set_explanation(new_explanation)
1231+
best_optimization.candidate.explanation = new_explanation
12321232

12331233
console.print(Panel(new_explanation_raw_str, title="Best Candidate Explanation", border_style="blue"))
12341234

@@ -1250,9 +1250,9 @@ def process_review(
12501250

12511251
if raise_pr and not self.args.staging_review:
12521252
data["git_remote"] = self.args.git_remote
1253-
check_create_pr(**data)
1253+
check_create_pr(**data, root_dir=self.project_root)
12541254
elif self.args.staging_review:
1255-
create_staging(**data)
1255+
create_staging(**data, root_dir=self.project_root)
12561256
else:
12571257
# Mark optimization success since no PR will be created
12581258
mark_optimization_success(

codeflash/result/create_pr.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,7 @@
1010
from codeflash.cli_cmds.console import console, logger
1111
from codeflash.code_utils import env_utils
1212
from codeflash.code_utils.code_replacer import is_zero_diff
13-
from codeflash.code_utils.git_utils import (
14-
check_and_push_branch,
15-
get_current_branch,
16-
get_repo_owner_and_name,
17-
git_root_dir,
18-
)
13+
from codeflash.code_utils.git_utils import check_and_push_branch, get_current_branch, get_repo_owner_and_name
1914
from codeflash.code_utils.github_utils import github_pr_url
2015
from codeflash.code_utils.tabulate import tabulate
2116
from codeflash.code_utils.time_utils import format_perf, format_time
@@ -189,16 +184,17 @@ def check_create_pr(
189184
replay_tests: str,
190185
concolic_tests: str,
191186
git_remote: Optional[str] = None,
187+
root_dir: Optional[Path] = None,
192188
) -> None:
193189
pr_number: Optional[int] = env_utils.get_pr_number()
194190
git_repo = git.Repo(search_parent_directories=True)
195191

196192
if pr_number is not None:
197193
logger.info(f"Suggesting changes to PR #{pr_number} ...")
198194
owner, repo = get_repo_owner_and_name(git_repo)
199-
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()
195+
relative_path = explanation.file_path.relative_to(root_dir).as_posix()
200196
build_file_changes = {
201-
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
197+
Path(p).relative_to(root_dir).as_posix(): FileDiffContent(
202198
oldContent=original_code[p], newContent=new_code[p]
203199
)
204200
for p in original_code
@@ -247,10 +243,10 @@ def check_create_pr(
247243
if not check_and_push_branch(git_repo, git_remote, wait_for_push=True):
248244
logger.warning("⏭️ Branch is not pushed, skipping PR creation...")
249245
return
250-
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()
246+
relative_path = explanation.file_path.relative_to(root_dir).as_posix()
251247
base_branch = get_current_branch()
252248
build_file_changes = {
253-
Path(p).relative_to(git_root_dir()).as_posix(): FileDiffContent(
249+
Path(p).relative_to(root_dir).as_posix(): FileDiffContent(
254250
oldContent=original_code[p], newContent=new_code[p]
255251
)
256252
for p in original_code

0 commit comments

Comments
 (0)