diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index fa4364fe3..6c5ac7a5e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -108,6 +108,7 @@ jobs: poetry run patchwork AutoFix --log debug \ --patched_api_key=${{ secrets.PATCHED_API_KEY }} \ --github_api_key=${{ secrets.SCM_GITHUB_KEY }} \ + --issue_url=https://github.com/patched-codes/patchwork/issues/1039 \ --force_pr_creation \ --disable_telemetry diff --git a/patchwork/common/client/llm/anthropic.py b/patchwork/common/client/llm/anthropic.py index 0bc7f98ae..49a5660c7 100644 --- a/patchwork/common/client/llm/anthropic.py +++ b/patchwork/common/client/llm/anthropic.py @@ -84,6 +84,7 @@ def __get_model_limit(self, model: str) -> int: return 200_000 - safety_margin def __adapt_input_messages(self, messages: Iterable[ChatCompletionMessageParam]) -> list[MessageParam]: + system: Union[str, Iterable[TextBlockParam]] | NotGiven = NOT_GIVEN new_messages = [] for message in messages: if message.get("role") == "system": @@ -128,22 +129,22 @@ def __adapt_input_messages(self, messages: Iterable[ChatCompletionMessageParam]) return new_messages def __adapt_chat_completion_request( - self, - messages: Iterable[ChatCompletionMessageParam], - model: str, - frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, - logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, - logprobs: Optional[bool] | NotGiven = NOT_GIVEN, - max_tokens: Optional[int] | NotGiven = NOT_GIVEN, - n: Optional[int] | NotGiven = NOT_GIVEN, - presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, - response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, - stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, - tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN, - top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, + self, + messages: Iterable[ChatCompletionMessageParam], + model: str, + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, + tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, ): system: Union[str, Iterable[TextBlockParam]] | NotGiven = NOT_GIVEN adapted_messages = self.__adapt_input_messages(messages) @@ -207,22 +208,22 @@ def is_model_supported(self, model: str) -> bool: return model in self.__definitely_allowed_models or model.startswith(self.__allowed_model_prefix) def is_prompt_supported( - self, - messages: Iterable[ChatCompletionMessageParam], - model: str, - frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, - logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, - logprobs: Optional[bool] | NotGiven = NOT_GIVEN, - max_tokens: Optional[int] | NotGiven = NOT_GIVEN, - n: Optional[int] | NotGiven = NOT_GIVEN, - presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, - response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, - stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, - tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN, - top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, + self, + messages: Iterable[ChatCompletionMessageParam], + model: str, + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, + tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, ) -> int: model_limit = self.__get_model_limit(model) input_kwargs = self.__adapt_chat_completion_request( @@ -251,27 +252,27 @@ def is_prompt_supported( return model_limit - message_token_count.input_tokens def truncate_messages( - self, messages: Iterable[ChatCompletionMessageParam], model: str + self, messages: Iterable[ChatCompletionMessageParam], model: str ) -> Iterable[ChatCompletionMessageParam]: return self._truncate_messages(self, messages, model) def chat_completion( - self, - messages: Iterable[ChatCompletionMessageParam], - model: str, - frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, - logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, - logprobs: Optional[bool] | NotGiven = NOT_GIVEN, - max_tokens: Optional[int] | NotGiven = NOT_GIVEN, - n: Optional[int] | NotGiven = NOT_GIVEN, - presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, - response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, - stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN, - temperature: Optional[float] | NotGiven = NOT_GIVEN, - tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, - tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN, - top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, - top_p: Optional[float] | NotGiven = NOT_GIVEN, + self, + messages: Iterable[ChatCompletionMessageParam], + model: str, + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN, + tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, ) -> ChatCompletion: input_kwargs = self.__adapt_chat_completion_request( messages=messages, diff --git a/patchwork/common/client/scm.py b/patchwork/common/client/scm.py index 22704f9a4..49a85b57d 100644 --- a/patchwork/common/client/scm.py +++ b/patchwork/common/client/scm.py @@ -17,7 +17,8 @@ from azure.devops.released.core.core_client import CoreClient from azure.devops.released.git.git_client import GitClient from azure.devops.v7_1.git.models import GitPullRequest, GitPullRequestSearchCriteria, TeamProjectReference, GitRepository -from github import Auth, Consts, Github, GithubException, PullRequest +from github import Auth, Consts, Github, GithubException, PullRequest, Issue +from github.GithubObject import NotSet from github.GithubException import UnknownObjectException from gitlab import Gitlab, GitlabAuthenticationError, GitlabError from gitlab.v4.objects import ProjectMergeRequest @@ -197,6 +198,7 @@ def create_pr( body: str, original_branch: str, feature_branch: str, + issue_url: str | None = None, ) -> PullRequestProtocol: ... @@ -434,18 +436,26 @@ def get_slug_and_id_from_url(self, url: str) -> tuple[str, int] | None: return slug, resource_id def find_issue_by_url(self, url: str) -> IssueText | None: - slug, issue_id = self.get_slug_and_id_from_url(url) + resource_slug_and_id = self.get_slug_and_id_from_url(url) + if resource_slug_and_id is None: + return None + slug, issue_id = resource_slug_and_id return self.find_issue_by_id(slug, issue_id) def find_issue_by_id(self, slug: str, issue_id: int) -> IssueText | None: - repo = self.github.get_repo(slug) + issue = self.__find_issue_by_id(slug, issue_id) + if issue is None: + return None + return dict( + title=issue.title, + body=issue.body, + comments=[issue_comment.body for issue_comment in issue.get_comments()], + ) + + def __find_issue_by_id(self, slug: str, issue_id: int) -> Issue | None: try: - issue = repo.get_issue(issue_id) - return dict( - title=issue.title, - body=issue.body, - comments=[issue_comment.body for issue_comment in issue.get_comments()], - ) + repo = self.github.get_repo(slug) + return repo.get_issue(issue_id) except GithubException as e: logger.warn(f"Failed to get issue: {e}") return None @@ -508,10 +518,19 @@ def create_pr( body: str, original_branch: str, feature_branch: str, + issue_url: str | None = None, ) -> PullRequestProtocol: # before creating a PR, check if one already exists repo = self.github.get_repo(slug) - gh_pr = repo.create_pull(title=title, body=body, base=original_branch, head=feature_branch) + + issue_obj = NotSet + if issue_url is not None: + resource_slug_and_id = self.get_slug_and_id_from_url(issue_url) + if resource_slug_and_id is not None: + slug, issue_id = resource_slug_and_id + issue_obj = self.__find_issue_by_id(slug, issue_id) + + gh_pr = repo.create_pull(title=title, body=body, base=original_branch, head=feature_branch, issue=issue_obj) pr = GithubPullRequest(gh_pr) return pr @@ -630,7 +649,9 @@ def create_pr( body: str, original_branch: str, feature_branch: str, + issue_url: str | None = None, ) -> PullRequestProtocol: + # issue_url is unused here because we usually set it in the MR body instead for gitlab. # before creating a PR, check if one already exists project = self.gitlab.projects.get(slug) gl_mr = project.mergerequests.create( @@ -777,6 +798,7 @@ def create_pr( body: str, original_branch: str, feature_branch: str, + issue_url: str | None = None, ) -> PullRequestProtocol: # before creating a PR, check if one already exists pr_body = GitPullRequest( diff --git a/patchwork/steps/CreatePR/CreatePR.py b/patchwork/steps/CreatePR/CreatePR.py index ba2be895e..c960aecd5 100644 --- a/patchwork/steps/CreatePR/CreatePR.py +++ b/patchwork/steps/CreatePR/CreatePR.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing_extensions import Optional import git from git.exc import GitCommandError @@ -49,6 +50,7 @@ def __init__(self, inputs: dict): ) self.enabled = False + self.issue_url = inputs.get("issue_url") self.pr_body = inputs.get("pr_body", "") self.title = inputs.get("pr_title", "Patchwork PR") self.force = bool(inputs.get("force_pr_creation", False)) @@ -107,6 +109,7 @@ def run(self) -> dict: base_branch_name=self.base_branch, target_branch_name=self.target_branch, scm_client=self.scm_client, + issue_url=self.issue_url, force=self.force, ) @@ -147,17 +150,19 @@ def create_pr( base_branch_name: str, target_branch_name: str, scm_client: ScmPlatformClientProtocol, + issue_url: Optional[str] = None, force: bool = False, ): prs = scm_client.find_prs(repo_slug, original_branch=base_branch_name, feature_branch=target_branch_name) pr = next(iter(prs), None) if pr is None: pr = scm_client.create_pr( - repo_slug, - title, - body, - base_branch_name, - target_branch_name, + slug=repo_slug, + title=title, + body=body, + original_branch=base_branch_name, + feature_branch=target_branch_name, + issue_url=issue_url ) pr.set_pr_description(body) diff --git a/patchwork/steps/CreatePR/typed.py b/patchwork/steps/CreatePR/typed.py index 9bd7e4401..0b7ac4b3b 100644 --- a/patchwork/steps/CreatePR/typed.py +++ b/patchwork/steps/CreatePR/typed.py @@ -14,6 +14,7 @@ class CreatePRInputs(__CreatePRRequiredInputs, total=False): force_pr_creation: Annotated[bool, StepTypeConfig(is_config=True)] disable_pr: Annotated[bool, StepTypeConfig(is_config=True)] scm_url: Annotated[str, StepTypeConfig(is_config=True)] + issue_url: Annotated[str, StepTypeConfig(is_config=True)] gitlab_api_key: Annotated[str, StepTypeConfig(is_config=True)] github_api_key: Annotated[str, StepTypeConfig(is_config=True)] azuredevops_api_key: Annotated[str, StepTypeConfig(is_config=True)] diff --git a/pyproject.toml b/pyproject.toml index 1bb476a20..3cd563b39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "patchwork-cli" -version = "0.0.84" +version = "0.0.85" description = "" authors = ["patched.codes"] license = "AGPL"