| 
 | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates.  | 
 | 2 | +# All rights reserved.  | 
 | 3 | +#  | 
 | 4 | +# This source code is licensed under the BSD-style license found in the  | 
 | 5 | +# LICENSE file in the root directory of this source tree.  | 
 | 6 | + | 
 | 7 | +import argparse  | 
 | 8 | +import os  | 
 | 9 | +import re  | 
 | 10 | + | 
 | 11 | +from typing import List  | 
 | 12 | + | 
 | 13 | +# Provided by the PyGithub pip package.  | 
 | 14 | +from github import Auth, Github  | 
 | 15 | +from github.Repository import Repository  | 
 | 16 | + | 
 | 17 | + | 
 | 18 | +def parse_args():  | 
 | 19 | +    parser = argparse.ArgumentParser(  | 
 | 20 | +        description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter  | 
 | 21 | +    )  | 
 | 22 | +    parser.add_argument(  | 
 | 23 | +        "--repo",  | 
 | 24 | +        type=str,  | 
 | 25 | +        help='The github repo to modify: e.g. "pytorch/executorch".',  | 
 | 26 | +        required=True,  | 
 | 27 | +    )  | 
 | 28 | +    parser.add_argument(  | 
 | 29 | +        "--pr",  | 
 | 30 | +        type=int,  | 
 | 31 | +        help="Number of the PR in the stack to check and create corresponding PR",  | 
 | 32 | +        required=True,  | 
 | 33 | +    )  | 
 | 34 | +    return parser.parse_args()  | 
 | 35 | + | 
 | 36 | + | 
 | 37 | +def extract_stack_from_body(pr_body: str) -> List[int]:  | 
 | 38 | +    """Extracts a list of PR numbers from a ghexport-generated PR body.  | 
 | 39 | +
  | 
 | 40 | +    The base of the stack is in index 0.  | 
 | 41 | +    """  | 
 | 42 | + | 
 | 43 | +    # Expected format. The `__->__` could appear on any line. Stop parsing  | 
 | 44 | +    # after the blank line. This would return [1, 2, 3].  | 
 | 45 | +    """  | 
 | 46 | +    Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):  | 
 | 47 | +    * #3  | 
 | 48 | +    * __->__ #2  | 
 | 49 | +    * #1  | 
 | 50 | +
  | 
 | 51 | +    <PR description details>  | 
 | 52 | +    """  | 
 | 53 | + | 
 | 54 | +    prs = []  | 
 | 55 | +    ghstack_begin = (  | 
 | 56 | +        "Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom):"  | 
 | 57 | +    )  | 
 | 58 | +    ghstack_begin_seen = False  | 
 | 59 | +    for line in pr_body.splitlines():  | 
 | 60 | +        if ghstack_begin in line:  | 
 | 61 | +            ghstack_begin_seen = True  | 
 | 62 | +        if not ghstack_begin_seen:  | 
 | 63 | +            continue  | 
 | 64 | +        match = re.match(r"\*(?:.*?)? #(\d+)", line)  | 
 | 65 | +        if match:  | 
 | 66 | +            # It's a bullet followed by an integer.  | 
 | 67 | +            prs.append(int(match.group(1)))  | 
 | 68 | +    return list(reversed(prs))  | 
 | 69 | + | 
 | 70 | + | 
 | 71 | +def get_pr_stack_from_number(pr_number: int, repo: Repository) -> List[int]:  | 
 | 72 | +    pr_stack = extract_stack_from_body(repo.get_pull(pr_number).body)  | 
 | 73 | + | 
 | 74 | +    if not pr_stack:  | 
 | 75 | +        raise Exception(  | 
 | 76 | +            f"Could not find PR stack in body of #{pr_number}. "  | 
 | 77 | +            + "Please make sure that the PR was created with ghstack."  | 
 | 78 | +        )  | 
 | 79 | + | 
 | 80 | +    return pr_stack  | 
 | 81 | + | 
 | 82 | + | 
 | 83 | +def create_prs_for_orig_branch(pr_stack: List[int], repo: Repository):  | 
 | 84 | +    # For the first PR, we want to merge to `main` branch, and we will update  | 
 | 85 | +    # as we go through the stack  | 
 | 86 | +    orig_branch_merge_base = "main"  | 
 | 87 | +    for i in range(len(pr_stack)):  | 
 | 88 | +        pr = repo.get_pull(pr_stack[i])  | 
 | 89 | +        if not pr.is_merged():  | 
 | 90 | +            print("The PR (and stack above) is not merged yet, skipping")  | 
 | 91 | +            return  | 
 | 92 | +        # Check for invariant: For the current PR, it must be gh/user/x/base <- gh/user/x/head  | 
 | 93 | +        assert pr.base.ref.replace("base", "head") == pr.head.ref  | 
 | 94 | +        # The PR we want to create is then "branch_to_merge" <- gh/user/x/orig  | 
 | 95 | +        # gh/user/x/orig is the clean diff between gh/user/x/base <- gh/user/x/head  | 
 | 96 | +        orig_branch_merge_head = pr.base.ref.replace("base", "orig")  | 
 | 97 | +        bot_metadata = f"""This PR was created by the merge bot to help merge the original PR into the main branch.  | 
 | 98 | +ghstack PR number: https://github.com/pytorch/executorch/pull/{pr.number}  | 
 | 99 | +^ Please use this as the source of truth for the PR details, comments, and reviews  | 
 | 100 | +ghstack PR base: https://github.com/pytorch/executorch/tree/{pr.base.ref}  | 
 | 101 | +ghstack PR head: https://github.com/pytorch/executorch/tree/{pr.head.ref}  | 
 | 102 | +Merge bot PR base: https://github.com/pytorch/executorch/tree/{orig_branch_merge_base}  | 
 | 103 | +Merge bot PR head: https://github.com/pytorch/executorch/tree/{orig_branch_merge_head}"""  | 
 | 104 | + | 
 | 105 | +        existing_orig_pr = repo.get_pulls(  | 
 | 106 | +            head="pytorch:" + orig_branch_merge_head,  | 
 | 107 | +            base=orig_branch_merge_base,  | 
 | 108 | +            state="open",  | 
 | 109 | +        )  | 
 | 110 | +        if existing_orig_pr.totalCount > 0:  | 
 | 111 | +            print(  | 
 | 112 | +                f"PR for {orig_branch_merge_head} already exists {existing_orig_pr[0]}"  | 
 | 113 | +            )  | 
 | 114 | +            # We don't need to create/edit because the head PR is merged and orig is finalized.  | 
 | 115 | +        else:  | 
 | 116 | +            repo.create_pull(  | 
 | 117 | +                base=orig_branch_merge_base,  | 
 | 118 | +                head=orig_branch_merge_head,  | 
 | 119 | +                title=pr.title,  | 
 | 120 | +                body=bot_metadata,  | 
 | 121 | +            )  | 
 | 122 | +        # Advance the base for the next PR  | 
 | 123 | +        orig_branch_merge_base = orig_branch_merge_head  | 
 | 124 | + | 
 | 125 | + | 
 | 126 | +def main():  | 
 | 127 | +    args = parse_args()  | 
 | 128 | + | 
 | 129 | +    with Github(auth=Auth.Token(os.environ["GITHUB_TOKEN"])) as gh:  | 
 | 130 | +        repo = gh.get_repo(args.repo)  | 
 | 131 | +        create_prs_for_orig_branch(get_pr_stack_from_number(args.pr, repo), repo)  | 
 | 132 | + | 
 | 133 | + | 
 | 134 | +if __name__ == "__main__":  | 
 | 135 | +    main()  | 
0 commit comments