diff --git a/src/codegen/extensions/langchain/tools.py b/src/codegen/extensions/langchain/tools.py index 5e78db2d1..e191ca7f8 100644 --- a/src/codegen/extensions/langchain/tools.py +++ b/src/codegen/extensions/langchain/tools.py @@ -7,6 +7,7 @@ from codegen.extensions.linear.linear_client import LinearClient from codegen.extensions.tools.bash import run_bash_command +from codegen.extensions.tools.github.checkout_pr import checkout_pr from codegen.extensions.tools.github.search import search from codegen.extensions.tools.github.view_pr_checks import view_pr_checks from codegen.extensions.tools.linear.linear import ( @@ -605,6 +606,28 @@ def _run(self, pr_id: int) -> str: return result.render() +class GithubCheckoutPRInput(BaseModel): + """Input for checkout out a PR head branch.""" + + pr_number: int = Field(..., description="Number of the PR to checkout") + + +class GithubCheckoutPRTool(BaseTool): + """Tool for checking out a PR head branch.""" + + name: ClassVar[str] = "checkout_pr" + description: ClassVar[str] = "Checkout out a PR head branch" + args_schema: ClassVar[type[BaseModel]] = GithubCheckoutPRInput + codebase: Codebase = Field(exclude=True) + + def __init__(self, codebase: Codebase) -> None: + super().__init__(codebase=codebase) + + def _run(self, pr_number: int) -> str: + result = checkout_pr(self.codebase, pr_number) + return result.render() + + class GithubCreatePRCommentInput(BaseModel): """Input for creating a PR comment""" diff --git a/src/codegen/extensions/tools/github/checkout_pr.py b/src/codegen/extensions/tools/github/checkout_pr.py new file mode 100644 index 000000000..1b4b6d769 --- /dev/null +++ b/src/codegen/extensions/tools/github/checkout_pr.py @@ -0,0 +1,47 @@ +"""Tool for viewing PR contents and modified symbols.""" + +from pydantic import Field + +from codegen.sdk.core.codebase import Codebase + +from ..observation import Observation + + +class CheckoutPRObservation(Observation): + """Response from checking out a PR.""" + + pr_number: int = Field( + description="PR number", + ) + success: bool = Field( + description="Whether the checkout was successful", + default=False, + ) + + +def checkout_pr(codebase: Codebase, pr_number: int) -> CheckoutPRObservation: + """Checkout a PR. + + Args: + codebase: The codebase to operate on + pr_number: Number of the PR to get the contents for + """ + try: + pr = codebase.op.remote_git_repo.get_pull_safe(pr_number) + if not pr: + return CheckoutPRObservation( + pr_number=pr_number, + success=False, + ) + + codebase.checkout(branch=pr.head.ref) + return CheckoutPRObservation( + pr_number=pr_number, + success=True, + ) + except Exception as e: + return CheckoutPRObservation( + pr_number=pr_number, + success=False, + error=f"Failed to checkout PR: {e!s}", + )