Skip to content

Commit e3aac40

Browse files
chore: CG-10986 checkout pr tool
1 parent 3eb3324 commit e3aac40

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

src/codegen/extensions/langchain/tools.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from codegen.extensions.linear.linear_client import LinearClient
99
from codegen.extensions.tools.bash import run_bash_command
10+
from codegen.extensions.tools.github.checkout_pr import checkout_pr
1011
from codegen.extensions.tools.github.search import search
1112
from codegen.extensions.tools.github.view_pr_checks import view_pr_checks
1213
from codegen.extensions.tools.linear.linear import (
@@ -605,6 +606,28 @@ def _run(self, pr_id: int) -> str:
605606
return result.render()
606607

607608

609+
class GithubCheckoutPRInput(BaseModel):
610+
"""Input for checkout out a PR head branch."""
611+
612+
pr_number: int = Field(..., description="Number of the PR to checkout")
613+
614+
615+
class GithubCheckoutPRTool(BaseTool):
616+
"""Tool for checking out a PR head branch."""
617+
618+
name: ClassVar[str] = "checkout_pr"
619+
description: ClassVar[str] = "Checkout out a PR head branch"
620+
args_schema: ClassVar[type[BaseModel]] = GithubCheckoutPRInput
621+
codebase: Codebase = Field(exclude=True)
622+
623+
def __init__(self, codebase: Codebase) -> None:
624+
super().__init__(codebase=codebase)
625+
626+
def _run(self, pr_number: int) -> str:
627+
result = checkout_pr(self.codebase, pr_number)
628+
return result.render()
629+
630+
608631
class GithubCreatePRCommentInput(BaseModel):
609632
"""Input for creating a PR comment"""
610633

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
"""Tool for viewing PR contents and modified symbols."""
2+
3+
from pydantic import Field
4+
5+
from codegen.sdk.core.codebase import Codebase
6+
7+
from ..observation import Observation
8+
9+
10+
class CheckoutPRObservation(Observation):
11+
"""Response from checking out a PR."""
12+
13+
pr_number: int = Field(
14+
description="PR number",
15+
)
16+
success: bool = Field(
17+
description="Whether the checkout was successful",
18+
default=False,
19+
)
20+
21+
22+
def checkout_pr(codebase: Codebase, pr_number: int) -> CheckoutPRObservation:
23+
"""Checkout a PR.
24+
25+
Args:
26+
codebase: The codebase to operate on
27+
pr_number: Number of the PR to get the contents for
28+
"""
29+
try:
30+
pr = codebase.op.remote_git_repo.get_pull_safe(pr_number)
31+
if not pr:
32+
return CheckoutPRObservation(
33+
pr_number=pr_number,
34+
success=False,
35+
)
36+
37+
codebase.checkout(branch=pr.head.ref)
38+
return CheckoutPRObservation(
39+
pr_number=pr_number,
40+
success=True,
41+
)
42+
except Exception as e:
43+
return CheckoutPRObservation(
44+
pr_number=pr_number,
45+
success=False,
46+
error=f"Failed to checkout PR: {e!s}",
47+
)

0 commit comments

Comments
 (0)