Skip to content

Commit 731cdad

Browse files
authored
PATCHED_API_KEY integration (#12)
* patched-api-key integration * fix gitlab comments issues * fix vulnerability limit * Fix PRReview by skipping some extensions
1 parent 5057603 commit 731cdad

File tree

7 files changed

+363
-299
lines changed

7 files changed

+363
-299
lines changed

patchwork/app.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import importlib
22
import json
3+
import traceback
34
from pathlib import Path
45

56
import click
@@ -73,9 +74,13 @@ def cli(log: str, patchflow: str, opts: list[str], config: str | None, output: s
7374
else:
7475
# treat --key=value as a key-value pair
7576
inputs[key] = value
76-
77-
patchflow_instance = patchflow_class(inputs)
78-
patchflow_instance.run()
77+
try:
78+
patchflow_instance = patchflow_class(inputs)
79+
patchflow_instance.run()
80+
except Exception as e:
81+
logger.debug(traceback.format_exc())
82+
logger.error(f"Error running patchflow {patchflow}: {e}")
83+
exit(1)
7984

8085
data_format_mapping = {
8186
"yaml": yaml.dump,

patchwork/common/client/scm.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,13 @@ def set_pr_description(self, body: str) -> None:
137137
self._mr.save()
138138

139139
def create_comment(
140-
self, path: str, body: str, start_line: int | None = None, end_line: int | None = None
140+
self, body: str, path: str | None = None, start_line: int | None = None, end_line: int | None = None
141141
) -> str | None:
142+
final_body = f"{_COMMENT_MARKER} \n{PullRequestProtocol._apply_pr_template(self, body)}"
143+
if path is None:
144+
note = self._mr.notes.create({"body": final_body})
145+
return f"#note_{note.get_id()}"
146+
142147
while True:
143148
try:
144149
commit = self._mr.commits().next()
@@ -161,7 +166,6 @@ def create_comment(
161166
head_commit = diff.head_commit_sha
162167

163168
try:
164-
final_body = f"{_COMMENT_MARKER} \n{PullRequestProtocol._apply_pr_template(self, body)}"
165169
discussion = self._mr.discussions.create(
166170
{
167171
"body": final_body,
@@ -187,14 +191,19 @@ def create_comment(
187191
return None
188192

189193
def reset_comments(self) -> None:
190-
for discussion in self._mr.discussions.list():
194+
for discussion in self._mr.discussions.list(iterator=True):
191195
for note in discussion.attributes["notes"]:
192-
if note["type"] == "DiffNote" and note["body"].startswith(_COMMENT_MARKER):
196+
if note["body"].startswith(_COMMENT_MARKER):
193197
discussion.notes.delete(note["id"])
194198

195199
def file_diffs(self) -> dict[str, str]:
196-
files = self._mr.diffs.list()
197-
return {file.attributes["new_path"]: file.attributes["diff"] for file in files}
200+
diffs = self._mr.diffs.list()
201+
latest_diff = max(diffs, key=lambda diff: diff.created_at, default=None)
202+
if latest_diff is None:
203+
return {}
204+
205+
files = self._mr.diffs.get(latest_diff.id).diffs
206+
return {file["new_path"]: file["diff"] for file in files}
198207

199208

200209
class GithubPullRequest(PullRequestProtocol):
@@ -336,7 +345,7 @@ def get_pr_by_url(self, url: str) -> PullRequestProtocol | None:
336345
logger.error(f"Invalid PR URL: {url}")
337346
return None
338347

339-
slug = "/".join(url_parts[-4:-2])
348+
slug = "/".join(url_parts[-5:-3])
340349

341350
return self.find_pr_by_id(slug, int(pr_id))
342351

patchwork/steps/CallOpenAI/CallOpenAI.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,12 @@
1010
from patchwork.logger import logger
1111
from patchwork.step import Step
1212

13+
_TOKEN_URL = "https://app.patched.codes/signin"
14+
_DEFAULT_PATCH_URL = "https://patchwork.patched.codes/v1"
15+
1316

1417
class CallOpenAI(Step):
15-
required_keys = {"openai_api_key", "prompt_file"}
18+
required_keys = {"prompt_file"}
1619

1720
def __init__(self, inputs: dict):
1821
logger.info(f"Run started {self.__class__.__name__}")
@@ -27,7 +30,29 @@ def __init__(self, inputs: dict):
2730
self.model_args = {key[len("model_") :]: value for key, value in inputs.items() if key.startswith("model_")}
2831
self.client_args = {key[len("client_") :]: value for key, value in inputs.items() if key.startswith("client_")}
2932

30-
self.openai_api_key = inputs["openai_api_key"]
33+
openai_key = inputs.get("openai_api_key") or os.environ.get("OPENAI_API_KEY")
34+
if openai_key is not None:
35+
self.openai_api_key = openai_key
36+
37+
patched_key = inputs.get("patched_api_key")
38+
if patched_key is not None:
39+
self.openai_api_key = patched_key
40+
self.client_args["base_url"] = _DEFAULT_PATCH_URL
41+
42+
if self.openai_api_key is None:
43+
raise ValueError(
44+
f"Model API key not found.\n"
45+
f'Please login at: "{_TOKEN_URL}",\n'
46+
"Please go to the Integration's tab and generate an API key.\n"
47+
"Please copy the access token that is generated, "
48+
"and add `--patched_api_key=<token>` to the command line.\n"
49+
"\n"
50+
"If you are using a OpenAI API Key, please set `--openai_api_key=<token>`.\n"
51+
)
52+
53+
if not self.openai_api_key:
54+
raise ValueError('Missing required data: "openai_api_key"')
55+
3156
self.prompt_file = Path(inputs["prompt_file"])
3257
if not self.prompt_file.is_file():
3358
raise ValueError(f'Unable to find Prompt file: "{self.prompt_file}"')

patchwork/steps/ExtractCode/ExtractCode.py

Lines changed: 71 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,76 @@ def resolve_artifact_location(
9393
return None
9494

9595

96+
def transform_sarif_results(
97+
sarif_data: dict, base_path: Path, context_length: int, vulnerability_limit: int
98+
) -> dict[tuple[str, int, int, int], list[str]]:
99+
# Process each result in SARIF data
100+
grouped_messages = defaultdict(list)
101+
vulnerability_count = 0
102+
for run_idx, run in enumerate(sarif_data.get("runs", [])):
103+
artifact_locations = [
104+
parse_sarif_location(base_path, artifact["location"]["uri"]) for artifact in run.get("artifacts", [])
105+
]
106+
107+
for result_idx, result in enumerate(run.get("results", [])):
108+
for location_idx, location in enumerate(result.get("locations", [])):
109+
physical_location = location.get("physicalLocation", {})
110+
111+
artifact_location = physical_location.get("artifactLocation", {})
112+
uri = resolve_artifact_location(base_path, artifact_location, artifact_locations)
113+
if uri is None:
114+
logger.warn(
115+
f'Unable to find file for ".runs[{run_idx}].results[{result_idx}].locations[{location_idx}]"'
116+
)
117+
continue
118+
119+
region = physical_location.get("region", {})
120+
start_line = region.get("startLine", 1)
121+
end_line = region.get("endLine", start_line)
122+
start_line = start_line - 1
123+
124+
# Generate file path assuming code is in the current working directory
125+
file_path = str(uri.relative_to(base_path))
126+
127+
# Extract lines from the code file
128+
logger.info(f"Extracting context for {file_path} at {start_line}:{end_line}")
129+
try:
130+
with open_with_chardet(file_path, "r") as file:
131+
src = file.read()
132+
133+
source_lines = src.splitlines(keepends=True)
134+
context_start, context_end = get_source_code_context(
135+
file_path, source_lines, start_line, end_line, context_length
136+
)
137+
138+
source_code_context = None
139+
if context_start is not None and context_end is not None:
140+
source_code_context = "".join(source_lines[context_start:context_end])
141+
142+
except FileNotFoundError:
143+
context_start = None
144+
context_end = None
145+
source_code_context = None
146+
logger.info(f"File not found in the current working directory: {file_path}")
147+
148+
if source_code_context is None:
149+
logger.info(f"No context found for {file_path} at {start_line}:{end_line}")
150+
continue
151+
152+
start = context_start if context_start is not None else start_line
153+
end = context_end if context_end is not None else end_line
154+
155+
grouped_messages[(uri, start, end, source_code_context)].append(
156+
result.get("message", {}).get("text", "")
157+
)
158+
159+
vulnerability_count = vulnerability_count + 1
160+
if 0 < vulnerability_limit <= vulnerability_count:
161+
return grouped_messages
162+
163+
return grouped_messages
164+
165+
96166
class ExtractCode(Step):
97167
required_keys = {"sarif_file_path"}
98168

@@ -112,7 +182,6 @@ def __init__(self, inputs: dict):
112182
self.vulnerability_limit = inputs.get("vulnerability_limit", 10)
113183

114184
# Prepare for data extraction
115-
self.extracted_data = []
116185
self.extracted_code_contexts = []
117186

118187
def run(self) -> dict:
@@ -122,77 +191,8 @@ def run(self) -> dict:
122191

123192
vulnerability_count = 0
124193
base_path = Path.cwd()
125-
# Process each result in SARIF data
126-
grouped_messages = defaultdict(list)
127-
for run_idx, run in enumerate(sarif_data.get("runs", [])):
128-
artifact_locations = [
129-
parse_sarif_location(base_path, artifact["location"]["uri"]) for artifact in run.get("artifacts", [])
130-
]
131-
132-
for result_idx, result in enumerate(run.get("results", [])):
133-
for location_idx, location in enumerate(result.get("locations", [])):
134-
physical_location = location.get("physicalLocation", {})
135-
136-
artifact_location = physical_location.get("artifactLocation", {})
137-
uri = resolve_artifact_location(base_path, artifact_location, artifact_locations)
138-
if uri is None:
139-
logger.warn(
140-
f'Unable to find file for ".runs[{run_idx}].results[{result_idx}].locations[{location_idx}]"'
141-
)
142-
continue
143-
144-
region = physical_location.get("region", {})
145-
start_line = region.get("startLine", 1)
146-
end_line = region.get("endLine", start_line)
147-
start_line = start_line - 1
148-
149-
# Generate file path assuming code is in the current working directory
150-
file_path = str(uri.relative_to(base_path))
151-
152-
# Extract lines from the code file
153-
logger.info(f"Extracting context for {file_path} at {start_line}:{end_line}")
154-
try:
155-
with open_with_chardet(file_path, "r") as file:
156-
src = file.read()
157-
158-
source_lines = src.splitlines(keepends=True)
159-
context_start, context_end = get_source_code_context(
160-
file_path, source_lines, start_line, end_line, self.context_length
161-
)
162-
163-
source_code_context = None
164-
if context_start is not None and context_end is not None:
165-
source_code_context = "".join(source_lines[context_start:context_end])
166-
167-
except FileNotFoundError:
168-
context_start = None
169-
context_end = None
170-
source_code_context = None
171-
logger.info(f"File not found in the current working directory: {file_path}")
172-
173-
if source_code_context is None:
174-
logger.info(f"No context found for {file_path} at {start_line}:{end_line}")
175-
continue
176-
177-
start = context_start if context_start is not None else start_line
178-
end = context_end if context_end is not None else end_line
179-
self.extracted_data.append(
180-
{
181-
"affectedCode": source_code_context,
182-
"startLine": start,
183-
"endLine": end,
184-
"uri": file_path,
185-
"messageText": result.get("message", {}).get("text", ""),
186-
}
187-
)
188-
189-
grouped_messages[(uri, start, end, source_code_context)].append(
190-
result.get("message", {}).get("text", "")
191-
)
192194

193-
vulnerability_count = vulnerability_count + 1
194-
if 0 < self.vulnerability_limit <= vulnerability_count:
195-
break
195+
grouped_messages = transform_sarif_results(sarif_data, base_path, self.context_length, self.vulnerability_limit)
196196

197197
self.extracted_code_contexts = [
198198
{

patchwork/steps/ModifyCode/ModifyCode.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def run(self) -> dict:
7474
code_snippets = load_json_file(self.code_snippets_path)
7575

7676
modified_code_files = []
77-
sorted_list = sorted(zip(code_snippets, self.extracted_responses), key=lambda x: x[0]["startLine"], reverse=True)
77+
sorted_list = sorted(
78+
zip(code_snippets, self.extracted_responses), key=lambda x: x[0]["startLine"], reverse=True
79+
)
7880
for code_snippet, extracted_response in sorted_list:
7981
uri = code_snippet["uri"]
8082
start_line = code_snippet["startLine"]

patchwork/steps/ReadPRDiffs/ReadPRDiffs.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,26 @@
55
from patchwork.logger import logger
66
from patchwork.step import Step
77

8+
_IGNORED_EXTENSIONS = [
9+
".png",
10+
".jpg",
11+
".jpeg",
12+
".gif",
13+
".svg",
14+
".pdf",
15+
".docx",
16+
".xlsx",
17+
".pptx",
18+
".zip",
19+
".tar",
20+
".gz",
21+
".lock",
22+
]
23+
24+
25+
def filter_by_extension(file, extensions):
26+
return any(file.endswith(ext) for ext in extensions)
27+
828

929
class ReadPRDiffs(Step):
1030
required_keys = {"pr_url"}
@@ -30,6 +50,8 @@ def __init__(self, inputs: dict):
3050
def run(self) -> dict:
3151
prompt_values = []
3252
for path, diffs in self.pr.file_diffs().items():
53+
if filter_by_extension(path, _IGNORED_EXTENSIONS):
54+
continue
3355
prompt_values.append(dict(path=path, diff=diffs))
3456

3557
prompt_value_file = tempfile.mktemp(".json")

0 commit comments

Comments
 (0)