5555import logging
5656import os
5757import re
58+ import subprocess
5859import sys
5960from dataclasses import dataclass
6061from functools import cache
@@ -399,19 +400,40 @@ def last(ref: str, sep: str = "/") -> str:
399400 return ref .rsplit (sep , 1 )[- 1 ]
400401
401402
402- # TODO: Move to 'modular.utils.git'
403+ class GitAncestryError (RuntimeError ):
404+ """Error raised when git ancestry check fails."""
405+
406+
403407def is_ancestor (commit1 : str , commit2 : str , * , verbose : bool ) -> bool :
404408 """
405409 Returns true if 'commit1' is an ancestor of 'commit2'.
410+
411+ Raises:
412+ GitAncestryError: If git command fails for reasons other than ancestry check.
406413 """
407- # TODO: We need to check returncode of this command more carefully, as the
408- # command simply might fail (rc != 0 and rc != 1).
409- p = run_shell_command (
410- ["git" , "merge-base" , "--is-ancestor" , commit1 , commit2 ],
411- check = False ,
412- quiet = not verbose ,
413- )
414- return p .returncode == 0
414+ error_code = None
415+
416+ try :
417+ p = run_shell_command (
418+ ["git" , "merge-base" , "--is-ancestor" , commit1 , commit2 ],
419+ check = False ,
420+ quiet = not verbose ,
421+ )
422+ if p .returncode == 0 :
423+ return True
424+ if p .returncode == 1 :
425+ return False
426+
427+ # Store error code for later
428+ error_code = p .returncode
429+ except subprocess .SubprocessError as e :
430+ raise GitAncestryError (f"Failed to determine ancestry relationship: { e } " ) from e
431+
432+ # Handle error code outside the try block
433+ if error_code is None :
434+ # This should never happen, but just in case
435+ raise GitAncestryError ("Unexpected error in git ancestry check" )
436+ raise GitAncestryError (f"Git ancestry check failed with code { error_code } " )
415437
416438
417439def is_repo_clean () -> bool :
@@ -456,59 +478,74 @@ def set_base_branches(st: list[StackEntry], target: str) -> None:
456478
457479def verify (st : list [StackEntry ], * , check_base : bool = False ) -> None :
458480 log (h ("Verifying stack info" ))
459- for index , e in enumerate (st ):
460- if e .has_missing_info ():
461- error (ERROR_STACKINFO_MISSING .format (** locals () ))
481+ for index , entry in enumerate (st ):
482+ if entry .has_missing_info ():
483+ error (ERROR_STACKINFO_MISSING .format (e = entry ))
462484 raise RuntimeError
463485
464- if len (e .pr .split ("/" )) == 0 or not last (e .pr ).isnumeric ():
465- error (ERROR_STACKINFO_BAD_LINK .format (** locals () ))
486+ if len (entry .pr .split ("/" )) == 0 or not last (entry .pr ).isnumeric ():
487+ error (ERROR_STACKINFO_BAD_LINK .format (e = entry ))
466488 raise RuntimeError
467489
468- ghinfo = get_command_output (
469- [
470- "gh" ,
471- "pr" ,
472- "view" ,
473- e .pr ,
474- "--json" ,
475- "baseRefName,headRefName,number,state,body,title,url,mergeStateStatus" ,
476- ]
477- )
478- d = json .loads (ghinfo )
479- for required_field in ["state" , "number" , "baseRefName" , "headRefName" ]:
480- if required_field not in d :
481- error (ERROR_STACKINFO_MALFORMED_RESPONSE .format (** locals ()))
490+ try :
491+ ghinfo = get_command_output (
492+ [
493+ "gh" ,
494+ "pr" ,
495+ "view" ,
496+ entry .pr ,
497+ "--json" ,
498+ "baseRefName,headRefName,number,state,body,title,url,mergeStateStatus" ,
499+ ]
500+ )
501+
502+ try :
503+ d = json .loads (ghinfo )
504+ except json .JSONDecodeError as e :
505+ error (f"Failed to parse JSON response from GitHub: { ghinfo } " )
506+ raise RuntimeError ("Invalid JSON response from GitHub" ) from e
507+
508+ for required_field in ["state" , "number" , "baseRefName" , "headRefName" ]:
509+ if required_field not in d :
510+ error (
511+ ERROR_STACKINFO_MALFORMED_RESPONSE .format (
512+ e = entry , required_field = required_field , d = d
513+ )
514+ )
515+ raise RuntimeError
516+
517+ if d ["state" ] != "OPEN" :
518+ error (ERROR_STACKINFO_PR_NOT_OPEN .format (e = entry , d = d ))
482519 raise RuntimeError
483520
484- if d [ "state" ] != "OPEN" :
485- error (ERROR_STACKINFO_PR_NOT_OPEN .format (** locals () ))
486- raise RuntimeError
521+ if int ( last ( entry . pr )) != d [ "number" ] :
522+ error (ERROR_STACKINFO_PR_NUMBER_MISMATCH .format (e = entry , d = d ))
523+ raise RuntimeError
487524
488- if int ( last ( e . pr )) != d ["number " ]:
489- error (ERROR_STACKINFO_PR_NUMBER_MISMATCH .format (** locals () ))
490- raise RuntimeError
525+ if entry . head != d ["headRefName " ]:
526+ error (ERROR_STACKINFO_PR_HEAD_MISMATCH .format (e = entry , d = d ))
527+ raise RuntimeError
491528
492- if e .head != d ["headRefName" ]:
493- error (ERROR_STACKINFO_PR_HEAD_MISMATCH .format (** locals ()))
494- raise RuntimeError
529+ # 'Base' branch might diverge when the stack is modified (e.g. when a
530+ # new commit is added to the middle of the stack). It is not an issue
531+ # if we're updating the stack (i.e. in 'submit'), but it is an issue if
532+ # we are trying to land it.
533+ if check_base and entry .base != d ["baseRefName" ]:
534+ error (ERROR_STACKINFO_PR_BASE_MISMATCH .format (e = entry , d = d ))
535+ raise RuntimeError
495536
496- # 'Base' branch might diverge when the stack is modified (e.g. when a
497- # new commit is added to the middle of the stack). It is not an issue
498- # if we're updating the stack (i.e. in 'submit'), but it is an issue if
499- # we are trying to land it.
500- if check_base and e .base != d ["baseRefName" ]:
501- error (ERROR_STACKINFO_PR_BASE_MISMATCH .format (** locals ()))
502- raise RuntimeError
537+ # The first entry on the stack needs to be actually mergeable on GitHub.
538+ if (
539+ check_base
540+ and index == 0
541+ and d ["mergeStateStatus" ] not in ["CLEAN" , "UNKNOWN" , "UNSTABLE" ]
542+ ):
543+ error (ERROR_STACKINFO_PR_NOT_MERGEABLE .format (e = entry , d = d ))
544+ raise RuntimeError
503545
504- # The first entry on the stack needs to be actually mergeable on GitHub.
505- if (
506- check_base
507- and index == 0
508- and d ["mergeStateStatus" ] not in ["CLEAN" , "UNKNOWN" , "UNSTABLE" ]
509- ):
510- error (ERROR_STACKINFO_PR_NOT_MERGEABLE .format (** locals ()))
511- raise RuntimeError
546+ except subprocess .CalledProcessError as exc :
547+ error (f"Failed to get PR information from GitHub: { exc } " )
548+ raise RuntimeError ("GitHub API request failed" ) from exc
512549
513550
514551def print_stack (st : list [StackEntry ], * , links : bool , level : int = 1 ) -> None :
@@ -603,10 +640,39 @@ def get_taken_branch_ids(refs: list[str], branch_name_template: str) -> list[int
603640
604641
605642def generate_available_branch_name (refs : list [str ], branch_name_template : str ) -> str :
643+ """Generate an available branch name that doesn't conflict with existing branches.
644+
645+ This function handles potential race conditions by using an ID higher than
646+ the current maximum.
647+
648+ Args:
649+ refs: List of existing branch references
650+ branch_name_template: Template for the branch name
651+
652+ Returns:
653+ A branch name that doesn't conflict with existing branches
654+ """
655+ max_attempts = 100
606656 branch_ids = get_taken_branch_ids (refs , branch_name_template )
607657 max_ref_num = max (branch_ids ) if branch_ids else 0
608658 new_branch_id = max_ref_num + 1
609- return generate_branch_name (branch_name_template , new_branch_id )
659+
660+ # Safety check: verify the new branch name doesn't already exist
661+ new_branch_name = generate_branch_name (branch_name_template , new_branch_id )
662+ attempts = 0
663+ while any (
664+ ref .endswith (f"/{ new_branch_name } " ) or ref == new_branch_name for ref in refs
665+ ):
666+ # Increment and try again if there's a conflict
667+ new_branch_id += 1
668+ new_branch_name = generate_branch_name (branch_name_template , new_branch_id )
669+ attempts += 1
670+ if attempts > max_attempts : # Prevent infinite loops
671+ raise RuntimeError (
672+ "Unable to generate a unique branch name after 100 attempts"
673+ )
674+
675+ return new_branch_name
610676
611677
612678def get_available_branch_name (remote : str , branch_name_template : str ) -> str :
@@ -955,7 +1021,7 @@ def command_submit(
9551021 return
9561022
9571023 if (draft_bitmask is not None ) and (len (draft_bitmask ) != len (st )):
958- log ( h ( "Draft bitmask passed to 'submit' doesn't match number of PRs!" ) )
1024+ error ( "Draft bitmask passed to 'submit' doesn't match number of PRs!" )
9591025 return
9601026
9611027 # Create local branches and initialize base and head fields in the stack
@@ -1121,6 +1187,19 @@ def delete_remote_branches(
11211187 cmd .extend ([f":{ branch } " for branch in remote_branches_to_delete ])
11221188 run_shell_command (cmd , check = False , quiet = not verbose )
11231189
1190+ # Close associated PRs as mentioned in the docstring
1191+ for e in st :
1192+ if e .has_pr ():
1193+ try :
1194+ run_shell_command (
1195+ ["gh" , "pr" , "close" , e .pr , "--delete-branch=false" ],
1196+ check = False ,
1197+ quiet = not verbose ,
1198+ )
1199+ log (f"Closed PR { e .pr } " , level = 1 )
1200+ except Exception as exc : # noqa: BLE001
1201+ log (f"Failed to close PR { e .pr } : { exc } " , level = 1 )
1202+
11241203
11251204# ===----------------------------------------------------------------------=== #
11261205# Entry point for 'land' command
@@ -1467,7 +1546,7 @@ def load_config(config_file: str) -> configparser.ConfigParser:
14671546 return config
14681547
14691548
1470- def main () -> None : # noqa: PLR0912
1549+ def main () -> None : # noqa: PLR0912, PLR0915
14711550 config_file = os .getenv ("STACKPR_CONFIG" , ".stack-pr.cfg" )
14721551 config = load_config (config_file )
14731552
@@ -1490,9 +1569,17 @@ def main() -> None: # noqa: PLR0912
14901569
14911570 current_branch = get_current_branch_name ()
14921571 get_branch_name_base (common_args .branch_name_template )
1572+ stashed = False
14931573 try :
14941574 if args .command in ["submit" , "export" ] and args .stash :
1495- run_shell_command (["git" , "stash" , "save" ], quiet = not common_args .verbose )
1575+ # Check if there's anything to stash first
1576+ if not is_repo_clean ():
1577+ run_shell_command (
1578+ ["git" , "stash" , "save" ], quiet = not common_args .verbose
1579+ )
1580+ stashed = True
1581+ else :
1582+ log ("No changes to stash" , level = 1 )
14961583
14971584 if args .command != "view" and not is_repo_clean ():
14981585 error (ERROR_REPO_DIRTY )
@@ -1518,15 +1605,27 @@ def main() -> None: # noqa: PLR0912
15181605 return
15191606 except Exception as exc :
15201607 # If something failed, checkout the original branch
1521- run_shell_command (
1522- ["git" , "checkout" , current_branch ], quiet = not common_args .verbose
1523- )
1608+ try :
1609+ run_shell_command (
1610+ ["git" , "checkout" , current_branch ], quiet = not common_args .verbose
1611+ )
1612+ except Exception as checkout_error : # noqa: BLE001
1613+ error (f"Failed to checkout original branch: { checkout_error } " )
15241614 if isinstance (exc , SubprocessError ):
15251615 print_cmd_failure_details (exc )
15261616 raise
15271617 finally :
1528- if args .command in ["submit" , "export" ] and args .stash :
1529- run_shell_command (["git" , "stash" , "pop" ], quiet = not common_args .verbose )
1618+ # Only try to pop the stash if we actually stashed something
1619+ if stashed and args .command in ["submit" , "export" ]:
1620+ try :
1621+ run_shell_command (
1622+ ["git" , "stash" , "pop" ], quiet = not common_args .verbose
1623+ )
1624+ except Exception as stash_error : # noqa: BLE001
1625+ error (f"Failed to pop stashed changes: { stash_error } " )
1626+ error (
1627+ "Your changes are still in the stash. Run 'git stash pop' to retrieve them."
1628+ )
15301629
15311630
15321631if __name__ == "__main__" :
0 commit comments