Skip to content

Commit e806c5d

Browse files
authored
Merge pull request #598 from codeflash-ai/git-branch-fixes
fix crash on unpushed branch push
2 parents d5ec766 + ac859b5 commit e806c5d

File tree

5 files changed

+404
-346
lines changed

5 files changed

+404
-346
lines changed

codeflash/code_utils/checkpoint.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
from rich.prompt import Confirm
1212

13+
from codeflash.cli_cmds.console import console
14+
1315
if TYPE_CHECKING:
1416
import argparse
1517

@@ -142,8 +144,11 @@ def ask_should_use_checkpoint_get_functions(args: argparse.Namespace) -> Optiona
142144
if previous_checkpoint_functions and Confirm.ask(
143145
"Previous Checkpoint detected from an incomplete optimization run, shall I continue the optimization from that point?",
144146
default=True,
147+
console=console,
145148
):
146-
pass
149+
console.rule()
147150
else:
148151
previous_checkpoint_functions = None
152+
153+
console.rule()
149154
return previous_checkpoint_functions

codeflash/code_utils/git_utils.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
from git import Repo
2323

2424

25-
def get_git_diff(repo_directory: Path = Path.cwd(), uncommitted_changes: bool = False) -> dict[str, list[int]]: # noqa: B008, FBT001, FBT002
25+
def get_git_diff(repo_directory: Path | None = None, *, uncommitted_changes: bool = False) -> dict[str, list[int]]:
26+
if repo_directory is None:
27+
repo_directory = Path.cwd()
2628
repository = git.Repo(repo_directory, search_parent_directories=True)
2729
commit = repository.head.commit
2830
if uncommitted_changes:
@@ -117,30 +119,31 @@ def confirm_proceeding_with_no_git_repo() -> str | bool:
117119
return True
118120

119121

120-
def check_and_push_branch(repo: git.Repo, git_remote: str | None = "origin", wait_for_push: bool = False) -> bool: # noqa: FBT001, FBT002
121-
current_branch = repo.active_branch.name
122+
def check_and_push_branch(repo: git.Repo, git_remote: str | None = "origin", *, wait_for_push: bool = False) -> bool:
123+
current_branch = repo.active_branch
124+
current_branch_name = current_branch.name
122125
remote = repo.remote(name=git_remote)
123126

124127
# Check if the branch is pushed
125-
if f"{git_remote}/{current_branch}" not in repo.refs:
126-
logger.warning(f"⚠️ The branch '{current_branch}' is not pushed to the remote repository.")
128+
if f"{git_remote}/{current_branch_name}" not in repo.refs:
129+
logger.warning(f"⚠️ The branch '{current_branch_name}' is not pushed to the remote repository.")
127130
if not sys.__stdin__.isatty():
128131
logger.warning("Non-interactive shell detected. Branch will not be pushed.")
129132
return False
130133
if sys.__stdin__.isatty() and Confirm.ask(
131134
f"⚡️ In order for me to create PRs, your current branch needs to be pushed. Do you want to push "
132-
f"the branch '{current_branch}' to the remote repository?",
135+
f"the branch '{current_branch_name}' to the remote repository?",
133136
default=False,
134137
):
135138
remote.push(current_branch)
136-
logger.info(f"⬆️ Branch '{current_branch}' has been pushed to {git_remote}.")
139+
logger.info(f"⬆️ Branch '{current_branch_name}' has been pushed to {git_remote}.")
137140
if wait_for_push:
138141
time.sleep(3) # adding this to give time for the push to register with GitHub,
139142
# so that our modifications to it are not rejected
140143
return True
141-
logger.info(f"🔘 Branch '{current_branch}' has not been pushed to {git_remote}.")
144+
logger.info(f"🔘 Branch '{current_branch_name}' has not been pushed to {git_remote}.")
142145
return False
143-
logger.debug(f"The branch '{current_branch}' is present in the remote repository.")
146+
logger.debug(f"The branch '{current_branch_name}' is present in the remote repository.")
144147
return True
145148

146149

codeflash/discovery/functions_to_optimize.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import git
1414
import libcst as cst
1515
from pydantic.dataclasses import dataclass
16+
from rich.tree import Tree
1617

1718
from codeflash.api.cfapi import get_blocklisted_functions, is_function_being_optimized_again
1819
from codeflash.cli_cmds.console import DEBUG_MODE, console, logger
@@ -37,6 +38,7 @@
3738

3839
from codeflash.models.models import CodeOptimizationContext
3940
from codeflash.verification.verification_utils import TestConfig
41+
from rich.text import Text
4042

4143

4244
@dataclass(frozen=True)
@@ -594,20 +596,22 @@ def filter_functions(
594596

595597
if not disable_logs:
596598
log_info = {
597-
f"{test_functions_removed_count} test function{'s' if test_functions_removed_count != 1 else ''}": test_functions_removed_count,
598-
f"{site_packages_removed_count} site-package function{'s' if site_packages_removed_count != 1 else ''}": site_packages_removed_count,
599-
f"{malformed_paths_count} non-importable file path{'s' if malformed_paths_count != 1 else ''}": malformed_paths_count,
600-
f"{non_modules_removed_count} function{'s' if non_modules_removed_count != 1 else ''} outside module-root": non_modules_removed_count,
601-
f"{ignore_paths_removed_count} file{'s' if ignore_paths_removed_count != 1 else ''} from ignored paths": ignore_paths_removed_count,
602-
f"{submodule_ignored_paths_count} file{'s' if submodule_ignored_paths_count != 1 else ''} from ignored submodules": submodule_ignored_paths_count,
603-
f"{blocklist_funcs_removed_count} function{'s' if blocklist_funcs_removed_count != 1 else ''} as previously optimized": blocklist_funcs_removed_count,
604-
f"{previous_checkpoint_functions_removed_count} function{'s' if previous_checkpoint_functions_removed_count != 1 else ''} skipped from checkpoint": previous_checkpoint_functions_removed_count,
599+
"Test functions removed": (test_functions_removed_count, "yellow"),
600+
"Site-package functions removed": (site_packages_removed_count, "magenta"),
601+
"Non-importable file paths": (malformed_paths_count, "red"),
602+
"Functions outside module-root": (non_modules_removed_count, "cyan"),
603+
"Files from ignored paths": (ignore_paths_removed_count, "blue"),
604+
"Files from ignored submodules": (submodule_ignored_paths_count, "bright_black"),
605+
"Blocklisted functions removed": (blocklist_funcs_removed_count, "bright_red"),
606+
"Functions skipped from checkpoint": (previous_checkpoint_functions_removed_count, "green"),
605607
}
606-
log_string = "\n".join([k for k, v in log_info.items() if v > 0])
607-
if log_string:
608-
logger.info(f"Ignoring: {log_string}")
608+
tree = Tree(Text("Ignored functions and files", style="bold"))
609+
for label, (count, color) in log_info.items():
610+
if count > 0:
611+
tree.add(Text(f"{label}: {count}", style=color))
612+
if len(tree.children) > 0:
613+
console.print(tree)
609614
console.rule()
610-
611615
return {Path(k): v for k, v in filtered_modified_functions.items() if v}, functions_count
612616

613617

tests/test_git_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_check_and_push_branch(self, mock_confirm, mock_isatty, mock_repo):
7474
mock_origin.push.return_value = None
7575

7676
assert check_and_push_branch(mock_repo_instance)
77-
mock_origin.push.assert_called_once_with("test-branch")
77+
mock_origin.push.assert_called_once_with(mock_repo_instance.active_branch)
7878
mock_origin.push.reset_mock()
7979

8080
# Test when branch is already pushed

0 commit comments

Comments
 (0)