Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/codegen/extensions/langchain/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from codegen.extensions.linear.linear_client import LinearClient
from codegen.extensions.tools.bash import run_bash_command
from codegen.extensions.tools.github.checkout_pr import checkout_pr
from codegen.extensions.tools.github.search import search
from codegen.extensions.tools.github.view_pr_checks import view_pr_checks
from codegen.extensions.tools.linear.linear import (
Expand Down Expand Up @@ -605,6 +606,28 @@ def _run(self, pr_id: int) -> str:
return result.render()


class GithubCheckoutPRInput(BaseModel):
"""Input for checkout out a PR head branch."""

pr_number: int = Field(..., description="Number of the PR to checkout")


class GithubCheckoutPRTool(BaseTool):
"""Tool for checking out a PR head branch."""

name: ClassVar[str] = "checkout_pr"
description: ClassVar[str] = "Checkout out a PR head branch"
args_schema: ClassVar[type[BaseModel]] = GithubCheckoutPRInput
codebase: Codebase = Field(exclude=True)

def __init__(self, codebase: Codebase) -> None:
super().__init__(codebase=codebase)

def _run(self, pr_number: int) -> str:
result = checkout_pr(self.codebase, pr_number)
return result.render()


class GithubCreatePRCommentInput(BaseModel):
"""Input for creating a PR comment"""

Expand Down
47 changes: 47 additions & 0 deletions src/codegen/extensions/tools/github/checkout_pr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Tool for viewing PR contents and modified symbols."""

from pydantic import Field

from codegen.sdk.core.codebase import Codebase

from ..observation import Observation


class CheckoutPRObservation(Observation):
"""Response from checking out a PR."""

pr_number: int = Field(
description="PR number",
)
success: bool = Field(
description="Whether the checkout was successful",
default=False,
)


def checkout_pr(codebase: Codebase, pr_number: int) -> CheckoutPRObservation:
"""Checkout a PR.

Args:
codebase: The codebase to operate on
pr_number: Number of the PR to get the contents for
"""
try:
pr = codebase.op.remote_git_repo.get_pull_safe(pr_number)
if not pr:
return CheckoutPRObservation(
pr_number=pr_number,
success=False,
)

codebase.checkout(branch=pr.head.ref)
return CheckoutPRObservation(
pr_number=pr_number,
success=True,
)
except Exception as e:
return CheckoutPRObservation(
pr_number=pr_number,
success=False,
error=f"Failed to checkout PR: {e!s}",
)
Loading