diff --git a/src/codegen/extensions/langchain/tools.py b/src/codegen/extensions/langchain/tools.py index c63843738..5e78db2d1 100644 --- a/src/codegen/extensions/langchain/tools.py +++ b/src/codegen/extensions/langchain/tools.py @@ -8,6 +8,7 @@ from codegen.extensions.linear.linear_client import LinearClient from codegen.extensions.tools.bash import run_bash_command from codegen.extensions.tools.github.search import search +from codegen.extensions.tools.github.view_pr_checks import view_pr_checks from codegen.extensions.tools.linear.linear import ( linear_comment_on_issue_tool, linear_create_issue_tool, @@ -669,6 +670,28 @@ def _run( return result.render() +class GithubViewPRCheckInput(BaseModel): + """Input for viewing PR checks""" + + pr_number: int = Field(..., description="The PR number to view checks for") + + +class GithubViewPRCheckTool(BaseTool): + """Tool for viewing PR checks.""" + + name: ClassVar[str] = "view_pr_checks" + description: ClassVar[str] = "View the check suites for a PR" + args_schema: ClassVar[type[BaseModel]] = GithubCreatePRReviewCommentInput + codebase: Codebase = Field(exclude=True) + + def __init__(self, codebase: Codebase) -> None: + super().__init__(codebase=codebase) + + def _run(self, pr_number: int) -> str: + result = view_pr_checks(self.codebase, pr_number=pr_number) + return result.render() + + ######################################################################################################################## # LINEAR ######################################################################################################################## diff --git a/src/codegen/extensions/tools/github/view_pr_checks.py b/src/codegen/extensions/tools/github/view_pr_checks.py new file mode 100644 index 000000000..9f17edaa2 --- /dev/null +++ b/src/codegen/extensions/tools/github/view_pr_checks.py @@ -0,0 +1,54 @@ +"""Tool for creating PR review comments.""" + +import json + +from pydantic import Field + +from codegen.sdk.core.codebase import Codebase + +from ..observation import Observation + + +class PRCheckObservation(Observation): + """Response from retrieving PR checks.""" + + pr_number: int = Field( + description="PR number that was viewed", + ) + head_sha: str | None = Field( + description="SHA of the head commit", + ) + summary: str | None = Field( + description="Summary of the checks", + ) + + +def view_pr_checks(codebase: Codebase, pr_number: int) -> PRCheckObservation: + """Retrieve check information from a Github PR . + + Args: + codebase: The codebase to operate on + pr_number: The PR number to view checks on + """ + try: + pr = codebase.op.remote_git_repo.get_pull_safe(pr_number) + if not pr: + return PRCheckObservation( + pr_number=pr_number, + head_sha=None, + summary=None, + ) + commit = codebase.op.remote_git_repo.get_commit_safe(pr.head.sha) + all_check_suites = commit.get_check_suites() + return PRCheckObservation( + pr_number=pr_number, + head_sha=pr.head.sha, + summary="\n".join([json.dumps(check_suite.raw_data) for check_suite in all_check_suites]), + ) + + except Exception as e: + return PRCheckObservation( + pr_number=pr_number, + head_sha=None, + summary=None, + )