Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ jobs:
path: dist/

- name: Sign the dists with Sigstore
uses: sigstore/gh-action-sigstore-python@v3
uses: sigstore/gh-action-sigstore-python@v3.0.0
with:
inputs: >-
./dist/*.tar.gz
Expand Down
5 changes: 3 additions & 2 deletions patchwork/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,10 @@ def cli(
if not disable_telemetry:
patched.send_public_telemetry(patchflow_name, inputs)

with patched.patched_telemetry(patchflow_name, {}):
with patched.patched_telemetry(patchflow_name, {}) as output_dict:
patchflow_instance = patchflow_class(inputs)
patchflow_instance.run()
patchflow_output = patchflow_instance.run()
output_dict.update(patchflow_output)
except Exception as e:
logger.debug(traceback.format_exc())
logger.error(f"Error running patchflow {patchflow}: {e}")
Expand Down
36 changes: 27 additions & 9 deletions patchwork/common/client/patched.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ class PatchedClient(click.ParamType):
ALLOWED_TELEMETRY_KEYS = {
"model",
}
ALLOWED_TELEMETRY_OUTPUT_KEYS = {
"pr_url",
"issue_url",
}

def __init__(self, access_token: str, url: str = DEFAULT_PATCH_URL):
self.access_token = access_token
Expand Down Expand Up @@ -140,6 +144,15 @@ def __handle_telemetry_inputs(self, inputs: dict[str, Any]) -> dict:

return inputs_copy

def __handle_telemetry_outputs(self, outputs: dict[str, Any]) -> dict:
diff_keys = set(outputs.keys()).difference(self.ALLOWED_TELEMETRY_OUTPUT_KEYS)

outputs_copy = outputs.copy()
for key in diff_keys:
del outputs_copy[key]

return outputs_copy

async def _public_telemetry(self, patchflow: str, inputs: dict[str, Any]):
user_config = get_user_config()
requests.post(
Expand Down Expand Up @@ -169,38 +182,42 @@ def send_public_telemetry(self, patchflow: str, inputs: dict):

@contextlib.contextmanager
def patched_telemetry(self, patchflow: str, inputs: dict):
outputs = dict()

if not self.access_token:
yield
yield outputs
return

try:
is_valid_client = self.test_token()
except Exception as e:
logger.error(f"Access Token test failed: {e}")
yield
yield outputs
return

if not is_valid_client:
yield
yield outputs
return

try:
repo = Repo(Path.cwd(), search_parent_directories=True)
patchflow_run_id = self.record_patchflow_run(patchflow, repo, self.__handle_telemetry_inputs(inputs))
except Exception as e:
logger.error(f"Failed to record patchflow run: {e}")
yield
yield outputs
return

if patchflow_run_id is None:
yield
yield outputs
return

try:
yield
yield outputs
finally:
try:
self.finish_record_patchflow_run(patchflow_run_id, patchflow, repo)
self.finish_record_patchflow_run(
patchflow_run_id, patchflow, repo, self.__handle_telemetry_outputs(outputs)
)
except Exception as e:
logger.error(f"Failed to finish patchflow run: {e}")

Expand All @@ -222,16 +239,17 @@ def record_patchflow_run(self, patchflow: str, repo: Repo, inputs: dict) -> int
return None

logger.debug(f"Patchflow run recorded for {patchflow}")
return response.json()["id"]
return response.json().get("id")

def finish_record_patchflow_run(self, id: int, patchflow: str, repo: Repo) -> None:
def finish_record_patchflow_run(self, id: int, patchflow: str, repo: Repo, outputs: dict) -> None:
response = self._post(
url=self.url + "/v1/patchwork/",
headers={"Authorization": f"Bearer {self.access_token}"},
json={
"id": id,
"url": repo.remotes.origin.url,
"patchflow": patchflow,
"outputs": outputs
},
)

Expand Down
196 changes: 192 additions & 4 deletions patchwork/common/client/scm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,23 @@
import time
from enum import Enum
from itertools import chain
from pathlib import Path
from urllib.parse import urlparse

import git
import gitlab.const
from attrs import define
from azure.devops.connection import Connection
from azure.devops.released.client_factory import ClientFactory
from azure.devops.released.core.core_client import CoreClient
from azure.devops.released.git.git_client import GitClient
from azure.devops.v7_1.git.models import GitPullRequest, GitPullRequestSearchCriteria, TeamProjectReference, GitRepository
from github import Auth, Consts, Github, GithubException, PullRequest
from github.GithubException import UnknownObjectException
from gitlab import Gitlab, GitlabAuthenticationError, GitlabError
from gitlab.v4.objects import ProjectMergeRequest
from giturlparse import GitUrlParsed, parse
from msrest.authentication import BasicAuthentication
from typing_extensions import Protocol, TypedDict

from patchwork.logger import logger
Expand Down Expand Up @@ -53,12 +62,13 @@ class PullRequestTexts(TypedDict):


class PullRequestState(Enum):
OPEN = (["open"], ["opened"])
CLOSED = (["closed"], ["closed", "merged"])
OPEN = (["open"], ["opened"], ["active"])
CLOSED = (["closed"], ["closed", "merged"], ["completed", "abandoned", "notSet"])

def __init__(self, github_state: list[str], gitlab_state: list[str]):
def __init__(self, github_state: list[str], gitlab_state: list[str], azure_devops_state: list[str]):
self.github_state: list[str] = github_state
self.gitlab_state: list[str] = gitlab_state
self.azure_devops_state: list[str] = azure_devops_state


_COMMENT_MARKER = "<!-- PatchWork comment marker -->"
Expand Down Expand Up @@ -112,6 +122,11 @@ def _apply_pr_template(pr: "PullRequestProtocol", body: str) -> str:
# chunk_link_format = file_link_format + "_{start_line}_{end_line}"
chunk_link_format = file_link_format + ""
anchor_hash = hashlib.sha1
elif isinstance(pr, AzureDevopsPullRequest):
backup_link_format = "{url}?_a=files"
file_link_format = backup_link_format + "&path=/{path}"
chunk_link_format = file_link_format + ""
anchor_hash = hashlib.md5
else:
return pr.url()

Expand All @@ -135,7 +150,7 @@ def _apply_pr_template(pr: "PullRequestProtocol", body: str) -> str:
format_to_use = chunk_link_format

replacement_value = format_to_use.format(
url=pr.url(), diff_anchor=diff_anchor, start_line=start, end_line=end
url=pr.url(), path=path, diff_anchor=diff_anchor, start_line=start, end_line=end
)
template = template[:start_idx] + replacement_value + template[end_idx + 2 :]
start_idx, end_idx = PullRequestProtocol._get_template_indexes(template)
Expand Down Expand Up @@ -352,6 +367,37 @@ def texts(self) -> PullRequestTexts:
diffs={file.filename: file.patch for file in self._pr.get_files() if file.patch is not None},
)

class AzureDevopsPullRequest(PullRequestProtocol):
def __init__(self, pr: GitPullRequest, git_client: GitClient, pr_base_url: str):
self._pr: GitPullRequest = pr
self.git_client: GitClient = git_client
self.pr_base_url = pr_base_url

@property
def id(self) -> int:
return self._pr.pull_request_id

def url(self) -> str:
final_pr_url = self.pr_base_url
if not final_pr_url.endswith("/"):
final_pr_url += "/"
return final_pr_url + str(self.id)

def set_pr_description(self, body: str) -> None:
final_body = PullRequestProtocol._apply_pr_template(self, body)
body = GitPullRequest(description=final_body)
self.git_client.update_pull_request(body, repository_id=self._pr.repository.id, pull_request_id=self._pr.pull_request_id, project=self._pr.repository.project.id)

def create_comment(
self, body: str, path: str | None = None, start_line: int | None = None, end_line: int | None = None
) -> str | None:
...

def reset_comments(self) -> None:
...

def texts(self) -> PullRequestTexts:
...

class GithubClient(ScmPlatformClientProtocol):
DEFAULT_URL = Consts.DEFAULT_BASE_URL
Expand Down Expand Up @@ -608,3 +654,145 @@ def create_issue_comment(

obj = self.gitlab.projects.get(slug).issues.create({"title": title, "description": issue_text})
return obj["web_url"]


class AzureDevopsClient(ScmPlatformClientProtocol):
DEFAULT_URL = "https://dev.azure.com/"

def __init__(self, access_token: str, url: str = DEFAULT_URL, remote: str = "origin"):
self.credentials = BasicAuthentication('', access_token)
self.__url = url
self.__remote = remote
git_repo = git.Repo(Path.cwd(), search_parent_directories=True)
original_remote_url = git_repo.remotes[remote].url
parsed_repo: GitUrlParsed = parse(original_remote_url)
self.__org_name = parsed_repo.owner
self.__project_name = parsed_repo.groups_path.replace("/_git", "")
self.__repo_name = parsed_repo.repo

def __pr_resource_html_url(self):
url = self.__url
if not url.endswith("/"):
url += "/"
return f"{url}{self.__org_name}/{self.__project_name}/_git/{self.__repo_name}/pullrequest/"


@functools.cached_property
def clients(self) -> ClientFactory:
url = self.__url
if not url.endswith("/"):
url += "/"

conn = Connection(base_url=f"{url}{self.__org_name}", creds=self.credentials)
return conn.clients

@functools.cached_property
def git_client(self) -> GitClient:
return self.clients.get_git_client()

@functools.cached_property
def core_client(self) -> CoreClient:
return self.clients.get_core_client()

@functools.cached_property
def project(self) -> TeamProjectReference:
projs = self.core_client.get_projects()
proj = next((proj for proj in projs if proj.name == self.__project_name), None)
if proj is None:
raise ValueError(f"Unable to determine project name from remote {self.__remote} url. Parsed project name: {self.__project_name}")
return proj

@functools.cached_property
def repo(self) -> GitRepository:
repos = self.git_client.get_repositories(project=self.project.id)
git_repo = next((r for r in repos if r.name == self.__repo_name), None)
if git_repo is None:
raise ValueError(f"Unable to determine repository name from remote {self.__remote} url. Parsed repository name: {self.__repo_name}")
return git_repo

def set_url(self, url: str) -> None:
self.__url = url

def test(self) -> bool:
response = self.core_client.get_projects()
return next(iter(response), None) is not None

def get_slug_and_id_from_url(self, url: str) -> tuple[str, int] | None:
...

def find_issue_by_url(self, url: str) -> IssueText | None:
...

def find_issue_by_id(self, slug: str, issue_id: int) -> IssueText | None:
...

def get_pr_by_url(self, url: str) -> PullRequestProtocol | None:
...

def find_pr_by_id(self, slug: str, pr_id: int) -> PullRequestProtocol | None:
...

def find_prs(
self,
slug: str,
state: PullRequestState | None = None,
original_branch: str | None = None,
feature_branch: str | None = None,
limit: int | None = None,
) -> list[PullRequestProtocol]:
kwargs_list = dict(status=[None], target_ref_name=[None], source_ref_name=[None])

if state is not None:
kwargs_list["status"] = state.gitlab_state # type: ignore
if original_branch is not None:
kwargs_list["target_ref_name"] = [f"refs/heads/{original_branch}"] # type: ignore
if feature_branch is not None:
kwargs_list["source_ref_name"] = [f"refs/heads/{feature_branch}"] # type: ignore

page_list = []
keys = kwargs_list.keys()
for instance in itertools.product(*kwargs_list.values()):
kwargs = dict(((key, value) for key, value in zip(keys, instance) if value is not None))
git_pr_search = GitPullRequestSearchCriteria(
repository_id=self.repo.id,
**kwargs,
)
pr_instances = self.git_client.get_pull_requests(
project=self.project.id,
repository_id=self.repo.id,
search_criteria=git_pr_search
)
page_list.append(pr_instances)

rv_list = []
for mr in itertools.islice(itertools.chain(*page_list), limit):
rv_list.append(AzureDevopsPullRequest(mr, self.git_client, self.__pr_resource_html_url()))

return rv_list

def create_pr(
self,
slug: str,
title: str,
body: str,
original_branch: str,
feature_branch: str,
) -> PullRequestProtocol:
# before creating a PR, check if one already exists
pr_body = GitPullRequest(
source_ref_name=f"refs/heads/{feature_branch}",
target_ref_name=f"refs/heads/{original_branch}",
title=title,
description=body,
# should be web tag definition
# labels="patchwork",
)
pr_instance = self.git_client.create_pull_request(pr_body, repository_id=self.repo.id, project=self.project.id)
mr = AzureDevopsPullRequest(pr_instance, self.git_client, self.__pr_resource_html_url()) # type: ignore
return mr

def create_issue_comment(
self, slug: str, issue_text: str, title: str | None = None, issue_id: int | None = None
) -> str:
...

Loading
Loading