Skip to content

Commit c03f256

Browse files
committed
Fix several bugs in stack-pr codebase
- Improve error handling in git ancestry check - Fix JSON handling in GitHub API responses - Add proper PR closing in delete_remote_branches function - Fix branch name generation race condition - Improve stash handling in main function - Update shell command type annotations for better typing
1 parent b569b3d commit c03f256

File tree

3 files changed

+189
-81
lines changed

3 files changed

+189
-81
lines changed

src/stack_pr/cli.py

Lines changed: 160 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import logging
5656
import os
5757
import re
58+
import subprocess
5859
import sys
5960
from dataclasses import dataclass
6061
from functools import cache
@@ -399,19 +400,40 @@ def last(ref: str, sep: str = "/") -> str:
399400
return ref.rsplit(sep, 1)[-1]
400401

401402

402-
# TODO: Move to 'modular.utils.git'
403+
class GitAncestryError(RuntimeError):
404+
"""Error raised when git ancestry check fails."""
405+
406+
403407
def is_ancestor(commit1: str, commit2: str, *, verbose: bool) -> bool:
404408
"""
405409
Returns true if 'commit1' is an ancestor of 'commit2'.
410+
411+
Raises:
412+
GitAncestryError: If git command fails for reasons other than ancestry check.
406413
"""
407-
# TODO: We need to check returncode of this command more carefully, as the
408-
# command simply might fail (rc != 0 and rc != 1).
409-
p = run_shell_command(
410-
["git", "merge-base", "--is-ancestor", commit1, commit2],
411-
check=False,
412-
quiet=not verbose,
413-
)
414-
return p.returncode == 0
414+
error_code = None
415+
416+
try:
417+
p = run_shell_command(
418+
["git", "merge-base", "--is-ancestor", commit1, commit2],
419+
check=False,
420+
quiet=not verbose,
421+
)
422+
if p.returncode == 0:
423+
return True
424+
if p.returncode == 1:
425+
return False
426+
427+
# Store error code for later
428+
error_code = p.returncode
429+
except subprocess.SubprocessError as e:
430+
raise GitAncestryError(f"Failed to determine ancestry relationship: {e}") from e
431+
432+
# Handle error code outside the try block
433+
if error_code is None:
434+
# This should never happen, but just in case
435+
raise GitAncestryError("Unexpected error in git ancestry check")
436+
raise GitAncestryError(f"Git ancestry check failed with code {error_code}")
415437

416438

417439
def is_repo_clean() -> bool:
@@ -456,59 +478,74 @@ def set_base_branches(st: list[StackEntry], target: str) -> None:
456478

457479
def verify(st: list[StackEntry], *, check_base: bool = False) -> None:
458480
log(h("Verifying stack info"))
459-
for index, e in enumerate(st):
460-
if e.has_missing_info():
461-
error(ERROR_STACKINFO_MISSING.format(**locals()))
481+
for index, entry in enumerate(st):
482+
if entry.has_missing_info():
483+
error(ERROR_STACKINFO_MISSING.format(e=entry))
462484
raise RuntimeError
463485

464-
if len(e.pr.split("/")) == 0 or not last(e.pr).isnumeric():
465-
error(ERROR_STACKINFO_BAD_LINK.format(**locals()))
486+
if len(entry.pr.split("/")) == 0 or not last(entry.pr).isnumeric():
487+
error(ERROR_STACKINFO_BAD_LINK.format(e=entry))
466488
raise RuntimeError
467489

468-
ghinfo = get_command_output(
469-
[
470-
"gh",
471-
"pr",
472-
"view",
473-
e.pr,
474-
"--json",
475-
"baseRefName,headRefName,number,state,body,title,url,mergeStateStatus",
476-
]
477-
)
478-
d = json.loads(ghinfo)
479-
for required_field in ["state", "number", "baseRefName", "headRefName"]:
480-
if required_field not in d:
481-
error(ERROR_STACKINFO_MALFORMED_RESPONSE.format(**locals()))
490+
try:
491+
ghinfo = get_command_output(
492+
[
493+
"gh",
494+
"pr",
495+
"view",
496+
entry.pr,
497+
"--json",
498+
"baseRefName,headRefName,number,state,body,title,url,mergeStateStatus",
499+
]
500+
)
501+
502+
try:
503+
d = json.loads(ghinfo)
504+
except json.JSONDecodeError as e:
505+
error(f"Failed to parse JSON response from GitHub: {ghinfo}")
506+
raise RuntimeError("Invalid JSON response from GitHub") from e
507+
508+
for required_field in ["state", "number", "baseRefName", "headRefName"]:
509+
if required_field not in d:
510+
error(
511+
ERROR_STACKINFO_MALFORMED_RESPONSE.format(
512+
e=entry, required_field=required_field, d=d
513+
)
514+
)
515+
raise RuntimeError
516+
517+
if d["state"] != "OPEN":
518+
error(ERROR_STACKINFO_PR_NOT_OPEN.format(e=entry, d=d))
482519
raise RuntimeError
483520

484-
if d["state"] != "OPEN":
485-
error(ERROR_STACKINFO_PR_NOT_OPEN.format(**locals()))
486-
raise RuntimeError
521+
if int(last(entry.pr)) != d["number"]:
522+
error(ERROR_STACKINFO_PR_NUMBER_MISMATCH.format(e=entry, d=d))
523+
raise RuntimeError
487524

488-
if int(last(e.pr)) != d["number"]:
489-
error(ERROR_STACKINFO_PR_NUMBER_MISMATCH.format(**locals()))
490-
raise RuntimeError
525+
if entry.head != d["headRefName"]:
526+
error(ERROR_STACKINFO_PR_HEAD_MISMATCH.format(e=entry, d=d))
527+
raise RuntimeError
491528

492-
if e.head != d["headRefName"]:
493-
error(ERROR_STACKINFO_PR_HEAD_MISMATCH.format(**locals()))
494-
raise RuntimeError
529+
# 'Base' branch might diverge when the stack is modified (e.g. when a
530+
# new commit is added to the middle of the stack). It is not an issue
531+
# if we're updating the stack (i.e. in 'submit'), but it is an issue if
532+
# we are trying to land it.
533+
if check_base and entry.base != d["baseRefName"]:
534+
error(ERROR_STACKINFO_PR_BASE_MISMATCH.format(e=entry, d=d))
535+
raise RuntimeError
495536

496-
# 'Base' branch might diverge when the stack is modified (e.g. when a
497-
# new commit is added to the middle of the stack). It is not an issue
498-
# if we're updating the stack (i.e. in 'submit'), but it is an issue if
499-
# we are trying to land it.
500-
if check_base and e.base != d["baseRefName"]:
501-
error(ERROR_STACKINFO_PR_BASE_MISMATCH.format(**locals()))
502-
raise RuntimeError
537+
# The first entry on the stack needs to be actually mergeable on GitHub.
538+
if (
539+
check_base
540+
and index == 0
541+
and d["mergeStateStatus"] not in ["CLEAN", "UNKNOWN", "UNSTABLE"]
542+
):
543+
error(ERROR_STACKINFO_PR_NOT_MERGEABLE.format(e=entry, d=d))
544+
raise RuntimeError
503545

504-
# The first entry on the stack needs to be actually mergeable on GitHub.
505-
if (
506-
check_base
507-
and index == 0
508-
and d["mergeStateStatus"] not in ["CLEAN", "UNKNOWN", "UNSTABLE"]
509-
):
510-
error(ERROR_STACKINFO_PR_NOT_MERGEABLE.format(**locals()))
511-
raise RuntimeError
546+
except subprocess.CalledProcessError as exc:
547+
error(f"Failed to get PR information from GitHub: {exc}")
548+
raise RuntimeError("GitHub API request failed") from exc
512549

513550

514551
def print_stack(st: list[StackEntry], *, links: bool, level: int = 1) -> None:
@@ -603,10 +640,39 @@ def get_taken_branch_ids(refs: list[str], branch_name_template: str) -> list[int
603640

604641

605642
def generate_available_branch_name(refs: list[str], branch_name_template: str) -> str:
643+
"""Generate an available branch name that doesn't conflict with existing branches.
644+
645+
This function handles potential race conditions by using an ID higher than
646+
the current maximum.
647+
648+
Args:
649+
refs: List of existing branch references
650+
branch_name_template: Template for the branch name
651+
652+
Returns:
653+
A branch name that doesn't conflict with existing branches
654+
"""
655+
max_attempts = 100
606656
branch_ids = get_taken_branch_ids(refs, branch_name_template)
607657
max_ref_num = max(branch_ids) if branch_ids else 0
608658
new_branch_id = max_ref_num + 1
609-
return generate_branch_name(branch_name_template, new_branch_id)
659+
660+
# Safety check: verify the new branch name doesn't already exist
661+
new_branch_name = generate_branch_name(branch_name_template, new_branch_id)
662+
attempts = 0
663+
while any(
664+
ref.endswith(f"/{new_branch_name}") or ref == new_branch_name for ref in refs
665+
):
666+
# Increment and try again if there's a conflict
667+
new_branch_id += 1
668+
new_branch_name = generate_branch_name(branch_name_template, new_branch_id)
669+
attempts += 1
670+
if attempts > max_attempts: # Prevent infinite loops
671+
raise RuntimeError(
672+
"Unable to generate a unique branch name after 100 attempts"
673+
)
674+
675+
return new_branch_name
610676

611677

612678
def get_available_branch_name(remote: str, branch_name_template: str) -> str:
@@ -955,7 +1021,7 @@ def command_submit(
9551021
return
9561022

9571023
if (draft_bitmask is not None) and (len(draft_bitmask) != len(st)):
958-
log(h("Draft bitmask passed to 'submit' doesn't match number of PRs!"))
1024+
error("Draft bitmask passed to 'submit' doesn't match number of PRs!")
9591025
return
9601026

9611027
# Create local branches and initialize base and head fields in the stack
@@ -1121,6 +1187,19 @@ def delete_remote_branches(
11211187
cmd.extend([f":{branch}" for branch in remote_branches_to_delete])
11221188
run_shell_command(cmd, check=False, quiet=not verbose)
11231189

1190+
# Close associated PRs as mentioned in the docstring
1191+
for e in st:
1192+
if e.has_pr():
1193+
try:
1194+
run_shell_command(
1195+
["gh", "pr", "close", e.pr, "--delete-branch=false"],
1196+
check=False,
1197+
quiet=not verbose,
1198+
)
1199+
log(f"Closed PR {e.pr}", level=1)
1200+
except Exception as exc: # noqa: BLE001
1201+
log(f"Failed to close PR {e.pr}: {exc}", level=1)
1202+
11241203

11251204
# ===----------------------------------------------------------------------=== #
11261205
# Entry point for 'land' command
@@ -1467,7 +1546,7 @@ def load_config(config_file: str) -> configparser.ConfigParser:
14671546
return config
14681547

14691548

1470-
def main() -> None: # noqa: PLR0912
1549+
def main() -> None: # noqa: PLR0912, PLR0915
14711550
config_file = os.getenv("STACKPR_CONFIG", ".stack-pr.cfg")
14721551
config = load_config(config_file)
14731552

@@ -1490,9 +1569,17 @@ def main() -> None: # noqa: PLR0912
14901569

14911570
current_branch = get_current_branch_name()
14921571
get_branch_name_base(common_args.branch_name_template)
1572+
stashed = False
14931573
try:
14941574
if args.command in ["submit", "export"] and args.stash:
1495-
run_shell_command(["git", "stash", "save"], quiet=not common_args.verbose)
1575+
# Check if there's anything to stash first
1576+
if not is_repo_clean():
1577+
run_shell_command(
1578+
["git", "stash", "save"], quiet=not common_args.verbose
1579+
)
1580+
stashed = True
1581+
else:
1582+
log("No changes to stash", level=1)
14961583

14971584
if args.command != "view" and not is_repo_clean():
14981585
error(ERROR_REPO_DIRTY)
@@ -1518,15 +1605,27 @@ def main() -> None: # noqa: PLR0912
15181605
return
15191606
except Exception as exc:
15201607
# If something failed, checkout the original branch
1521-
run_shell_command(
1522-
["git", "checkout", current_branch], quiet=not common_args.verbose
1523-
)
1608+
try:
1609+
run_shell_command(
1610+
["git", "checkout", current_branch], quiet=not common_args.verbose
1611+
)
1612+
except Exception as checkout_error: # noqa: BLE001
1613+
error(f"Failed to checkout original branch: {checkout_error}")
15241614
if isinstance(exc, SubprocessError):
15251615
print_cmd_failure_details(exc)
15261616
raise
15271617
finally:
1528-
if args.command in ["submit", "export"] and args.stash:
1529-
run_shell_command(["git", "stash", "pop"], quiet=not common_args.verbose)
1618+
# Only try to pop the stash if we actually stashed something
1619+
if stashed and args.command in ["submit", "export"]:
1620+
try:
1621+
run_shell_command(
1622+
["git", "stash", "pop"], quiet=not common_args.verbose
1623+
)
1624+
except Exception as stash_error: # noqa: BLE001
1625+
error(f"Failed to pop stashed changes: {stash_error}")
1626+
error(
1627+
"Your changes are still in the stash. Run 'git stash pop' to retrieve them."
1628+
)
15301629

15311630

15321631
if __name__ == "__main__":

src/stack_pr/git.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import re
43
import string
54
import subprocess
65
from collections.abc import Sequence
@@ -132,20 +131,26 @@ def get_uncommitted_changes(
132131
return changes
133132

134133

135-
# TODO: enforce this as a module dependency
136134
def check_gh_installed() -> None:
137-
"""Check if the gh tool is installed.
135+
"""Check if the gh tool is installed and authenticated.
138136
139137
Raises:
140-
GitError if gh is not available.
138+
GitError: If gh is not available or not authenticated.
141139
"""
142-
143140
try:
144-
run_shell_command(["gh"], capture_output=True, quiet=False)
141+
# Check if gh is installed
142+
run_shell_command(["gh", "--version"], capture_output=True, quiet=False)
143+
144+
# Check if gh is authenticated
145+
auth_status = get_command_output(["gh", "auth", "status"], check=False)
146+
if "You are not logged into any GitHub hosts" in auth_status:
147+
raise GitError(
148+
"'gh' is not authenticated. Please run 'gh auth login' to authenticate."
149+
)
145150
except subprocess.CalledProcessError as err:
146151
raise GitError(
147-
"'gh' is not installed. Please visit https://cli.github.com/ for"
148-
" installation instuctions."
152+
"'gh' is not installed or not accessible. Please visit https://cli.github.com/ for"
153+
" installation instructions."
149154
) from err
150155

151156

@@ -175,12 +180,17 @@ def get_gh_username() -> str:
175180
]
176181
)
177182

178-
# Extract the login name.
179-
m = re.search(r"\"login\":\"(.*?)\"", user_query)
180-
if not m:
181-
raise GitError("Unable to find current github user name")
183+
# Parse JSON response properly
184+
import json
182185

183-
return m.group(1)
186+
try:
187+
response = json.loads(user_query)
188+
login = response.get("data", {}).get("viewer", {}).get("login")
189+
if not login:
190+
raise GitError("Unable to find current github user name")
191+
return str(login) # Ensure we return a string
192+
except json.JSONDecodeError as e:
193+
raise GitError("Invalid response from GitHub API") from e
184194

185195

186196
def get_changed_files(

src/stack_pr/shell_commands.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,17 @@
55
from collections.abc import Iterable
66
from logging import getLogger
77
from pathlib import Path
8+
from typing import Any, Union
89

9-
if sys.version_info >= (3, 13):
10-
# Unpack moved to typing
11-
from typing import Any, Union
12-
else:
13-
from typing import Union
14-
15-
from typing_extensions import Any
10+
# For Python versions that don't have typing.Unpack yet (pre-3.13),
11+
# we import from typing_extensions instead
12+
if sys.version_info < (3, 13):
13+
from typing_extensions import Unpack # noqa: F401
1614

1715

1816
logger = getLogger(__name__)
1917

18+
# Define type for shell commands, using Iterable for improved compatibility
2019
ShellCommand = Iterable[Union[str, Path]]
2120

2221

0 commit comments

Comments
 (0)