diff --git a/mergify_cli/ci/detector.py b/mergify_cli/ci/detector.py index cea90202..e27ca6d2 100644 --- a/mergify_cli/ci/detector.py +++ b/mergify_cli/ci/detector.py @@ -178,7 +178,7 @@ def get_github_pull_request_number() -> int | None: match get_ci_provider(): case "github_actions": try: - event = utils.get_github_event() + _, event = utils.get_github_event() except utils.GitHubEventNotFoundError: return None pr = event.get("pull_request") diff --git a/mergify_cli/ci/scopes/base_detector.py b/mergify_cli/ci/scopes/base_detector.py index ac3316bf..8ee4e14b 100644 --- a/mergify_cli/ci/scopes/base_detector.py +++ b/mergify_cli/ci/scopes/base_detector.py @@ -1,7 +1,6 @@ from __future__ import annotations import dataclasses -import os import typing import yaml @@ -69,32 +68,70 @@ def _detect_base_from_event(ev: dict[str, typing.Any]) -> str | None: return None +def _detect_default_branch_from_event(ev: dict[str, typing.Any]) -> str | None: + repo = ev.get("repository") + if isinstance(repo, dict): + sha = repo.get("default_branch") + if isinstance(sha, str) and sha: + return sha + return None + + +def _detect_base_from_push_event(ev: dict[str, typing.Any]) -> str | None: + sha = ev.get("before") + if isinstance(sha, str) and sha: + return sha + return None + + @dataclasses.dataclass class Base: ref: str is_merge_queue: bool +PULL_REQUEST_EVENTS = { + "pull_request", + "pull_request_review", + "pull_request_review_comment", + "pull_request_target", +} + + def detect() -> Base: try: - event = utils.get_github_event() + event_name, event = utils.get_github_event() except utils.GitHubEventNotFoundError: - pass + # fallback to last commit + return Base("HEAD^", is_merge_queue=False) else: - # 0) merge-queue PR override - mq_sha = _detect_base_from_merge_queue_payload(event) - if mq_sha: - return Base(mq_sha, is_merge_queue=True) - - # 1) standard event payload - event_sha = _detect_base_from_event(event) - if event_sha: - return Base(event_sha, is_merge_queue=False) - - # 2) base ref (e.g., PR target branch) - base_ref = os.environ.get("GITHUB_BASE_REF") - if base_ref: - return Base(base_ref, is_merge_queue=False) - - msg = "Could not detect base SHA. Provide GITHUB_EVENT_PATH / GITHUB_BASE_REF." + if event_name in PULL_REQUEST_EVENTS: + # 0) merge-queue PR override + mq_sha = _detect_base_from_merge_queue_payload(event) + if mq_sha: + return Base(mq_sha, is_merge_queue=True) + + # 1) standard event payload + event_sha = _detect_base_from_event(event) + if event_sha: + return Base(event_sha, is_merge_queue=False) + + # 2) standard event payload + event_sha = _detect_default_branch_from_event(event) + if event_sha: + return Base(event_sha, is_merge_queue=False) + + elif event_name == "push": + event_sha = _detect_base_from_push_event(event) + if event_sha: + return Base(event_sha, is_merge_queue=False) + + event_sha = _detect_default_branch_from_event(event) + if event_sha: + return Base(event_sha, is_merge_queue=False) + else: + msg = "Unhandled GITHUB_EVENT_NAME" + raise BaseNotFoundError(msg) + + msg = "Could not detect base SHA. Provide GITHUB_EVENT_NAME / GITHUB_EVENT_PATH." raise BaseNotFoundError(msg) diff --git a/mergify_cli/tests/ci/scopes/test_base_detector.py b/mergify_cli/tests/ci/scopes/test_base_detector.py index 4fa6a23a..ddb384fa 100644 --- a/mergify_cli/tests/ci/scopes/test_base_detector.py +++ b/mergify_cli/tests/ci/scopes/test_base_detector.py @@ -6,18 +6,41 @@ from mergify_cli.ci.scopes import base_detector -def test_detect_base_github_base_ref( +@pytest.mark.parametrize("event_name", ["pull_request", "pull_request_review", "push"]) +def test_detect_base_from_repository_default_branch( + event_name: str, monkeypatch: pytest.MonkeyPatch, + tmp_path: pathlib.Path, ) -> None: - monkeypatch.setenv("GITHUB_BASE_REF", "main") - monkeypatch.delenv("GITHUB_EVENT_PATH", raising=False) + event_data = {"repository": {"default_branch": "main"}} + event_file = tmp_path / "event.json" + event_file.write_text(json.dumps(event_data)) + + monkeypatch.setenv("GITHUB_EVENT_NAME", event_name) + monkeypatch.setenv("GITHUB_EVENT_PATH", str(event_file)) result = base_detector.detect() assert result == base_detector.Base("main", is_merge_queue=False) -def test_detect_base_from_event_path( +def test_detect_base_from_push_event( + monkeypatch: pytest.MonkeyPatch, + tmp_path: pathlib.Path, +) -> None: + event_data = {"before": "abc123"} + event_file = tmp_path / "event.json" + event_file.write_text(json.dumps(event_data)) + + monkeypatch.setenv("GITHUB_EVENT_NAME", "push") + monkeypatch.setenv("GITHUB_EVENT_PATH", str(event_file)) + + result = base_detector.detect() + + assert result == base_detector.Base("abc123", is_merge_queue=False) + + +def test_detect_base_from_pull_request_event_path( monkeypatch: pytest.MonkeyPatch, tmp_path: pathlib.Path, ) -> None: @@ -29,8 +52,8 @@ def test_detect_base_from_event_path( event_file = tmp_path / "event.json" event_file.write_text(json.dumps(event_data)) + monkeypatch.setenv("GITHUB_EVENT_NAME", "pull_request") monkeypatch.setenv("GITHUB_EVENT_PATH", str(event_file)) - monkeypatch.delenv("GITHUB_BASE_REF", raising=False) result = base_detector.detect() @@ -51,6 +74,7 @@ def test_detect_base_merge_queue_override( event_file = tmp_path / "event.json" event_file.write_text(json.dumps(event_data)) + monkeypatch.setenv("GITHUB_EVENT_NAME", "pull_request") monkeypatch.setenv("GITHUB_EVENT_PATH", str(event_file)) result = base_detector.detect() @@ -58,9 +82,16 @@ def test_detect_base_merge_queue_override( assert result == base_detector.Base("xyz789", is_merge_queue=True) -def test_detect_base_no_info(monkeypatch: pytest.MonkeyPatch) -> None: - monkeypatch.delenv("GITHUB_EVENT_PATH", raising=False) - monkeypatch.delenv("GITHUB_BASE_REF", raising=False) +def test_detect_base_no_info( + tmp_path: pathlib.Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + event_data: dict[str, str] = {} + event_file = tmp_path / "event.json" + event_file.write_text(json.dumps(event_data)) + + monkeypatch.setenv("GITHUB_EVENT_NAME", "pull_request") + monkeypatch.setenv("GITHUB_EVENT_PATH", str(event_file)) with pytest.raises( base_detector.BaseNotFoundError, diff --git a/mergify_cli/tests/test_utils.py b/mergify_cli/tests/test_utils.py index 3f33a2b5..a075e6e0 100644 --- a/mergify_cli/tests/test_utils.py +++ b/mergify_cli/tests/test_utils.py @@ -117,9 +117,11 @@ def test_get_github_event_success( event_file = tmp_path / "event.json" event_file.write_text(json.dumps(event_data)) + monkeypatch.setenv("GITHUB_EVENT_NAME", "pull_request") monkeypatch.setenv("GITHUB_EVENT_PATH", str(event_file)) - result = utils.get_github_event() - assert result == event_data + name, event = utils.get_github_event() + assert name == "pull_request" + assert event == event_data def test_get_github_event_not_found(monkeypatch: pytest.MonkeyPatch) -> None: diff --git a/mergify_cli/utils.py b/mergify_cli/utils.py index 55a207f5..726de3fe 100644 --- a/mergify_cli/utils.py +++ b/mergify_cli/utils.py @@ -321,12 +321,15 @@ class GitHubEventNotFoundError(Exception): pass -def get_github_event() -> typing.Any: # noqa: ANN401 +def get_github_event() -> tuple[str, typing.Any]: + event_name = os.environ.get("GITHUB_EVENT_NAME") + if not event_name: + raise GitHubEventNotFoundError event_path = os.environ.get("GITHUB_EVENT_PATH") if event_path and pathlib.Path(event_path).is_file(): try: with pathlib.Path(event_path).open("r", encoding="utf-8") as f: - return json.load(f) + return event_name, json.load(f) except FileNotFoundError: pass raise GitHubEventNotFoundError