Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions kernel_patches_daemon/branch_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,12 @@ def same_series_different_target(ref: str, other_ref: str) -> bool:
return ref1["series"] == ref2["series"] and ref1["target"] != ref2["target"]


def prs_for_the_same_series(pr1: PullRequest, pr2: PullRequest) -> bool:
return pr1.title == pr2.title or same_series_different_target(
pr1.head.ref, pr2.head.ref
)


def _reset_repo(repo, branch: str) -> None:
"""
Reset the repository into a known good state, with `branch` checked out.
Expand Down Expand Up @@ -767,18 +773,14 @@ async def _guess_pr(
- try to guess based on first series
"""

# try to find amond known relevant PRs
if series.subject in self.prs:
return self.prs[series.subject]

if not branch:
# resolve branch: series -> subject -> branch
subject = Subject(series.subject, self.patchwork)
branch = await self.subject_to_branch(subject)

try:
# we assuming only one PR can be active for one head->base
return self.all_prs[branch][self.repo_branch][0]
return self.all_prs[branch][self.repo_pr_base_branch][0]
except (KeyError, IndexError):
pass

Expand Down Expand Up @@ -1218,6 +1220,7 @@ async def evaluate_ci_result(
new_label = StatusLabelSuffixes.FAIL.to_label(series.version)
not_label = StatusLabelSuffixes.PASS.to_label(series.version)

pr.update() # make sure we are looking at the up to date labels
labels = {label.name for label in pr.labels}
# Always make sure to remove the unused label so that we eventually
# converge on having only one pass/fail label for each version, come
Expand Down
45 changes: 33 additions & 12 deletions kernel_patches_daemon/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,49 @@ def __init__(
loop_delay: int = DEFAULT_LOOP_DELAY,
) -> None:
self.project: str = kpd_config.patchwork.project
self.github_sync_worker = GithubSync(
kpd_config=kpd_config, labels_cfg=labels_cfg
)
self.kpd_config = kpd_config
self.labels_cfg = labels_cfg
self.loop_delay: Final[int] = loop_delay
self.metrics_logger = metrics_logger
self.github_sync_worker: GithubSync = GithubSync(
kpd_config=self.kpd_config, labels_cfg=self.labels_cfg
)

def reset_github_sync(self) -> bool:
try:
self.github_sync_worker = GithubSync(
kpd_config=self.kpd_config, labels_cfg=self.labels_cfg
)
return True
except Exception:
logger.exception(
"Unhandled exception while creating GithubSync object",
exc_info=True,
)
return False

async def submit_metrics(self) -> None:
if self.metrics_logger:
logger.info("Submitting run metrics into metrics logger")
try:
self.metrics_logger(self.project, self.github_sync_worker.stats)
except Exception:
logger.exception(
"Failed to submit run metrics into metrics logger", exc_info=True
)
else:
if self.metrics_logger is None:
logger.warn(
"Not submitting run metrics because metrics logger is not configured"
)
return
try:
self.metrics_logger(self.project, self.github_sync_worker.stats)
logger.info("Submitted run metrics into metrics logger")
except Exception:
logger.exception(
"Failed to submit run metrics into metrics logger", exc_info=True
)

async def run(self) -> None:
while True:
ok = self.reset_github_sync()
if not ok:
logger.error(
"Most likely something went wrong connecting to GitHub or Patchwork. Skipping this iteration without submitting metrics."
)
continue
try:
await self.github_sync_worker.sync_patches()
self.github_sync_worker.increment_counter("runs_successful")
Expand Down
33 changes: 27 additions & 6 deletions kernel_patches_daemon/github_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
parse_pr_ref,
parsed_pr_ref_ok,
pr_has_label,
same_series_different_target,
prs_for_the_same_series,
)
from kernel_patches_daemon.config import BranchConfig, KPDConfig
from kernel_patches_daemon.github_logs import (
Expand Down Expand Up @@ -156,12 +156,19 @@ def close_existing_prs_for_series(
the same series name, but different remote branch and close them.
"""

prs_to_close = [
existing_prs = [
existing_pr
for worker in workers
for existing_pr in worker.prs.values()
if same_series_different_target(pr.head.ref, existing_pr.head.ref)
if existing_pr.number != pr.number
]

prs_to_close = [
existing_pr
for existing_pr in existing_prs
if prs_for_the_same_series(pr, existing_pr)
]

# Remove matching PRs from other workers
for pr_to_close in prs_to_close:
logging.info(
Expand Down Expand Up @@ -329,21 +336,35 @@ async def sync_patches(self) -> None:
parsed_ref = parse_pr_ref(pr.head.ref)
# ignore unknown format branch/PRs.
if not parsed_pr_ref_ok(parsed_ref):
logger.warning(
logger.info(
f"Unexpected format of the branch name: {pr.head.ref}"
)
continue

if parsed_ref["target"] != worker.repo_branch:
logger.info(
f"Skipping sync of PR {pr.number} ({pr.head.ref}) as it's not for {worker.repo_branch}"
)
continue

series_id = parsed_ref["series_id"]
series = await self.pw.get_series_by_id(series_id)
subject = self.pw.get_subject_by_series(series)
if subject_name != subject.subject:
logger.warning(
logger.info(
f"Renaming {pr} from {subject_name} to {subject.subject} according to {series.id}"
)
pr.edit(title=subject.subject)

latest_series = await subject.latest_series()
if latest_series is None:
logger.warning(
f"Closing {pr} associated with irrelevent or outdated series {series_id}"
)
pr.edit(state="close")
continue

branch_name = await worker.subject_to_branch(subject)
latest_series = await subject.latest_series() or series
pr = await self.checkout_and_patch_safe(
worker, branch_name, latest_series
)
Expand Down
55 changes: 41 additions & 14 deletions tests/test_branch_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@

# pyre-unsafe

import os
import random
import re
import shutil
import tempfile
import unittest
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from unittest.mock import MagicMock, patch

import git
Expand All @@ -35,6 +34,7 @@
EmailBodyContext,
furnish_ci_email_body,
parse_pr_ref,
prs_for_the_same_series,
same_series_different_target,
temporary_patch_file,
UPSTREAM_REMOTE_NAME,
Expand Down Expand Up @@ -633,16 +633,6 @@ def test_delete_branches(self) -> None:
ggr.assert_called_once_with(f"heads/{branch_deleted}")
ggr.return_value.delete.assert_called_once()

@aioresponses()
async def test_guess_pr_return_from_active_pr_cache(self, m: aioresponses) -> None:
# Whatever is in our self.prs's cache dictionary will be returned.
series = Series(self._pw, SERIES_DATA)
sentinel = random.random()
# pyre-fixme[6]: For 2nd argument expected `PullRequest` but got `float`.
self._bw.prs["foo"] = sentinel
pr = await self._bw._guess_pr(series)
self.assertEqual(sentinel, pr)

async def test_guess_pr_return_from_secondary_cache_with_specified_branch(
self,
) -> None:
Expand All @@ -654,7 +644,7 @@ async def test_guess_pr_return_from_secondary_cache_with_specified_branch(
series = Series(self._pw, SERIES_DATA)
sentinel = random.random()
self._bw.all_prs[mybranch] = {}
self._bw.all_prs[mybranch][TEST_REPO_BRANCH] = [sentinel]
self._bw.all_prs[mybranch][TEST_REPO_PR_BASE_BRANCH] = [sentinel]
pr = await self._bw._guess_pr(series, mybranch)
self.assertEqual(sentinel, pr)

Expand All @@ -671,7 +661,7 @@ async def test_guess_pr_return_from_secondary_cache_without_specified_branch(

sentinel = random.random()
self._bw.all_prs[mybranch] = {}
self._bw.all_prs[mybranch][TEST_REPO_BRANCH] = [sentinel]
self._bw.all_prs[mybranch][TEST_REPO_PR_BASE_BRANCH] = [sentinel]
pr = await self._bw._guess_pr(series, mybranch)
self.assertEqual(sentinel, pr)

Expand Down Expand Up @@ -902,6 +892,43 @@ def test_same_series_different_target(self) -> None:
)
)

def test_prs_for_the_same_series(self) -> None:
def create_mock_pr(title: Optional[str], head_ref: str) -> MagicMock:
pr = MagicMock()
pr.title = title
pr.head.ref = head_ref
return pr

with self.subTest("same_title_different_series"):
# Title match should return True regardless of series
pr1 = create_mock_pr("Fix memory leak", "series/123=>main")
pr2 = create_mock_pr("Fix memory leak", "series/456=>bpf-next")
self.assertTrue(prs_for_the_same_series(pr1, pr2))

with self.subTest("different_title_same_series_different_target"):
# Same series, different target should return True regardless of title
pr1 = create_mock_pr("Fix memory leak", "series/123=>main")
pr2 = create_mock_pr("Different title", "series/123=>bpf-next")
self.assertTrue(prs_for_the_same_series(pr1, pr2))

with self.subTest("different_title_different_series"):
# No match
pr1 = create_mock_pr("Fix memory leak", "series/123=>main")
pr2 = create_mock_pr("Different title", "series/456=>main")
self.assertFalse(prs_for_the_same_series(pr1, pr2))

with self.subTest("none_vs_string_title"):
# None vs string title should not match
pr1 = create_mock_pr(None, "series/123=>main")
pr2 = create_mock_pr("Fix memory leak", "series/456=>bpf-next")
self.assertFalse(prs_for_the_same_series(pr1, pr2))

with self.subTest("malformed_refs_no_match"):
# Malformed refs with different titles should return False
pr1 = create_mock_pr("Fix memory leak", "invalid-ref-format")
pr2 = create_mock_pr("Different title", "also-invalid")
self.assertFalse(prs_for_the_same_series(pr1, pr2))


class TestGitSeriesAlreadyApplied(unittest.IsolatedAsyncioTestCase):
def setUp(self):
Expand Down
4 changes: 3 additions & 1 deletion tests/test_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import unittest
from typing import Any, Dict, List
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch

from kernel_patches_daemon.config import KPDConfig
from kernel_patches_daemon.daemon import KernelPatchesWorker
Expand Down Expand Up @@ -68,6 +68,7 @@ def setUp(self) -> None:
)

self.worker.github_sync_worker.sync_patches = AsyncMock()
self.worker.reset_github_sync = MagicMock(return_value=True)

async def test_run_ok(self) -> None:
with (
Expand All @@ -81,6 +82,7 @@ async def test_run_ok(self) -> None:

gh_sync = self.worker.github_sync_worker
gh_sync.sync_patches.assert_called_once()
self.worker.reset_github_sync.assert_called_once()
self.assertEqual(len(LOGGED_METRICS), 1)
stats = LOGGED_METRICS[0][self.worker.project]
self.assertEqual(stats["runs_successful"], 1)
Expand Down