Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 11 additions & 33 deletions src/codegen/git/utils/pr_review.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import TYPE_CHECKING

import requests
from github import Repository
from github.PullRequest import PullRequest
from unidiff import PatchSet
Expand Down Expand Up @@ -39,28 +38,6 @@ def get_file_to_changed_ranges(pull_patch_set: PatchSet) -> dict[str, list]:
return file_to_changed_ranges


def get_pull_patch_set(op: RepoOperator, pull: PullRequestContext) -> PatchSet:
# Get the diff directly from GitHub's API
if not op.remote_git_repo:
msg = "GitHub API client is required to get PR diffs"
raise ValueError(msg)

# Get the diff directly from the PR
diff_url = pull.raw_data.get("diff_url")
if diff_url:
# Fetch the diff content from the URL
response = requests.get(diff_url)
response.raise_for_status()
diff = response.text
else:
# If diff_url not available, get the patch directly
diff = pull.get_patch()

# Parse the diff into a PatchSet
pull_patch_set = PatchSet(diff)
return pull_patch_set


def to_1_indexed(zero_indexed_range: range) -> range:
"""Converts a n-indexed range to n+1-indexed.
Primarily to convert 0-indexed ranges to 1 indexed
Expand Down Expand Up @@ -131,7 +108,7 @@ def __init__(self, op: RepoOperator, codebase: "Codebase", pr: PullRequest):
def modified_file_ranges(self) -> dict[str, list[tuple[int, int]]]:
"""Files and the ranges within that are modified"""
if not self._modified_file_ranges:
pull_patch_set = get_pull_patch_set(op=self._op, pull=self._gh_pr)
pull_patch_set = self.get_pull_patch_set()
self._modified_file_ranges = get_file_to_changed_ranges(pull_patch_set)
return self._modified_file_ranges

Expand Down Expand Up @@ -174,15 +151,16 @@ def get_pr_diff(self) -> str:
raise ValueError(msg)

# Get the diff directly from the PR
diff_url = self._gh_pr.raw_data.get("diff_url")
if diff_url:
# Fetch the diff content from the URL
response = requests.get(diff_url)
response.raise_for_status()
return response.text
else:
# If diff_url not available, get the patch directly
return self._gh_pr.get_patch()
status, _, res = self._op.remote_git_repo.repo._requester.requestJson("GET", self._gh_pr.url, headers={"Accept": "application/vnd.github.v3.diff"})
if status != 200:
msg = f"Failed to get PR diff: {res}"
raise Exception(msg)
return res

def get_pull_patch_set(self) -> PatchSet:
diff = self.get_pr_diff()
pull_patch_set = PatchSet(diff)
return pull_patch_set

def get_commit_sha(self) -> str:
"""Get the commit SHA of the PR"""
Expand Down
Loading