diff --git a/kernel_patches_daemon/branch_worker.py b/kernel_patches_daemon/branch_worker.py index df6363b..d09cbd7 100644 --- a/kernel_patches_daemon/branch_worker.py +++ b/kernel_patches_daemon/branch_worker.py @@ -351,9 +351,9 @@ async def send_email( email_send_fail_counter.add(1) -def _is_pr_flagged(pr: PullRequest) -> bool: - for label in pr.get_labels(): - if MERGE_CONFLICT_LABEL == label.name: +def pr_has_label(pr: PullRequest, label: str) -> bool: + for l in pr.get_labels(): + if l.name == label: return True return False @@ -855,7 +855,7 @@ async def _comment_series_pr( if pr: if (not has_merge_conflict) or ( - has_merge_conflict and not _is_pr_flagged(pr) + has_merge_conflict and not pr_has_label(pr, MERGE_CONFLICT_LABEL) ): if message: self._add_pull_request_comment(pr, message) @@ -1109,7 +1109,7 @@ async def sync_checks(self, pr: PullRequest, series: Series) -> None: pr.update() # if it's merge conflict - report failure ctx = BranchWorker.slugify_context(f"{CI_DESCRIPTION}-{self.repo_branch}") - if _is_pr_flagged(pr): + if pr_has_label(pr, MERGE_CONFLICT_LABEL): await series.set_check( status=Status.CONFLICT, target_url=pr.html_url, diff --git a/kernel_patches_daemon/github_sync.py b/kernel_patches_daemon/github_sync.py index 890ae5f..d0a1704 100644 --- a/kernel_patches_daemon/github_sync.py +++ b/kernel_patches_daemon/github_sync.py @@ -14,8 +14,10 @@ from github import Auth from github.PullRequest import PullRequest from kernel_patches_daemon.branch_worker import ( + MERGE_CONFLICT_LABEL, BranchWorker, parsed_pr_ref_ok, + pr_has_label, same_series_different_target, parse_pr_ref, NewPRWithNoChangeException, @@ -191,6 +193,30 @@ async def checkout_and_patch_safe( ) return None + async def select_target_branches_for_subject( + self, subject: Subject, tag_mapped_branches: List[str] + ) -> List[str]: + if len(tag_mapped_branches) == 1: + return tag_mapped_branches + + # Check if a single relevant open PR without merge conflicts exists. + # If yes, then pick it without trying other target branches. + subject_pr_targets = [] + for branch in tag_mapped_branches: + worker = self.workers[branch] + subj_branch = await worker.subject_to_branch(subject) + for pr in worker.prs.values(): + if pr.head.ref == subj_branch and not pr_has_label( + pr, MERGE_CONFLICT_LABEL + ): + subject_pr_targets.append(branch) + + if len(subject_pr_targets) == 1: + return subject_pr_targets + + # If no sticky target is determined, then return all branches + return tag_mapped_branches + async def sync_relevant_subject(self, subject: Subject) -> None: """ 1. Get Subject's latest series @@ -210,8 +236,11 @@ async def sync_relevant_subject(self, subject: Subject) -> None: ) return - last_branch = mapped_branches[-1] - for branch in mapped_branches: + target_branches = await self.select_target_branches_for_subject( + subject, mapped_branches + ) + last_branch = target_branches[-1] + for branch in target_branches: worker = self.workers[branch] # PR branch name == sid of the first known series pr_branch_name = await worker.subject_to_branch(subject) diff --git a/tests/test_github_sync.py b/tests/test_github_sync.py index e56dce6..7420064 100644 --- a/tests/test_github_sync.py +++ b/tests/test_github_sync.py @@ -14,10 +14,17 @@ from aioresponses import aioresponses -from kernel_patches_daemon.branch_worker import NewPRWithNoChangeException +from kernel_patches_daemon.branch_worker import ( + MERGE_CONFLICT_LABEL, + NewPRWithNoChangeException, +) from kernel_patches_daemon.config import KPDConfig, SERIES_TARGET_SEPARATOR from kernel_patches_daemon.github_sync import GithubSync -from tests.common.patchwork_mock import init_pw_responses, load_test_data, PatchworkMock +from tests.common.patchwork_mock import ( + init_pw_responses, + load_test_data, + PatchworkMock, +) TEST_BRANCH = "test-branch" TEST_BPF_NEXT_BRANCH = "test-bpf-next" @@ -280,6 +287,45 @@ async def test_sync_relevant_subject_success_second_branch(self) -> None: list(self._gh.workers.values()), pr_mock ) + def _setup_test_select_target_branches_for_subject(self): + series_prefix = "series/123123" + subject_mock = MagicMock() + subject_mock.subject = "Test subject" + subject_mock.branch = AsyncMock(return_value=series_prefix) + + pr_mock = MagicMock() + pr_mock.head.ref = ( + f"{series_prefix}{SERIES_TARGET_SEPARATOR}{TEST_BPF_NEXT_BRANCH}" + ) + + worker_mock = self._gh.workers[TEST_BPF_NEXT_BRANCH] + worker_mock.prs = {subject_mock.subject: pr_mock} + + return subject_mock, pr_mock + + async def test_select_target_branches_for_subject(self) -> None: + subject_mock, _ = self._setup_test_select_target_branches_for_subject() + mapped_branches = [TEST_BRANCH, TEST_BPF_NEXT_BRANCH] + + selected_branches = await self._gh.select_target_branches_for_subject( + subject_mock, mapped_branches + ) + + self.assertEqual(selected_branches, [TEST_BPF_NEXT_BRANCH]) + + async def test_select_target_branches_for_subject_merge_conflict(self) -> None: + subject_mock, pr_mock = self._setup_test_select_target_branches_for_subject() + label = MagicMock() + label.name = MERGE_CONFLICT_LABEL + pr_mock.get_labels = MagicMock(return_value=[label]) + mapped_branches = [TEST_BRANCH, TEST_BPF_NEXT_BRANCH] + + selected_branches = await self._gh.select_target_branches_for_subject( + subject_mock, mapped_branches + ) + + self.assertEqual(selected_branches, mapped_branches) + @aioresponses() async def test_sync_patches_pr_summary_success(self, m: aioresponses) -> None: """