Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion codeflash/cli_cmds/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace:
"I need a git repository to run --all and open PRs for optimizations. Exiting..."
)
apologize_and_exit()
if not args.no_pr and not check_and_push_branch(git_repo):
if not args.no_pr and not check_and_push_branch(git_repo, git_remote=args.git_remote):
exit_with_message("Branch is not pushed...", error_on_exit=True)
owner, repo = get_repo_owner_and_name(git_repo)
if not args.no_pr:
Expand Down
12 changes: 6 additions & 6 deletions codeflash/code_utils/git_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,12 @@ def confirm_proceeding_with_no_git_repo() -> str | bool:
return True


def check_and_push_branch(repo: git.Repo, wait_for_push: bool = False) -> bool: # noqa: FBT001, FBT002
def check_and_push_branch(repo: git.Repo, git_remote: str | None = "origin", wait_for_push: bool = False) -> bool: # noqa: FBT001, FBT002
current_branch = repo.active_branch.name
origin = repo.remote(name="origin")
remote = repo.remote(name=git_remote)

# Check if the branch is pushed
if f"origin/{current_branch}" not in repo.refs:
if f"{git_remote}/{current_branch}" not in repo.refs:
logger.warning(f"⚠️ The branch '{current_branch}' is not pushed to the remote repository.")
if not sys.__stdin__.isatty():
logger.warning("Non-interactive shell detected. Branch will not be pushed.")
Expand All @@ -132,13 +132,13 @@ def check_and_push_branch(repo: git.Repo, wait_for_push: bool = False) -> bool:
f"the branch '{current_branch}' to the remote repository?",
default=False,
):
origin.push(current_branch)
logger.info(f"⬆️ Branch '{current_branch}' has been pushed to origin.")
remote.push(current_branch)
logger.info(f"⬆️ Branch '{current_branch}' has been pushed to {git_remote}.")
if wait_for_push:
time.sleep(3) # adding this to give time for the push to register with GitHub,
# so that our modifications to it are not rejected
return True
logger.info(f"🔘 Branch '{current_branch}' has not been pushed to origin.")
logger.info(f"🔘 Branch '{current_branch}' has not been pushed to {git_remote}.")
return False
logger.debug(f"The branch '{current_branch}' is present in the remote repository.")
return True
Expand Down
2 changes: 1 addition & 1 deletion codeflash/result/create_pr.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def check_create_pr(
owner, repo = get_repo_owner_and_name(git_repo, git_remote)
logger.info(f"Pushing to {git_remote} - Owner: {owner}, Repo: {repo}")
console.rule()
if not check_and_push_branch(git_repo, wait_for_push=True):
if not check_and_push_branch(git_repo, git_remote, wait_for_push=True):
logger.warning("⏭️ Branch is not pushed, skipping PR creation...")
return
relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()
Expand Down
Loading