Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions mergify_cli/ci/scopes/base_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from __future__ import annotations

import json
import os
import pathlib
import typing

import click
import yaml


class MergeQueuePullRequest(typing.TypedDict):
number: int


class MergeQueueBatchFailed(typing.TypedDict):
draft_pr_number: int
checked_pull_request: list[int]


class MergeQueueMetadata(typing.TypedDict):
checking_base_sha: str
pull_requests: list[MergeQueuePullRequest]
previous_failed_batches: list[MergeQueueBatchFailed]


def _yaml_docs_from_fenced_blocks(body: str) -> MergeQueueMetadata | None:
lines = []
found = False
for line in body.splitlines():
if line.startswith("```yaml"):
found = True
elif found:
if line.startswith("```"):
break
lines.append(line)
if lines:
return typing.cast("MergeQueueMetadata", yaml.safe_load("\n".join(lines)))
return None


def _detect_base_from_merge_queue_payload(ev: dict[str, typing.Any]) -> str | None:
pr = ev.get("pull_request")
if not isinstance(pr, dict):
return None
title = pr.get("title") or ""
if not isinstance(title, str):
return None
if not title.startswith("merge-queue: "):
return None
body = pr.get("body") or ""
content = _yaml_docs_from_fenced_blocks(body)
if content:
return content["checking_base_sha"]
return None


def _detect_base_from_event(ev: dict[str, typing.Any]) -> str | None:
pr = ev.get("pull_request")
if isinstance(pr, dict):
sha = pr.get("base", {}).get("sha")
if isinstance(sha, str) and sha:
return sha
return None


def detect() -> str:
event_path = os.environ.get("GITHUB_EVENT_PATH")
event: dict[str, typing.Any] | None = None
if event_path and pathlib.Path(event_path).is_file():
try:
with pathlib.Path(event_path).open("r", encoding="utf-8") as f:
event = json.load(f)
except FileNotFoundError:
event = None

if event is not None:
# 0) merge-queue PR override
mq_sha = _detect_base_from_merge_queue_payload(event)
if mq_sha:
return mq_sha

# 1) standard event payload
event_sha = _detect_base_from_event(event)
if event_sha:
return event_sha

# 2) base ref (e.g., PR target branch)
base_ref = os.environ.get("GITHUB_BASE_REF")
if base_ref:
return base_ref

msg = (
"Could not detect base SHA. Ensure checkout has sufficient history "
"(e.g., actions/checkout with fetch-depth: 0) or provide GITHUB_EVENT_PATH / GITHUB_BASE_REF."
)
raise click.ClickException(
msg,
)
23 changes: 23 additions & 0 deletions mergify_cli/ci/scopes/changed_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

import subprocess

import click


def _run(cmd: list[str]) -> str:
try:
return subprocess.check_output(cmd, text=True, encoding="utf-8").strip()
except subprocess.CalledProcessError as e:
msg = f"Command failed: {' '.join(cmd)}\n{e}"
raise click.ClickException(msg) from e


def git_changed_files(base: str) -> list[str]:
# Committed changes only between base_sha and HEAD.
# Includes: Added (A), Copied (C), Modified (M), Renamed (R), Type-changed (T), Deleted (D)
# Excludes: Unmerged (U), Unknown (X), Broken (B)
out = _run(
["git", "diff", "--name-only", "--diff-filter=ACMRTD", f"{base}...HEAD"],
)
return [line for line in out.splitlines() if line]
117 changes: 5 additions & 112 deletions mergify_cli/ci/scopes/cli.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import annotations

import json
import os
import pathlib
import subprocess
import typing

import click
import pydantic
import yaml

from mergify_cli.ci.scopes import base_detector
from mergify_cli.ci.scopes import changed_files


if typing.TYPE_CHECKING:
from collections import abc
Expand Down Expand Up @@ -55,114 +56,6 @@ def from_yaml(cls, path: str) -> Config:
return cls.from_dict(data)


def _run(cmd: list[str]) -> str:
try:
return subprocess.check_output(cmd, text=True, encoding="utf-8").strip()
except subprocess.CalledProcessError as e:
msg = f"Command failed: {' '.join(cmd)}\n{e}"
raise click.ClickException(msg) from e


class MergeQueuePullRequest(typing.TypedDict):
number: int


class MergeQueueBatchFailed(typing.TypedDict):
draft_pr_number: int
checked_pull_request: list[int]


class MergeQueueMetadata(typing.TypedDict):
checking_base_sha: str
pull_requests: list[MergeQueuePullRequest]
previous_failed_batches: list[MergeQueueBatchFailed]


def _yaml_docs_from_fenced_blocks(body: str) -> MergeQueueMetadata | None:
lines = []
found = False
for line in body.splitlines():
if line.startswith("```yaml"):
found = True
elif found:
if line.startswith("```"):
break
lines.append(line)
if lines:
return typing.cast("MergeQueueMetadata", yaml.safe_load("\n".join(lines)))
return None


def _detect_base_from_merge_queue_payload(ev: dict[str, typing.Any]) -> str | None:
pr = ev.get("pull_request")
if not isinstance(pr, dict):
return None
title = pr.get("title") or ""
if not isinstance(title, str):
return None
if not title.startswith("merge-queue: "):
return None
body = pr.get("body") or ""
content = _yaml_docs_from_fenced_blocks(body)
if content:
return content["checking_base_sha"]
return None


def _detect_base_from_event(ev: dict[str, typing.Any]) -> str | None:
pr = ev.get("pull_request")
if isinstance(pr, dict):
sha = pr.get("base", {}).get("sha")
if isinstance(sha, str) and sha:
return sha
return None


def detect_base() -> str:
event_path = os.environ.get("GITHUB_EVENT_PATH")
event: dict[str, typing.Any] | None = None
if event_path and pathlib.Path(event_path).is_file():
try:
with pathlib.Path(event_path).open("r", encoding="utf-8") as f:
event = json.load(f)
except FileNotFoundError:
event = None

if event is not None:
# 0) merge-queue PR override
mq_sha = _detect_base_from_merge_queue_payload(event)
if mq_sha:
return mq_sha

# 1) standard event payload
event_sha = _detect_base_from_event(event)
if event_sha:
return event_sha

# 2) base ref (e.g., PR target branch)
base_ref = os.environ.get("GITHUB_BASE_REF")
if base_ref:
return base_ref

msg = (
"Could not detect base SHA. Ensure checkout has sufficient history "
"(e.g., actions/checkout with fetch-depth: 0) or provide GITHUB_EVENT_PATH / GITHUB_BASE_REF."
)
raise click.ClickException(
msg,
)


def git_changed_files(base: str) -> list[str]:
# Committed changes only between base_sha and HEAD.
# Includes: Added (A), Copied (C), Modified (M), Renamed (R), Type-changed (T), Deleted (D)
# Excludes: Unmerged (U), Unknown (X), Broken (B)
out = _run(
["git", "diff", "--name-only", "--diff-filter=ACMRTD", f"{base}...HEAD"],
)
return [line for line in out.splitlines() if line]


def match_scopes(
config: Config,
files: abc.Iterable[str],
Expand Down Expand Up @@ -210,8 +103,8 @@ def maybe_write_github_outputs(

def detect(config_path: str) -> None:
cfg = Config.from_yaml(config_path)
base = detect_base()
changed = git_changed_files(base)
base = base_detector.detect()
changed = changed_files.git_changed_files(base)
scopes_hit, per_scope = match_scopes(cfg, changed)

click.echo(f"Base: {base}")
Expand Down
108 changes: 108 additions & 0 deletions mergify_cli/tests/ci/scopes/test_base_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import json
import pathlib

import click
import pytest

from mergify_cli.ci.scopes import base_detector


def test_detect_base_github_base_ref(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setenv("GITHUB_BASE_REF", "main")
monkeypatch.delenv("GITHUB_EVENT_PATH", raising=False)

result = base_detector.detect()

assert result == "main"


def test_detect_base_from_event_path(
monkeypatch: pytest.MonkeyPatch,
tmp_path: pathlib.Path,
) -> None:
event_data = {
"pull_request": {
"base": {"sha": "abc123"},
},
}
event_file = tmp_path / "event.json"
event_file.write_text(json.dumps(event_data))

monkeypatch.setenv("GITHUB_EVENT_PATH", str(event_file))
monkeypatch.delenv("GITHUB_BASE_REF", raising=False)

result = base_detector.detect()

assert result == "abc123"


def test_detect_base_merge_queue_override(
monkeypatch: pytest.MonkeyPatch,
tmp_path: pathlib.Path,
) -> None:
event_data = {
"pull_request": {
"title": "merge-queue: Merge group",
"body": "```yaml\nchecking_base_sha: xyz789\n```",
"base": {"sha": "abc123"},
},
}
event_file = tmp_path / "event.json"
event_file.write_text(json.dumps(event_data))

monkeypatch.setenv("GITHUB_EVENT_PATH", str(event_file))

result = base_detector.detect()

assert result == "xyz789"


def test_detect_base_no_info(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("GITHUB_EVENT_PATH", raising=False)
monkeypatch.delenv("GITHUB_BASE_REF", raising=False)

with pytest.raises(click.ClickException, match="Could not detect base SHA"):
base_detector.detect()


def test_yaml_docs_from_fenced_blocks_valid() -> None:
body = """Some text
```yaml
---
checking_base_sha: xyz789
pull_requests: [{"number": 1}]
previous_failed_batches: []
...
```
More text"""

result = base_detector._yaml_docs_from_fenced_blocks(body)

assert result == base_detector.MergeQueueMetadata(
{
"checking_base_sha": "xyz789",
"pull_requests": [{"number": 1}],
"previous_failed_batches": [],
},
)


def test_yaml_docs_from_fenced_blocks_no_yaml() -> None:
body = "No yaml here"

result = base_detector._yaml_docs_from_fenced_blocks(body)

assert result is None


def test_yaml_docs_from_fenced_blocks_empty_yaml() -> None:
body = """Some text
```yaml
```
More text"""

result = base_detector._yaml_docs_from_fenced_blocks(body)

assert result is None
Loading