Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/lando/api/legacy/workers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import subprocess
from abc import ABC, abstractmethod
from time import sleep
from typing import Callable
from typing import Callable, TypeVar

from celery import Task
from django.db import transaction
Expand Down Expand Up @@ -36,6 +36,8 @@

logger = logging.getLogger(__name__)

T = TypeVar("T")


class Worker(ABC):
"""A base class for repository workers."""
Expand Down Expand Up @@ -314,15 +316,15 @@ def update_repo(

def handle_new_commit_failures(
self,
create_revision_callable: Callable[[Revision], None],
create_revision_callable: Callable[[Revision], T],
repo: Repo,
job: BaseJob,
scm: AbstractSCM,
revision: Revision,
) -> None:
) -> T:
"""Create revisions with job status handling."""
try:
create_revision_callable(revision)
return create_revision_callable(revision)
except NoDiffStartLine as exc:
message = (
"Lando encountered a malformed patch, please try again. "
Expand Down
13 changes: 12 additions & 1 deletion src/lando/api/legacy/workers/landing_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ def convert_patches_to_diff(self, scm: AbstractSCM, job: LandingJob):
# at this time. In theory this would work for any provided patches in a
# standard format.

def get_diff_from_patches(revision: Revision) -> str:
logger.debug(f"Converting paches to single diff for {revision} ...")
return scm.get_diff_from_patches(revision.patches)

# NOTE: this is only supported for jobs with a single revision at this time.
# See bug 2001185.

Expand All @@ -188,7 +192,14 @@ def convert_patches_to_diff(self, scm: AbstractSCM, job: LandingJob):
if not revision.patches:
raise ValueError("Revision is missing patches.")

diff = scm.get_diff_from_patches(revision.patches)
diff = self.handle_new_commit_failures(
get_diff_from_patches,
job.target_repo,
job,
scm,
revision,
)

revision.set_patch(diff)
revision.save()

Expand Down
31 changes: 23 additions & 8 deletions src/lando/main/scm/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from typing import Any
from typing import Any, Callable, TypeVar

from typing_extensions import override

Expand Down Expand Up @@ -36,6 +36,24 @@
ENV_COMMITTER_EMAIL = "GIT_COMMITTER_EMAIL"


T = TypeVar("T")


def detect_patch_conflict(fn: Callable[..., T]) -> Callable[..., T]:
"""Decorator transforming SCMExceptions to PatchConflict as appropriate."""

def wrapper(*args, **kwargs) -> T:
try:
return fn(*args, **kwargs)
except SCMException as exc:
if "error: patch" in exc.err:
raise PatchConflict(exc.err) from exc

raise exc

return wrapper


class GitSCM(AbstractSCM):
"""An implementation of the AbstractVCS for Git, for use by the Repo and LandingWorkers."""

Expand Down Expand Up @@ -122,6 +140,7 @@ def last_commit_for_path(self, path: str) -> str:
return self._git_run(*command, cwd=self.path)

@override
@detect_patch_conflict
def apply_patch(
self, diff: str, commit_description: str, commit_author: str, commit_date: str
):
Expand Down Expand Up @@ -152,13 +171,7 @@ def apply_patch(
]

for c in cmds:
try:
self._git_run(*c, cwd=self.path)
except SCMException as exc:
if "error: patch" in exc.err:
raise PatchConflict(exc.err) from exc

raise exc
self._git_run(*c, cwd=self.path)

@override
def apply_patch_git(self, patch_bytes: bytes):
Expand Down Expand Up @@ -189,6 +202,7 @@ def apply_patch_git(self, patch_bytes: bytes):
# Re-raise the exception from the failed `git am`.
raise exc

@detect_patch_conflict
def get_diff_from_patches(self, patches: str) -> str:
"""Apply multiple patches and return the diff output."""
# TODO: add error handling so that if something goes wrong here,
Expand All @@ -202,6 +216,7 @@ def get_diff_from_patches(self, patches: str) -> str:
patch_file.flush()

self._git_run("apply", "--reject", patch_file.name, cwd=self.path)

self._git_run("add", "-A", "-f", cwd=self.path)
return self._git_run(
"diff", "--staged", "--binary", cwd=self.path, rstrip=False
Expand Down