Skip to content

Commit cbf65cd

Browse files
kodjima33claude
andcommitted
Add merge_pr and list_prs tools
Add standalone merge and list PR functionality so users can ask to merge PRs directly (e.g. "merge PR #5") instead of only being able to merge as part of the code_feature flow. - Add list_pull_requests, get_pull_request, merge_pull_request to GitHubClient - Add list_prs tool: list open/closed/all PRs in a repo - Add merge_pr tool: merge a PR by number with squash/merge/rebase support - Register both tools in the omi-tools manifest Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c2e8b99 commit cbf65cd

File tree

2 files changed

+339
-0
lines changed

2 files changed

+339
-0
lines changed

github_client.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,158 @@ def get_repo_labels_with_details(
370370
print(f"⚠️ Error fetching labels: {e}")
371371
return []
372372

373+
def list_pull_requests(
374+
self,
375+
access_token: str,
376+
repo_full_name: str,
377+
state: str = "open",
378+
per_page: int = 10
379+
) -> List[Dict]:
380+
"""
381+
List pull requests in a repository.
382+
Returns list of PR dicts.
383+
"""
384+
try:
385+
response = requests.get(
386+
f"{self.api_base}/repos/{repo_full_name}/pulls",
387+
headers={
388+
"Authorization": f"Bearer {access_token}",
389+
"Accept": "application/vnd.github.v3+json"
390+
},
391+
params={
392+
"state": state,
393+
"per_page": per_page,
394+
"sort": "created",
395+
"direction": "desc"
396+
}
397+
)
398+
399+
if response.status_code == 200:
400+
pulls = response.json()
401+
return [
402+
{
403+
"number": pr["number"],
404+
"title": pr["title"],
405+
"state": pr["state"],
406+
"user": pr["user"]["login"] if pr.get("user") else None,
407+
"head": pr["head"]["ref"],
408+
"base": pr["base"]["ref"],
409+
"mergeable": pr.get("mergeable"),
410+
"url": pr["html_url"],
411+
"created_at": pr["created_at"],
412+
"draft": pr.get("draft", False),
413+
}
414+
for pr in pulls
415+
]
416+
else:
417+
print(f"Error listing PRs: {response.status_code}")
418+
return []
419+
420+
except Exception as e:
421+
print(f"Error listing PRs: {e}")
422+
return []
423+
424+
def get_pull_request(
425+
self,
426+
access_token: str,
427+
repo_full_name: str,
428+
pr_number: int
429+
) -> Optional[Dict]:
430+
"""
431+
Get details of a specific pull request.
432+
"""
433+
try:
434+
response = requests.get(
435+
f"{self.api_base}/repos/{repo_full_name}/pulls/{pr_number}",
436+
headers={
437+
"Authorization": f"Bearer {access_token}",
438+
"Accept": "application/vnd.github.v3+json"
439+
}
440+
)
441+
442+
if response.status_code == 200:
443+
pr = response.json()
444+
return {
445+
"number": pr["number"],
446+
"title": pr["title"],
447+
"state": pr["state"],
448+
"body": pr.get("body", ""),
449+
"user": pr["user"]["login"] if pr.get("user") else None,
450+
"head": pr["head"]["ref"],
451+
"base": pr["base"]["ref"],
452+
"mergeable": pr.get("mergeable"),
453+
"mergeable_state": pr.get("mergeable_state"),
454+
"merged": pr.get("merged", False),
455+
"url": pr["html_url"],
456+
"created_at": pr["created_at"],
457+
"updated_at": pr["updated_at"],
458+
"draft": pr.get("draft", False),
459+
"labels": [label["name"] for label in pr.get("labels", [])],
460+
"reviewers": [r["login"] for r in pr.get("requested_reviewers", [])],
461+
}
462+
elif response.status_code == 404:
463+
return None
464+
else:
465+
print(f"Error getting PR: {response.status_code}")
466+
return None
467+
468+
except Exception as e:
469+
print(f"Error getting PR: {e}")
470+
return None
471+
472+
def merge_pull_request(
473+
self,
474+
access_token: str,
475+
repo_full_name: str,
476+
pr_number: int,
477+
merge_method: str = "squash"
478+
) -> Dict:
479+
"""
480+
Merge a pull request.
481+
merge_method: 'merge', 'squash', or 'rebase'
482+
Returns dict with success status and message.
483+
"""
484+
try:
485+
response = requests.put(
486+
f"{self.api_base}/repos/{repo_full_name}/pulls/{pr_number}/merge",
487+
headers={
488+
"Authorization": f"Bearer {access_token}",
489+
"Accept": "application/vnd.github.v3+json"
490+
},
491+
json={"merge_method": merge_method}
492+
)
493+
494+
if response.status_code == 200:
495+
data = response.json()
496+
return {
497+
"success": True,
498+
"sha": data.get("sha"),
499+
"message": data.get("message", "Pull request merged")
500+
}
501+
elif response.status_code == 405:
502+
return {
503+
"success": False,
504+
"error": "PR cannot be merged (not mergeable, or merge blocked by branch protection rules)"
505+
}
506+
elif response.status_code == 409:
507+
return {
508+
"success": False,
509+
"error": "Merge conflict — the PR has conflicts that must be resolved first"
510+
}
511+
else:
512+
error_msg = response.json().get("message", response.text)
513+
return {
514+
"success": False,
515+
"error": f"GitHub API error ({response.status_code}): {error_msg}"
516+
}
517+
518+
except Exception as e:
519+
print(f"Error merging PR: {e}")
520+
return {
521+
"success": False,
522+
"error": str(e)
523+
}
524+
373525
def get_repo_permissions(self, access_token: str, repo_full_name: str) -> Optional[Dict]:
374526
"""
375527
Get repository permissions for the authenticated user.

main.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,58 @@ async def get_omi_tools_manifest():
234234
},
235235
"auth_required": True,
236236
"status_message": "Adding comment..."
237+
},
238+
{
239+
"name": "list_prs",
240+
"description": "List pull requests in a GitHub repository. Use this when the user wants to see pull requests, check open PRs, view recent PRs, or find a PR number.",
241+
"endpoint": "/tools/list_prs",
242+
"method": "POST",
243+
"parameters": {
244+
"properties": {
245+
"repo": {
246+
"type": "string",
247+
"description": "Repository to list PRs from (format: 'owner/repo'). If not provided, uses the user's default repository."
248+
},
249+
"state": {
250+
"type": "string",
251+
"enum": ["open", "closed", "all"],
252+
"description": "Filter by PR state. Defaults to 'open'."
253+
},
254+
"limit": {
255+
"type": "integer",
256+
"description": "Maximum number of PRs to return (default: 10, max: 50)"
257+
}
258+
},
259+
"required": []
260+
},
261+
"auth_required": True,
262+
"status_message": "Getting pull requests..."
263+
},
264+
{
265+
"name": "merge_pr",
266+
"description": "Merge a pull request in a GitHub repository. Use this when the user asks to merge a PR, accept a PR, merge changes, or apply a pull request. Requires the PR number.",
267+
"endpoint": "/tools/merge_pr",
268+
"method": "POST",
269+
"parameters": {
270+
"properties": {
271+
"pr_number": {
272+
"type": "integer",
273+
"description": "The pull request number to merge. Required."
274+
},
275+
"repo": {
276+
"type": "string",
277+
"description": "Repository the PR is in (format: 'owner/repo'). If not provided, uses the user's default repository."
278+
},
279+
"merge_method": {
280+
"type": "string",
281+
"enum": ["squash", "merge", "rebase"],
282+
"description": "How to merge the PR. Defaults to 'squash'. Use 'merge' for a merge commit, 'rebase' for rebasing."
283+
}
284+
},
285+
"required": ["pr_number"]
286+
},
287+
"auth_required": True,
288+
"status_message": "Merging pull request..."
237289
}
238290
]
239291
}
@@ -588,6 +640,141 @@ async def tool_add_comment(request: Request):
588640
return ChatToolResponse(error=f"Failed to add comment: {str(e)}")
589641

590642

643+
@app.post("/tools/list_prs", tags=["chat_tools"], response_model=ChatToolResponse)
644+
async def tool_list_prs(request: Request):
645+
"""
646+
List pull requests in a GitHub repository.
647+
"""
648+
try:
649+
body = await request.json()
650+
uid = body.get("uid")
651+
repo = body.get("repo")
652+
state = body.get("state", "open")
653+
limit = min(body.get("limit", 10), 50)
654+
655+
if not uid:
656+
return ChatToolResponse(error="User ID is required")
657+
658+
user = SimpleUserStorage.get_user(uid)
659+
if not user or not user.get("access_token"):
660+
return ChatToolResponse(
661+
error="Please connect your GitHub account first in the app settings."
662+
)
663+
664+
repo_full_name, error = get_repo_for_request(user, repo)
665+
if error:
666+
return ChatToolResponse(error=error)
667+
668+
prs = github_client.list_pull_requests(
669+
access_token=user["access_token"],
670+
repo_full_name=repo_full_name,
671+
state=state,
672+
per_page=limit
673+
)
674+
675+
if not prs:
676+
return ChatToolResponse(result=f"No {state} pull requests found in {repo_full_name}.")
677+
678+
result_parts = [f"**{state.title()} Pull Requests in {repo_full_name} ({len(prs)})**", ""]
679+
for pr in prs:
680+
draft_str = " (Draft)" if pr.get("draft") else ""
681+
result_parts.append(
682+
f"• **#{pr['number']}** - {pr['title']}{draft_str} — by {pr.get('user', 'unknown')} ({pr['head']}{pr['base']})"
683+
)
684+
685+
return ChatToolResponse(result="\n".join(result_parts))
686+
687+
except Exception as e:
688+
log(f"Error listing PRs: {e}")
689+
return ChatToolResponse(error=f"Failed to list pull requests: {str(e)}")
690+
691+
692+
@app.post("/tools/merge_pr", tags=["chat_tools"], response_model=ChatToolResponse)
693+
async def tool_merge_pr(request: Request):
694+
"""
695+
Merge a pull request in a GitHub repository.
696+
"""
697+
try:
698+
body = await request.json()
699+
uid = body.get("uid")
700+
pr_number = body.get("pr_number")
701+
repo = body.get("repo")
702+
merge_method = body.get("merge_method", "squash")
703+
704+
if not uid:
705+
return ChatToolResponse(error="User ID is required")
706+
707+
if not pr_number:
708+
return ChatToolResponse(error="Pull request number is required")
709+
710+
if merge_method not in ("squash", "merge", "rebase"):
711+
return ChatToolResponse(error="merge_method must be 'squash', 'merge', or 'rebase'")
712+
713+
user = SimpleUserStorage.get_user(uid)
714+
if not user or not user.get("access_token"):
715+
return ChatToolResponse(
716+
error="Please connect your GitHub account first in the app settings."
717+
)
718+
719+
repo_full_name, error = get_repo_for_request(user, repo)
720+
if error:
721+
return ChatToolResponse(error=error)
722+
723+
access_token = user["access_token"]
724+
725+
# Check permissions
726+
permissions = github_client.get_repo_permissions(access_token, repo_full_name)
727+
if not permissions or not (permissions.get("push") or permissions.get("admin")):
728+
return ChatToolResponse(
729+
error="You don't have write access to this repository. Cannot merge PRs."
730+
)
731+
732+
# Get PR details first to validate it exists and is open
733+
pr = github_client.get_pull_request(access_token, repo_full_name, int(pr_number))
734+
if not pr:
735+
return ChatToolResponse(error=f"Pull request #{pr_number} not found in {repo_full_name}")
736+
737+
if pr.get("merged"):
738+
return ChatToolResponse(
739+
result=f"Pull request **#{pr_number}** ({pr['title']}) is already merged."
740+
)
741+
742+
if pr["state"] != "open":
743+
return ChatToolResponse(
744+
error=f"Pull request **#{pr_number}** is {pr['state']}. Only open PRs can be merged."
745+
)
746+
747+
if pr.get("draft"):
748+
return ChatToolResponse(
749+
error=f"Pull request **#{pr_number}** is a draft. Please mark it as ready for review before merging."
750+
)
751+
752+
# Merge the PR
753+
log(f"Merging PR #{pr_number} in {repo_full_name} using {merge_method}...")
754+
result = github_client.merge_pull_request(
755+
access_token=access_token,
756+
repo_full_name=repo_full_name,
757+
pr_number=int(pr_number),
758+
merge_method=merge_method
759+
)
760+
761+
if result.get("success"):
762+
return ChatToolResponse(
763+
result=f"**PR #{pr_number} Merged!**\n\n"
764+
f"**{pr['title']}**\n"
765+
f"Merged `{pr['head']}` into `{pr['base']}` ({merge_method})\n"
766+
f"URL: {pr['url']}"
767+
)
768+
else:
769+
return ChatToolResponse(
770+
error=f"Failed to merge PR #{pr_number}: {result.get('error', 'Unknown error')}"
771+
)
772+
773+
except Exception as e:
774+
log(f"Error merging PR: {e}")
775+
return ChatToolResponse(error=f"Failed to merge pull request: {str(e)}")
776+
777+
591778
# ============================================
592779
# OAuth & Setup Endpoints
593780
# ============================================

0 commit comments

Comments
 (0)