diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index 265d32748..88dc96edc 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -234,7 +234,7 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace: "I need a git repository to run --all and open PRs for optimizations. Exiting..." ) apologize_and_exit() - if not args.no_pr and not check_and_push_branch(git_repo): + if not args.no_pr and not check_and_push_branch(git_repo, git_remote=args.git_remote): exit_with_message("Branch is not pushed...", error_on_exit=True) owner, repo = get_repo_owner_and_name(git_repo) if not args.no_pr: diff --git a/codeflash/code_utils/git_utils.py b/codeflash/code_utils/git_utils.py index ba3a19a2e..69575219a 100644 --- a/codeflash/code_utils/git_utils.py +++ b/codeflash/code_utils/git_utils.py @@ -117,12 +117,12 @@ def confirm_proceeding_with_no_git_repo() -> str | bool: return True -def check_and_push_branch(repo: git.Repo, wait_for_push: bool = False) -> bool: # noqa: FBT001, FBT002 +def check_and_push_branch(repo: git.Repo, git_remote: str | None = "origin", wait_for_push: bool = False) -> bool: # noqa: FBT001, FBT002 current_branch = repo.active_branch.name - origin = repo.remote(name="origin") + remote = repo.remote(name=git_remote) # Check if the branch is pushed - if f"origin/{current_branch}" not in repo.refs: + if f"{git_remote}/{current_branch}" not in repo.refs: logger.warning(f"⚠️ The branch '{current_branch}' is not pushed to the remote repository.") if not sys.__stdin__.isatty(): logger.warning("Non-interactive shell detected. Branch will not be pushed.") @@ -132,13 +132,13 @@ def check_and_push_branch(repo: git.Repo, wait_for_push: bool = False) -> bool: f"the branch '{current_branch}' to the remote repository?", default=False, ): - origin.push(current_branch) - logger.info(f"⬆️ Branch '{current_branch}' has been pushed to origin.") + remote.push(current_branch) + logger.info(f"⬆️ Branch '{current_branch}' has been pushed to {git_remote}.") if wait_for_push: time.sleep(3) # adding this to give time for the push to register with GitHub, # so that our modifications to it are not rejected return True - logger.info(f"🔘 Branch '{current_branch}' has not been pushed to origin.") + logger.info(f"🔘 Branch '{current_branch}' has not been pushed to {git_remote}.") return False logger.debug(f"The branch '{current_branch}' is present in the remote repository.") return True diff --git a/codeflash/result/create_pr.py b/codeflash/result/create_pr.py index 00ef1953c..2134cee09 100644 --- a/codeflash/result/create_pr.py +++ b/codeflash/result/create_pr.py @@ -185,7 +185,7 @@ def check_create_pr( owner, repo = get_repo_owner_and_name(git_repo, git_remote) logger.info(f"Pushing to {git_remote} - Owner: {owner}, Repo: {repo}") console.rule() - if not check_and_push_branch(git_repo, wait_for_push=True): + if not check_and_push_branch(git_repo, git_remote, wait_for_push=True): logger.warning("⏭️ Branch is not pushed, skipping PR creation...") return relative_path = explanation.file_path.relative_to(git_root_dir()).as_posix()