|
1 | 1 | from typing import TYPE_CHECKING |
2 | 2 |
|
3 | | -import requests |
4 | 3 | from github import Repository |
5 | 4 | from github.PullRequest import PullRequest |
6 | 5 | from unidiff import PatchSet |
@@ -39,28 +38,6 @@ def get_file_to_changed_ranges(pull_patch_set: PatchSet) -> dict[str, list]: |
39 | 38 | return file_to_changed_ranges |
40 | 39 |
|
41 | 40 |
|
42 | | -def get_pull_patch_set(op: RepoOperator, pull: PullRequestContext) -> PatchSet: |
43 | | - # Get the diff directly from GitHub's API |
44 | | - if not op.remote_git_repo: |
45 | | - msg = "GitHub API client is required to get PR diffs" |
46 | | - raise ValueError(msg) |
47 | | - |
48 | | - # Get the diff directly from the PR |
49 | | - diff_url = pull.raw_data.get("diff_url") |
50 | | - if diff_url: |
51 | | - # Fetch the diff content from the URL |
52 | | - response = requests.get(diff_url) |
53 | | - response.raise_for_status() |
54 | | - diff = response.text |
55 | | - else: |
56 | | - # If diff_url not available, get the patch directly |
57 | | - diff = pull.get_patch() |
58 | | - |
59 | | - # Parse the diff into a PatchSet |
60 | | - pull_patch_set = PatchSet(diff) |
61 | | - return pull_patch_set |
62 | | - |
63 | | - |
64 | 41 | def to_1_indexed(zero_indexed_range: range) -> range: |
65 | 42 | """Converts a n-indexed range to n+1-indexed. |
66 | 43 | Primarily to convert 0-indexed ranges to 1 indexed |
@@ -131,7 +108,7 @@ def __init__(self, op: RepoOperator, codebase: "Codebase", pr: PullRequest): |
131 | 108 | def modified_file_ranges(self) -> dict[str, list[tuple[int, int]]]: |
132 | 109 | """Files and the ranges within that are modified""" |
133 | 110 | if not self._modified_file_ranges: |
134 | | - pull_patch_set = get_pull_patch_set(op=self._op, pull=self._gh_pr) |
| 111 | + pull_patch_set = self.get_pull_patch_set() |
135 | 112 | self._modified_file_ranges = get_file_to_changed_ranges(pull_patch_set) |
136 | 113 | return self._modified_file_ranges |
137 | 114 |
|
@@ -174,15 +151,16 @@ def get_pr_diff(self) -> str: |
174 | 151 | raise ValueError(msg) |
175 | 152 |
|
176 | 153 | # Get the diff directly from the PR |
177 | | - diff_url = self._gh_pr.raw_data.get("diff_url") |
178 | | - if diff_url: |
179 | | - # Fetch the diff content from the URL |
180 | | - response = requests.get(diff_url) |
181 | | - response.raise_for_status() |
182 | | - return response.text |
183 | | - else: |
184 | | - # If diff_url not available, get the patch directly |
185 | | - return self._gh_pr.get_patch() |
| 154 | + status, _, res = self._op.remote_git_repo.repo._requester.requestJson("GET", self._gh_pr.url, headers={"Accept": "application/vnd.github.v3.diff"}) |
| 155 | + if status != 200: |
| 156 | + msg = f"Failed to get PR diff: {res}" |
| 157 | + raise Exception(msg) |
| 158 | + return res |
| 159 | + |
| 160 | + def get_pull_patch_set(self) -> PatchSet: |
| 161 | + diff = self.get_pr_diff() |
| 162 | + pull_patch_set = PatchSet(diff) |
| 163 | + return pull_patch_set |
186 | 164 |
|
187 | 165 | def get_commit_sha(self) -> str: |
188 | 166 | """Get the commit SHA of the PR""" |
|
0 commit comments