diff --git a/.github/scripts/check_labels.py b/.github/scripts/check_labels.py new file mode 100644 index 00000000000..10be42c3fd5 --- /dev/null +++ b/.github/scripts/check_labels.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +"""Check whether a PR has required labels.""" + +import sys +from typing import Any + +from github_utils import gh_delete_comment, gh_post_pr_comment +from gitutils import get_git_remote_name, get_git_repo_dir, GitRepo +from label_utils import has_required_labels, is_label_err_comment, LABEL_ERR_MSG +from trymerge import GitHubPR + + +def delete_all_label_err_comments(pr: "GitHubPR") -> None: + for comment in pr.get_comments(): + if is_label_err_comment(comment): + gh_delete_comment(pr.org, pr.project, comment.database_id) + + +def add_label_err_comment(pr: "GitHubPR") -> None: + # Only make a comment if one doesn't exist already + if not any(is_label_err_comment(comment) for comment in pr.get_comments()): + gh_post_pr_comment(pr.org, pr.project, pr.pr_num, LABEL_ERR_MSG) + + +def parse_args() -> Any: + from argparse import ArgumentParser + + parser = ArgumentParser("Check PR labels") + parser.add_argument("pr_num", type=int) + # add a flag to return a non-zero exit code if the PR does not have the required labels + parser.add_argument( + "--exit-non-zero", + action="store_true", + help="Return a non-zero exit code if the PR does not have the required labels", + ) + + return parser.parse_args() + + +def main() -> None: + args = parse_args() + repo = GitRepo(get_git_repo_dir(), get_git_remote_name()) + org, project = repo.gh_owner_and_name() + pr = GitHubPR(org, project, args.pr_num) + + try: + if not has_required_labels(pr): + print(LABEL_ERR_MSG) + add_label_err_comment(pr) + if args.exit_non_zero: + sys.exit(1) + else: + delete_all_label_err_comments(pr) + except Exception as e: + if args.exit_non_zero: + sys.exit(1) + + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/.github/workflows/check-labels.yml b/.github/workflows/check-labels.yml new file mode 100644 index 00000000000..19c70c820a8 --- /dev/null +++ b/.github/workflows/check-labels.yml @@ -0,0 +1,54 @@ +name: Check Labels + +on: + # We need pull_request_target to be able to post comments on PRs from forks. + # Only allow pull_request_target when merging to main, not some historical branch. + # + # Make sure to don't introduce explicit checking out and installing/running + # untrusted user code into this workflow! + pull_request_target: + types: [opened, synchronize, reopened, labeled, unlabeled] + branches: [main] + + # To check labels on ghstack PRs. + # Note: as pull_request doesn't trigger on PRs targeting main, + # to test changes to the workflow itself one needs to create + # a PR that targets a gh/**/base branch. + pull_request: + types: [opened, synchronize, reopened, labeled, unlabeled] + branches: [gh/**/base] + + workflow_dispatch: + inputs: + pr_number: + description: 'PR number to check labels for' + required: true + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }} + cancel-in-progress: true + +jobs: + check-labels: + permissions: + contents: read + pull-requests: write + name: Check labels + if: github.repository_owner == 'pytorch' + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - uses: actions/setup-python@v4 + with: + python-version: '3.10' + # Not the direct dependencies but the script uses trymerge + - run: pip install pyyaml==6.0 rockset==1.0.3 + - name: Check labels + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_NUM: ${{ github.event.number || github.event.inputs.pr_number }} + run: | + set -ex + python3 .github/scripts/check_labels.py --exit-non-zero "${PR_NUM}"