diff --git a/src/codegen/git/utils/pr_review.py b/src/codegen/git/utils/pr_review.py index 4ebdc204a..ffb3f52f0 100644 --- a/src/codegen/git/utils/pr_review.py +++ b/src/codegen/git/utils/pr_review.py @@ -1,6 +1,5 @@ from typing import TYPE_CHECKING -import requests from github import Repository from github.PullRequest import PullRequest from unidiff import PatchSet @@ -39,28 +38,6 @@ def get_file_to_changed_ranges(pull_patch_set: PatchSet) -> dict[str, list]: return file_to_changed_ranges -def get_pull_patch_set(op: RepoOperator, pull: PullRequestContext) -> PatchSet: - # Get the diff directly from GitHub's API - if not op.remote_git_repo: - msg = "GitHub API client is required to get PR diffs" - raise ValueError(msg) - - # Get the diff directly from the PR - diff_url = pull.raw_data.get("diff_url") - if diff_url: - # Fetch the diff content from the URL - response = requests.get(diff_url) - response.raise_for_status() - diff = response.text - else: - # If diff_url not available, get the patch directly - diff = pull.get_patch() - - # Parse the diff into a PatchSet - pull_patch_set = PatchSet(diff) - return pull_patch_set - - def to_1_indexed(zero_indexed_range: range) -> range: """Converts a n-indexed range to n+1-indexed. Primarily to convert 0-indexed ranges to 1 indexed @@ -131,7 +108,7 @@ def __init__(self, op: RepoOperator, codebase: "Codebase", pr: PullRequest): def modified_file_ranges(self) -> dict[str, list[tuple[int, int]]]: """Files and the ranges within that are modified""" if not self._modified_file_ranges: - pull_patch_set = get_pull_patch_set(op=self._op, pull=self._gh_pr) + pull_patch_set = self.get_pull_patch_set() self._modified_file_ranges = get_file_to_changed_ranges(pull_patch_set) return self._modified_file_ranges @@ -174,15 +151,16 @@ def get_pr_diff(self) -> str: raise ValueError(msg) # Get the diff directly from the PR - diff_url = self._gh_pr.raw_data.get("diff_url") - if diff_url: - # Fetch the diff content from the URL - response = requests.get(diff_url) - response.raise_for_status() - return response.text - else: - # If diff_url not available, get the patch directly - return self._gh_pr.get_patch() + status, _, res = self._op.remote_git_repo.repo._requester.requestJson("GET", self._gh_pr.url, headers={"Accept": "application/vnd.github.v3.diff"}) + if status != 200: + msg = f"Failed to get PR diff: {res}" + raise Exception(msg) + return res + + def get_pull_patch_set(self) -> PatchSet: + diff = self.get_pr_diff() + pull_patch_set = PatchSet(diff) + return pull_patch_set def get_commit_sha(self) -> str: """Get the commit SHA of the PR"""