Skip to content

Commit acb5fb9

Browse files
committed
implement azure devops pr related interfaces
1 parent c0345eb commit acb5fb9

File tree

1 file changed

+103
-32
lines changed

1 file changed

+103
-32
lines changed

patchwork/common/client/scm.py

Lines changed: 103 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,26 @@
44
import hashlib
55
import itertools
66
import time
7+
from difflib import unified_diff
78
from enum import Enum
89
from itertools import chain
910
from pathlib import Path
10-
from urllib.parse import urlparse
1111

1212
import git
1313
import gitlab.const
14-
from attrs import define
1514
from azure.devops.connection import Connection
1615
from azure.devops.released.client_factory import ClientFactory
1716
from azure.devops.released.core.core_client import CoreClient
1817
from azure.devops.released.git.git_client import GitClient
19-
from azure.devops.v7_1.git.models import GitPullRequest, GitPullRequestSearchCriteria, TeamProjectReference, GitRepository
18+
from azure.devops.v7_1.git.models import GitPullRequest, GitPullRequestSearchCriteria, TeamProjectReference, \
19+
GitRepository, Comment, GitPullRequestCommentThread, GitTargetVersionDescriptor, GitBaseVersionDescriptor
2020
from github import Auth, Consts, Github, GithubException, PullRequest
2121
from github.GithubException import UnknownObjectException
2222
from gitlab import Gitlab, GitlabAuthenticationError, GitlabError
2323
from gitlab.v4.objects import ProjectMergeRequest
2424
from giturlparse import GitUrlParsed, parse
2525
from msrest.authentication import BasicAuthentication
26-
from typing_extensions import Protocol, TypedDict
26+
from typing_extensions import Protocol, TypedDict, Iterator
2727

2828
from patchwork.logger import logger
2929

@@ -35,14 +35,6 @@ def get_slug_from_remote_url(remote_url: str) -> str:
3535
return slug
3636

3737

38-
@define
39-
class Comment:
40-
path: str
41-
body: str
42-
start_line: int | None
43-
end_line: int
44-
45-
4638
class IssueText(TypedDict):
4739
title: str
4840
body: str
@@ -386,18 +378,71 @@ def url(self) -> str:
386378
def set_pr_description(self, body: str) -> None:
387379
final_body = PullRequestProtocol._apply_pr_template(self, body)
388380
body = GitPullRequest(description=final_body)
389-
self.git_client.update_pull_request(body, repository_id=self._pr.repository.id, pull_request_id=self._pr.pull_request_id, project=self._pr.repository.project.id)
381+
self._pr = self.git_client.update_pull_request(body, repository_id=self._pr.repository.id, pull_request_id=self._pr.pull_request_id, project=self._pr.repository.project.id)
390382

391383
def create_comment(
392384
self, body: str, path: str | None = None, start_line: int | None = None, end_line: int | None = None
393385
) -> str | None:
394-
...
386+
try:
387+
comment_body = Comment(content=body)
388+
comment_thread_body = GitPullRequestCommentThread(comments=[comment_body])
389+
comment_thread = self.git_client.create_thread(comment_thread_body, repository_id=self._pr.repository.id, pull_request_id=self.id, project=self._pr.repository.project.id)
390+
return body
391+
except Exception as e:
392+
logger.error(e)
393+
return None
394+
395+
def __iter_comments(self) -> Iterator[tuple[GitPullRequestCommentThread, list[Comment]]]:
396+
threads = self.git_client.get_threads(repository_id=self._pr.repository.id, pull_request_id=self.id, project=self._pr.repository.project.id)
397+
for thread in threads:
398+
comments = self.git_client.get_comments(repository_id=self._pr.repository.id, pull_request_id=self.id, thread_id=thread.id, project=self._pr.repository.project.id)
399+
yield thread, comments
395400

396401
def reset_comments(self) -> None:
397-
...
402+
for thread, comments in self.__iter_comments():
403+
comment_ids_to_delete = []
404+
for comment in comments:
405+
if comment.content.startswith(_COMMENT_MARKER):
406+
comment_ids_to_delete.append(comment.id)
407+
if len(comment_ids_to_delete) == len(comments):
408+
for comment_id in comment_ids_to_delete:
409+
self.git_client.delete_comment(repository_id=self._pr.repository.id, pull_request_id=self.id, thread_id=thread.id, comment_id=comment_id, project=self._pr.repository.project.id)
398410

399411
def texts(self) -> PullRequestTexts:
400-
...
412+
self.git_client.get_commit_diffs(
413+
repository_id=self._pr.repository.id,
414+
project=self._pr.repository.project.id
415+
)
416+
417+
target_branch = self._pr.last_merge_source_commit.commit_id
418+
feature_branch = self._pr.last_merge_target_commit.commit_id
419+
420+
repo = git.Repo(path=Path.cwd(), search_parent_directories=True)
421+
repo.git.fetch()
422+
target_commit = repo.commit(target_branch)
423+
feature_commit = repo.commit(feature_branch)
424+
425+
diff_index = feature_commit.diff(target_commit)
426+
diffs = dict()
427+
for diff in diff_index:
428+
a_path = diff.a_path
429+
b_path = diff.b_path
430+
a_blob = diff.a_blob.data_stream.read().decode("utf-8")
431+
b_blob = diff.b_blob.data_stream.read().decode("utf-8")
432+
diff_lines = unified_diff(a_blob.splitlines(keepends=True), b_blob.splitlines(keepends=True), a_path, b_path)
433+
diff_content = "".join(diff_lines)
434+
diffs[a_path] = diff_content
435+
436+
comments: list[PullRequestComment] = []
437+
for _, raw_comments in self.__iter_comments():
438+
for raw_comment in raw_comments:
439+
comments.append(dict(user=raw_comment.author.display_name, body=raw_comment.content))
440+
return dict(
441+
title=self._pr.title or "",
442+
body=self._pr.description or "",
443+
comments=comments,
444+
diffs=diffs,
445+
)
401446

402447
class GithubClient(ScmPlatformClientProtocol):
403448
DEFAULT_URL = Consts.DEFAULT_BASE_URL
@@ -450,11 +495,11 @@ def find_issue_by_id(self, slug: str, issue_id: int) -> IssueText | None:
450495
logger.warn(f"Failed to get issue: {e}")
451496
return None
452497

453-
def get_pr_by_url(self, url: str) -> PullRequestProtocol | None:
498+
def get_pr_by_url(self, url: str) -> GithubPullRequest | None:
454499
slug, pr_id = self.get_slug_and_id_from_url(url)
455500
return self.find_pr_by_id(slug, pr_id)
456501

457-
def find_pr_by_id(self, slug: str, pr_id: int) -> PullRequestProtocol | None:
502+
def find_pr_by_id(self, slug: str, pr_id: int) -> GithubPullRequest | None:
458503
repo = self.github.get_repo(slug)
459504
try:
460505
pr = repo.get_pull(pr_id)
@@ -508,7 +553,7 @@ def create_pr(
508553
body: str,
509554
original_branch: str,
510555
feature_branch: str,
511-
) -> PullRequestProtocol:
556+
) -> GithubPullRequest:
512557
# before creating a PR, check if one already exists
513558
repo = self.github.get_repo(slug)
514559
gh_pr = repo.create_pull(title=title, body=body, base=original_branch, head=feature_branch)
@@ -579,11 +624,11 @@ def find_issue_by_id(self, slug: str, issue_id: int) -> IssueText | None:
579624
logger.warn(f"Failed to get issue: {e}")
580625
return None
581626

582-
def get_pr_by_url(self, url: str) -> PullRequestProtocol | None:
627+
def get_pr_by_url(self, url: str) -> GitlabMergeRequest | None:
583628
slug, pr_id = self.get_slug_and_id_from_url(url)
584629
return self.find_pr_by_id(slug, pr_id)
585630

586-
def find_pr_by_id(self, slug: str, pr_id: int) -> PullRequestProtocol | None:
631+
def find_pr_by_id(self, slug: str, pr_id: int) -> GitlabMergeRequest | None:
587632
project = self.gitlab.projects.get(slug)
588633
try:
589634
mr = project.mergerequests.get(pr_id)
@@ -599,7 +644,7 @@ def find_prs(
599644
original_branch: str | None = None,
600645
feature_branch: str | None = None,
601646
limit: int | None = None,
602-
) -> list[PullRequestProtocol]:
647+
) -> list[GitlabMergeRequest]:
603648
project = self.gitlab.projects.get(slug)
604649
kwargs_list = dict(iterator=[True], state=[None], target_branch=[None], source_branch=[None])
605650

@@ -630,7 +675,7 @@ def create_pr(
630675
body: str,
631676
original_branch: str,
632677
feature_branch: str,
633-
) -> PullRequestProtocol:
678+
) -> GitlabMergeRequest:
634679
# before creating a PR, check if one already exists
635680
project = self.gitlab.projects.get(slug)
636681
gl_mr = project.mergerequests.create(
@@ -714,23 +759,41 @@ def set_url(self, url: str) -> None:
714759
self.__url = url
715760

716761
def test(self) -> bool:
717-
response = self.core_client.get_projects()
718-
return next(iter(response), None) is not None
762+
try:
763+
proj = self.project
764+
return True
765+
except ValueError:
766+
return False
719767

720768
def get_slug_and_id_from_url(self, url: str) -> tuple[str, int] | None:
721-
...
769+
url_parts = url.split("/")
770+
if len(url_parts) == 1:
771+
logger.error(f"Invalid URL: {url}")
772+
return None
773+
774+
try:
775+
resource_id = int(url_parts[-1])
776+
except ValueError:
777+
logger.error(f"Invalid URL: {url}")
778+
return None
779+
780+
slug = "/".join(url_parts[-6:-3])
781+
782+
return slug, resource_id
722783

723784
def find_issue_by_url(self, url: str) -> IssueText | None:
724785
...
725786

726787
def find_issue_by_id(self, slug: str, issue_id: int) -> IssueText | None:
727788
...
728789

729-
def get_pr_by_url(self, url: str) -> PullRequestProtocol | None:
730-
...
790+
def get_pr_by_url(self, url: str) -> AzureDevopsPullRequest | None:
791+
slug, resource_id = self.get_slug_and_id_from_url(url)
792+
return self.find_pr_by_id(slug, resource_id)
731793

732-
def find_pr_by_id(self, slug: str, pr_id: int) -> PullRequestProtocol | None:
733-
...
794+
def find_pr_by_id(self, slug: str, pr_id: int) -> AzureDevopsPullRequest | None:
795+
pr = self.git_client.get_pull_request(repository_id=self.repo.id, pull_request_id=pr_id, project=self.project.id)
796+
return AzureDevopsPullRequest(pr, self.git_client, self.__pr_resource_html_url())
734797

735798
def find_prs(
736799
self,
@@ -739,7 +802,7 @@ def find_prs(
739802
original_branch: str | None = None,
740803
feature_branch: str | None = None,
741804
limit: int | None = None,
742-
) -> list[PullRequestProtocol]:
805+
) -> list[AzureDevopsPullRequest]:
743806
kwargs_list = dict(status=[None], target_ref_name=[None], source_ref_name=[None])
744807

745808
if state is not None:
@@ -777,7 +840,7 @@ def create_pr(
777840
body: str,
778841
original_branch: str,
779842
feature_branch: str,
780-
) -> PullRequestProtocol:
843+
) -> AzureDevopsPullRequest:
781844
# before creating a PR, check if one already exists
782845
pr_body = GitPullRequest(
783846
source_ref_name=f"refs/heads/{feature_branch}",
@@ -796,3 +859,11 @@ def create_issue_comment(
796859
) -> str:
797860
...
798861

862+
863+
if __name__ == "__main__":
864+
azure_client = AzureDevopsClient("EZ1PAM0W9v0nUZhmgjeTvaUFp9xcCiXTot0ImFdyQg96D1YgCKdCJQQJ99ALACAAAAAAAAAAAAASAZDONH1a")
865+
repository_id = azure_client.repo.id
866+
project_id = azure_client.project.id
867+
pr = azure_client.find_pr_by_id("", 3)
868+
texts = pr.texts()
869+
print(1)

0 commit comments

Comments
 (0)