Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions kernel_patches_daemon/branch_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@
from github.Repository import Repository
from github.WorkflowJob import WorkflowJob

from kernel_patches_daemon.config import EmailConfig
from kernel_patches_daemon.config import (
SERIES_ID_SEPARATOR,
SERIES_TARGET_SEPARATOR,
EmailConfig,
)
from kernel_patches_daemon.github_connector import GithubConnector
from kernel_patches_daemon.github_logs import GithubFailedJobLog, GithubLogExtractor
from kernel_patches_daemon.patchwork import Patchwork, Series, Subject
Expand Down Expand Up @@ -79,7 +83,7 @@
ALREADY_MERGED_LOOKBACK = 100
BRANCH_TTL = 172800 # 1 week
PULL_REQUEST_TTL = timedelta(days=7)
HEAD_BASE_SEPARATOR = "=>"

KNOWN_OK_COMMENT_EXCEPTIONS = {
"Commenting is disabled on issues with more than 2500 comments"
}
Expand Down Expand Up @@ -398,21 +402,36 @@ def create_color_labels(labels_cfg: Dict[str, str], repo: Repository) -> None:
repo.create_label(name=label, color=color)


def get_base_branch_from_ref(ref: str) -> str:
return ref.split(HEAD_BASE_SEPARATOR)[0]
def parse_pr_ref(ref: str) -> Dict[str, Any]:
# "series/123456=>target-branch" ->
# {
# "series": "series/123456",
# "series_id": 123456,
# "target": "target-branch",
# }
res = {}
tmp = ref.split(SERIES_TARGET_SEPARATOR, maxsplit=1)
res["series"] = tmp[0]
if len(tmp) >= 2:
res["target"] = tmp[1]

tmp = res["series"].split("/", maxsplit=1)
if len(tmp) >= 2:
res["series_id"] = int(tmp[1])

def has_same_base_different_remote(ref: str, other_ref: str) -> bool:
if ref == other_ref:
return False
return res

base = get_base_branch_from_ref(ref)
other_base = get_base_branch_from_ref(other_ref)

if base != other_base:
return False
def parsed_pr_ref_ok(parsed_ref: Dict[str, Any]) -> bool:
return "target" in parsed_ref and "series_id" in parsed_ref

return True

def same_series_different_target(ref: str, other_ref: str) -> bool:
if ref == other_ref:
return False
ref1 = parse_pr_ref(ref)
ref2 = parse_pr_ref(other_ref)
return ref1["series"] == ref2["series"] and ref1["target"] != ref2["target"]


def _reset_repo(repo, branch: str) -> None:
Expand Down Expand Up @@ -1081,7 +1100,8 @@ def filter_closed_pr(self, head: str) -> Optional[PullRequest]:
return res

async def subject_to_branch(self, subject: Subject) -> str:
return f"{await subject.branch}{HEAD_BASE_SEPARATOR}{self.repo_branch}"
subj_branch = await subject.branch()
return f"{subj_branch}{SERIES_TARGET_SEPARATOR}{self.repo_branch}"

async def sync_checks(self, pr: PullRequest, series: Series) -> None:
# Make sure that we are working with up-to-date data (as opposed to
Expand Down Expand Up @@ -1224,9 +1244,10 @@ def expire_branches(self) -> None:
# that are not belong to any known open prs
continue

if HEAD_BASE_SEPARATOR in branch:
split = branch.split(HEAD_BASE_SEPARATOR)
if len(split) > 1 and split[1] == self.repo_branch:
parsed_ref = parse_pr_ref(branch)

if parsed_pr_ref_ok(parsed_ref):
if parsed_ref["target"] == self.repo_branch:
# which have our repo_branch as target
# that doesn't have any closed PRs
# with last update within defined TTL
Expand Down
3 changes: 3 additions & 0 deletions kernel_patches_daemon/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@

logger = logging.getLogger(__name__)

SERIES_TARGET_SEPARATOR = "=>"
SERIES_ID_SEPARATOR = "/"


class UnsupportedConfigVersion(ValueError):
def __init__(self, version: int) -> None:
Expand Down
133 changes: 73 additions & 60 deletions kernel_patches_daemon/github_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
from github.PullRequest import PullRequest
from kernel_patches_daemon.branch_worker import (
BranchWorker,
get_base_branch_from_ref,
has_same_base_different_remote,
HEAD_BASE_SEPARATOR,
parsed_pr_ref_ok,
same_series_different_target,
parse_pr_ref,
NewPRWithNoChangeException,
)
from kernel_patches_daemon.config import BranchConfig, KPDConfig
from kernel_patches_daemon.config import (
SERIES_TARGET_SEPARATOR,
BranchConfig,
KPDConfig,
)
from kernel_patches_daemon.github_logs import (
BpfGithubLogExtractor,
DefaultGithubLogExtractor,
Expand Down Expand Up @@ -133,7 +137,8 @@ def __init__(

async def get_mapped_branches(self, series: Series) -> List[str]:
for tag in self.tag_to_branch_mapping:
if tag in await series.all_tags():
series_tags = await series.all_tags()
if tag in series_tags:
mapped_branches = self.tag_to_branch_mapping[tag]
logging.info(f"Tag '{tag}' mapped to branch order {mapped_branches}")
return mapped_branches
Expand All @@ -142,20 +147,20 @@ async def get_mapped_branches(self, series: Series) -> List[str]:
logging.info(f"Mapped to default branch order: {mapped_branches}")
return mapped_branches

def close_existing_prs_with_same_base(
def close_existing_prs_for_series(
self, workers: Sequence["BranchWorker"], pr: PullRequest
) -> None:
"""Close existing pull requests with the same base, but different remote branch
"""Close existing pull requests for the same series, but different target branch

For given pull request, find all other pull requests with
the same base, but different remote branch and close them.
the same series name, but different remote branch and close them.
"""

prs_to_close = [
existing_pr
for worker in workers
for existing_pr in worker.prs.values()
if has_same_base_different_remote(pr.head.ref, existing_pr.head.ref)
if same_series_different_target(pr.head.ref, existing_pr.head.ref)
]
# Remove matching PRs from other workers
for pr_to_close in prs_to_close:
Expand All @@ -169,7 +174,7 @@ def close_existing_prs_with_same_base(
del worker.prs[pr_to_close.title]

async def checkout_and_patch_safe(
self, worker, branch_name: str, series_to_apply: Series
self, worker: BranchWorker, branch_name: str, series_to_apply: Series
) -> Optional[PullRequest]:
try:
self.increment_counter("all_known_subjects")
Expand All @@ -186,6 +191,56 @@ async def checkout_and_patch_safe(
)
return None

async def sync_relevant_subject(self, subject: Subject) -> None:
"""
1. Get Subject's latest series
2. Get series tags
3. Map tags to a branches
4. Start from first branch, try to apply and generate PR,
if fails continue to next branch, if no more branches, generate a merge-conflict PR
"""
series = none_throws(await subject.latest_series())
tags = await series.all_tags()
logging.info(f"Processing {series.id}: {subject.subject} (tags: {tags})")

mapped_branches = await self.get_mapped_branches(series)
if len(mapped_branches) == 0:
logging.info(
f"Skipping {series.id}: {subject.subject} for no mapped branches."
)
return

last_branch = mapped_branches[-1]
for branch in mapped_branches:
worker = self.workers[branch]
# PR branch name == sid of the first known series
pr_branch_name = await worker.subject_to_branch(subject)
(applied, _, _) = await worker.try_apply_mailbox_series(
pr_branch_name, series
)
if not applied:
msg = f"Failed to apply series to {branch}, "
if branch != last_branch:
logging.info(msg + "moving to next.")
continue
else:
logging.info(msg + "no more next, staying.")

logging.info(f"Choosing branch {branch} to create/update PR.")
pr = await self.checkout_and_patch_safe(worker, pr_branch_name, series)
if pr is None:
continue

logging.info(
f"Created/updated {pr} ({pr.head.ref}): {pr.url} for series {series.id}"
)
await worker.sync_checks(pr, series)
# Close out other PRs if exists
self.close_existing_prs_for_series(list(self.workers.values()), pr)

break
pass

async def sync_patches(self) -> None:
"""
One subject = one branch
Expand All @@ -208,7 +263,8 @@ async def sync_patches(self) -> None:
await loop.run_in_executor(None, worker.get_pulls)
await loop.run_in_executor(None, worker.do_sync)
worker._closed_prs = None
worker.branches = [x.name for x in worker.repo.get_branches()]
branches = worker.repo.get_branches()
worker.branches = [b.name for b in branches]

mirror_done = time.time()

Expand All @@ -223,52 +279,8 @@ async def sync_patches(self) -> None:

pw_done = time.time()

# 1. Get Subject's latest series
# 2. Get series tags
# 3. Map tags to a branches
# 4. Start from first branch, try to apply and generate PR,
# if fails continue to next branch, if no more branches, generate a merge-conflict PR
for subject in self.subjects:
series = none_throws(await subject.latest_series)
logging.info(
f"Processing {series.id}: {subject.subject} (tags: {await series.all_tags()})"
)

mapped_branches = await self.get_mapped_branches(series)
# series to apply - last known series
if len(mapped_branches) == 0:
logging.info(
f"Skipping {series.id}: {subject.subject} for no mapped branches."
)
continue
last_branch = mapped_branches[-1]
for branch in mapped_branches:
worker = self.workers[branch]
# PR branch name == sid of the first known series
pr_branch_name = await worker.subject_to_branch(subject)
apply_mbox = await worker.try_apply_mailbox_series(
pr_branch_name, series
)
if not apply_mbox[0]:
msg = f"Failed to apply series to {branch}, "
if branch != last_branch:
logging.info(msg + "moving to next.")
continue
else:
logging.info(msg + "no more next, staying.")
logging.info(f"Choosing branch {branch} to create/update PR.")
pr = await self.checkout_and_patch_safe(worker, pr_branch_name, series)
if pr is None:
continue

logging.info(
f"Created/updated {pr} ({pr.head.ref}): {pr.url} for series {series.id}"
)
await worker.sync_checks(pr, series)
# Close out other PRs if exists
self.close_existing_prs_with_same_base(list(self.workers.values()), pr)

break
await self.sync_relevant_subject(subject)

# sync old subjects
subject_names = {x.subject for x in self.subjects}
Expand All @@ -278,20 +290,21 @@ async def sync_patches(self) -> None:
continue

if worker._is_relevant_pr(pr):
parsed_ref = parse_pr_ref(pr.head.ref)
# ignore unknown format branch/PRs.
if "/" not in pr.head.ref or HEAD_BASE_SEPARATOR not in pr.head.ref:
if parsed_pr_ref_ok(parsed_ref):
continue

series_id = int(get_base_branch_from_ref(pr.head.ref.split("/")[1]))
series_id = parsed_ref["series_id"]
series = await self.pw.get_series_by_id(series_id)
subject = self.pw.get_subject_by_series(series)
if subject_name != subject.subject:
logger.warning(
f"Renaming {pr} from {subject_name} to {subject.subject} according to {series.id}"
)
pr.edit(title=subject.subject)
branch_name = f"{await subject.branch}{HEAD_BASE_SEPARATOR}{worker.repo_branch}"
latest_series = await subject.latest_series or series
branch_name = await worker.subject_to_branch(subject)
latest_series = await subject.latest_series() or series
pr = await self.checkout_and_patch_safe(
worker, branch_name, latest_series
)
Expand Down
20 changes: 9 additions & 11 deletions kernel_patches_daemon/patchwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from aiohttp_retry import ExponentialRetry, RetryClient
from cachetools import TTLCache

from kernel_patches_daemon.config import SERIES_ID_SEPARATOR
from kernel_patches_daemon.status import Status
from multidict import MultiDict

Expand Down Expand Up @@ -288,21 +289,18 @@ def __init__(self, subject: str, pw_client: "Patchwork") -> None:
self.pw_client = pw_client
self.subject = subject

@property
async def branch(self) -> Optional[str]:
relevant_series = await self.relevant_series
relevant_series = await self.relevant_series()
if len(relevant_series) == 0:
return None
return f"series/{relevant_series[0].id}"
return f"series{SERIES_ID_SEPARATOR}{relevant_series[0].id}"

@property
async def latest_series(self) -> Optional["Series"]:
relevant_series = await self.relevant_series
relevant_series = await self.relevant_series()
if len(relevant_series) == 0:
return None
return relevant_series[-1]

@property
@cached(cache=TTLCache(maxsize=1, ttl=600))
async def relevant_series(self) -> List["Series"]:
"""
Expand Down Expand Up @@ -580,7 +578,8 @@ async def __get_objects_recursive(
params = {}
while True:
response = await self.__get(path, params=params)
items += await response.json()
j = await response.json()
items += j

if "next" not in response.links:
break
Expand Down Expand Up @@ -695,9 +694,8 @@ async def post_check_for_patch_id(
async def get_series_by_id(self, series_id: int) -> Series:
# fetches directly only if series is not available in local scope
if series_id not in self.known_series:
self.known_series[series_id] = Series(
self, await self.__get_object_by_id("series", series_id)
)
series_json = await self.__get_object_by_id("series", series_id)
self.known_series[series_id] = Series(self, series_json)

return self.known_series[series_id]

Expand Down Expand Up @@ -767,7 +765,7 @@ async def get_relevant_subjects(self) -> Sequence[Subject]:
async def fetch_latest_series(
subject_name, subject_obj
) -> Tuple[str, Series, Optional[Series]]:
return (subject_name, subject_obj, await subject_obj.latest_series)
return (subject_name, subject_obj, await subject_obj.latest_series())

tasks = [fetch_latest_series(k, v) for k, v in subjects.items()]
tasks = await asyncio.gather(*tasks)
Expand Down
Loading