Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions kernel_patches_daemon/branch_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,9 @@ async def send_email(
email_send_fail_counter.add(1)


def _is_pr_flagged(pr: PullRequest) -> bool:
for label in pr.get_labels():
if MERGE_CONFLICT_LABEL == label.name:
def pr_has_label(pr: PullRequest, label: str) -> bool:
for l in pr.get_labels():
if l.name == label:
return True
return False

Expand Down Expand Up @@ -855,7 +855,7 @@ async def _comment_series_pr(

if pr:
if (not has_merge_conflict) or (
has_merge_conflict and not _is_pr_flagged(pr)
has_merge_conflict and not pr_has_label(pr, MERGE_CONFLICT_LABEL)
):
if message:
self._add_pull_request_comment(pr, message)
Expand Down Expand Up @@ -1109,7 +1109,7 @@ async def sync_checks(self, pr: PullRequest, series: Series) -> None:
pr.update()
# if it's merge conflict - report failure
ctx = BranchWorker.slugify_context(f"{CI_DESCRIPTION}-{self.repo_branch}")
if _is_pr_flagged(pr):
if pr_has_label(pr, MERGE_CONFLICT_LABEL):
await series.set_check(
status=Status.CONFLICT,
target_url=pr.html_url,
Expand Down
33 changes: 31 additions & 2 deletions kernel_patches_daemon/github_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from github import Auth
from github.PullRequest import PullRequest
from kernel_patches_daemon.branch_worker import (
MERGE_CONFLICT_LABEL,
BranchWorker,
parsed_pr_ref_ok,
pr_has_label,
same_series_different_target,
parse_pr_ref,
NewPRWithNoChangeException,
Expand Down Expand Up @@ -191,6 +193,30 @@ async def checkout_and_patch_safe(
)
return None

async def select_target_branches_for_subject(
self, subject: Subject, tag_mapped_branches: List[str]
) -> List[str]:
if len(tag_mapped_branches) == 1:
return tag_mapped_branches

# Check if a single relevant open PR without merge conflicts exists.
# If yes, then pick it without trying other target branches.
subject_pr_targets = []
for branch in tag_mapped_branches:
worker = self.workers[branch]
subj_branch = await worker.subject_to_branch(subject)
for pr in worker.prs.values():
if pr.head.ref == subj_branch and not pr_has_label(
pr, MERGE_CONFLICT_LABEL
):
subject_pr_targets.append(branch)

if len(subject_pr_targets) == 1:
return subject_pr_targets

# If no sticky target is determined, then return all branches
return tag_mapped_branches

async def sync_relevant_subject(self, subject: Subject) -> None:
"""
1. Get Subject's latest series
Expand All @@ -210,8 +236,11 @@ async def sync_relevant_subject(self, subject: Subject) -> None:
)
return

last_branch = mapped_branches[-1]
for branch in mapped_branches:
target_branches = await self.select_target_branches_for_subject(
subject, mapped_branches
)
last_branch = target_branches[-1]
for branch in target_branches:
worker = self.workers[branch]
# PR branch name == sid of the first known series
pr_branch_name = await worker.subject_to_branch(subject)
Expand Down
50 changes: 48 additions & 2 deletions tests/test_github_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@

from aioresponses import aioresponses

from kernel_patches_daemon.branch_worker import NewPRWithNoChangeException
from kernel_patches_daemon.branch_worker import (
MERGE_CONFLICT_LABEL,
NewPRWithNoChangeException,
)
from kernel_patches_daemon.config import KPDConfig, SERIES_TARGET_SEPARATOR
from kernel_patches_daemon.github_sync import GithubSync
from tests.common.patchwork_mock import init_pw_responses, load_test_data, PatchworkMock
from tests.common.patchwork_mock import (
init_pw_responses,
load_test_data,
PatchworkMock,
)

TEST_BRANCH = "test-branch"
TEST_BPF_NEXT_BRANCH = "test-bpf-next"
Expand Down Expand Up @@ -280,6 +287,45 @@ async def test_sync_relevant_subject_success_second_branch(self) -> None:
list(self._gh.workers.values()), pr_mock
)

def _setup_test_select_target_branches_for_subject(self):
series_prefix = "series/123123"
subject_mock = MagicMock()
subject_mock.subject = "Test subject"
subject_mock.branch = AsyncMock(return_value=series_prefix)

pr_mock = MagicMock()
pr_mock.head.ref = (
f"{series_prefix}{SERIES_TARGET_SEPARATOR}{TEST_BPF_NEXT_BRANCH}"
)

worker_mock = self._gh.workers[TEST_BPF_NEXT_BRANCH]
worker_mock.prs = {subject_mock.subject: pr_mock}

return subject_mock, pr_mock

async def test_select_target_branches_for_subject(self) -> None:
subject_mock, _ = self._setup_test_select_target_branches_for_subject()
mapped_branches = [TEST_BRANCH, TEST_BPF_NEXT_BRANCH]

selected_branches = await self._gh.select_target_branches_for_subject(
subject_mock, mapped_branches
)

self.assertEqual(selected_branches, [TEST_BPF_NEXT_BRANCH])

async def test_select_target_branches_for_subject_merge_conflict(self) -> None:
subject_mock, pr_mock = self._setup_test_select_target_branches_for_subject()
label = MagicMock()
label.name = MERGE_CONFLICT_LABEL
pr_mock.get_labels = MagicMock(return_value=[label])
mapped_branches = [TEST_BRANCH, TEST_BPF_NEXT_BRANCH]

selected_branches = await self._gh.select_target_branches_for_subject(
subject_mock, mapped_branches
)

self.assertEqual(selected_branches, mapped_branches)

@aioresponses()
async def test_sync_patches_pr_summary_success(self, m: aioresponses) -> None:
"""
Expand Down