Skip to content

Commit 5b0d48e

Browse files
committed
update
1 parent 748c689 commit 5b0d48e

File tree

8 files changed

+179
-43
lines changed

8 files changed

+179
-43
lines changed

patchwork/steps/ResolveIssue/ResolveIssue.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import re
22
from pathlib import Path
3-
from typing import Any
3+
from typing import Any, Optional
44

55
from git import Repo
66
from openai.types.chat import ChatCompletionMessageParam
@@ -20,7 +20,7 @@
2020

2121
class _ResolveIssue(AnalyzeImplementStrategy):
2222
def __init__(self, repo_path: str, llm_client: LlmClient, issue_description: Any, **kwargs):
23-
self.tool_set = Tool.get_tools(repo_path=repo_path)
23+
self.tool_set = Tool.get_tools(path=repo_path)
2424
super().__init__(
2525
llm_client=llm_client,
2626
initial_template_data=dict(issue=issue_description),
@@ -74,10 +74,10 @@ def __init__(self, repo_path: str, llm_client: LlmClient, issue_description: Any
7474
**kwargs,
7575
)
7676

77-
def extract_analysis_message(self, message: ChatCompletionMessageParam) -> dict[str, str]:
77+
def extract_analysis_message(self, message: ChatCompletionMessageParam) -> Optional[dict[str, str]]:
7878
analysis_match = re.search(r"<analysis>(.*?)</analysis>", message.get("content"), re.DOTALL)
7979
if not analysis_match:
80-
return dict()
80+
return None
8181

8282
content = analysis_match.group(1)
8383
sections = dict()
@@ -112,12 +112,11 @@ def __init__(self, inputs):
112112
"\n"
113113
"If you are using an OpenAI API Key, please set `--openai_api_key=<token>`.\n"
114114
)
115-
issue_description = inputs.get("issue_description")
116115

117116
self.multiturn_llm_call = _ResolveIssue(
118117
repo_path=self.base_path,
119118
llm_client=llm_client,
120-
issue_description=issue_description,
119+
issue_description=inputs["issue_description"],
121120
)
122121

123122
def run(self):

patchwork/steps/ResolveIssue/multiturn_strategy/analyze_implement.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
implementation_prompt_template: str,
3232
**kwargs,
3333
):
34-
super().__init__(tool_set, **kwargs)
34+
super().__init__(tool_set=tool_set, **kwargs)
3535
self.llm_client = llm_client
3636
self.template_data = initial_template_data
3737
self.analysis_prompt_template = analysis_prompt_template
@@ -54,7 +54,8 @@ def __run_prompt(self, messages: list[ChatCompletionMessageParam]) -> list[ChatC
5454
if is_prompt_safe < 0:
5555
raise ValueError("The subsequent prompt is not supported, due to large size.")
5656
response = self.llm_client.chat_completion(**input_kwargs)
57-
messages.append(response.choices[0].message.to_dict())
57+
new_messages = [choice.message.to_dict() for choice in response.choices]
58+
messages.extend(new_messages)
5859
return messages
5960

6061
def __render_prompt(self, prompt: str) -> str:
@@ -81,12 +82,12 @@ def run_subsequent_prompt(self, messages: list[ChatCompletionMessageParam]) -> l
8182
self.template_data["analysis_results"] = possible_analysis_message
8283
implement_prompt = self.__render_prompt(self.implementation_prompt_template)
8384
messages = [dict(role="user", content=implement_prompt)]
84-
else:
85-
if last_message.get("tool_calls") is not None:
86-
tool_messages = self.execute_tools(last_message)
87-
messages.extend(tool_messages)
88-
messages.append(dict(role="user", content=f"Continue with the {self._stage.name.lower()} stage."))
85+
return self.__run_prompt(messages)
8986

87+
if last_message.get("tool_calls") is not None:
88+
tool_messages = self.execute_tools(last_message)
89+
messages.extend(tool_messages)
90+
messages.append(dict(role="user", content=f"Continue with the {self._stage.name.upper()} stage."))
9091
return self.__run_prompt(messages)
9192

9293
@abstractmethod

patchwork/steps/ResolveIssue/multiturn_strategy/multiturn_strategy.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,38 +25,40 @@ def get_tools_spec(self) -> list[ChatCompletionToolParam]:
2525
return [
2626
dict(
2727
type="function",
28-
function=dict(name=k, **v.json_schema),
28+
function={"name": k, **v.json_schema},
2929
)
3030
for k, v in self.tool_set.items()
3131
]
3232

3333
@abstractmethod
34-
def run_initial_prompt(self) -> list[ChatCompletionMessageParam]:
34+
def run_initial_prompt(self) -> tuple[ChatCompletionMessage, list[ChatCompletionMessageParam]]:
3535
pass
3636

3737
@abstractmethod
38-
def run_subsequent_prompt(self, messages: list[ChatCompletionMessageParam]) -> ChatCompletionMessage:
38+
def run_subsequent_prompt(self, messages: list[ChatCompletionMessageParam]) -> list[ChatCompletionMessageParam]:
3939
pass
4040

4141
def is_stop(self, messages: list[ChatCompletionMessageParam]) -> bool:
4242
return False
4343

44-
def execute_tools(self, chat_completion_message: ChatCompletionMessageParam) -> list[ChatCompletionMessageParam]:
44+
def execute_tools(self, last_message: ChatCompletionMessageParam) -> list[ChatCompletionMessageParam]:
4545
rv = []
46-
for tool_call in chat_completion_message.tool_calls:
47-
tooling_to_use = self.tooling.get(tool_call.function.name, None)
48-
if tooling_to_use is None:
46+
for tool_call in last_message.get("tool_calls", []):
47+
tool_name_to_use = tool_call.get("function", {}).get("name")
48+
tool_to_use = self.tool_set.get(tool_name_to_use, None)
49+
if tool_to_use is None:
4950
logging.info("LLM just used an non-existent tool!")
5051
continue
5152

52-
logging.info(f"Running tool: {tool_call.function.name}")
53+
logging.info(f"Running tool: {tool_name_to_use}")
5354
try:
54-
tool_kwargs = json.loads(tool_call.function.arguments)
55-
tooling_output = tooling_to_use.execute(**tool_kwargs)
55+
tool_arguments = tool_call.get("function", {}).get("arguments", "{}")
56+
tool_kwargs = json.loads(tool_arguments)
57+
tool_output = tool_to_use.execute(**tool_kwargs)
5658
except JSONDecodeError:
57-
tooling_output = "Arguments must be passed through a valid JSON object"
59+
tool_output = "Arguments must be passed through a valid JSON object"
5860

59-
rv.append({"tool_call_id": tool_call.id, "role": "tool", "content": tooling_output})
61+
rv.append({"tool_call_id": tool_call.get("id", ""), "role": "tool", "content": tool_output})
6062

6163
return rv
6264

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
from __future__ import annotations
2+
3+
import os
4+
import subprocess
5+
from pathlib import Path
6+
from typing import Literal
7+
8+
from patchwork.steps.ResolveIssue.tools.tool import Tool
9+
10+
11+
class BashTool(Tool, tool_name="bash"):
12+
def __init__(self, path: str):
13+
super().__init__()
14+
self.path = Path(path)
15+
self.modified_files = []
16+
17+
@property
18+
def json_schema(self) -> dict:
19+
return {
20+
"name": "bash",
21+
"description": """Run commands in a bash shell
22+
23+
* When invoking this tool, the contents of the "command" parameter does NOT need to be XML-escaped.
24+
* You don't have access to the internet via this tool.
25+
* You do have access to a mirror of common linux and python packages via apt and pip.
26+
* State is persistent across command calls and discussions with the user.
27+
* To inspect a particular line range of a file, e.g. lines 10-25, try 'sed -n 10,25p /path/to/the/file'.
28+
* Please avoid commands that may produce a very large amount of output.
29+
* Please run long lived commands in the background, e.g. 'sleep 10 &' or start a server in the background.""",
30+
"input_schema": {
31+
"type": "object",
32+
"properties": {
33+
"command": {
34+
"type": "string",
35+
"description": "The bash command to run."
36+
}
37+
},
38+
"required": ["command"]
39+
}
40+
}
41+
42+
def execute(
43+
self,
44+
command: str | None = None,
45+
*args,
46+
**kwargs,
47+
) -> str:
48+
"""Execute editor commands on files in the repository."""
49+
if command is None:
50+
return f"Error: `command` parameter must be set and cannot be empty"
51+
52+
try:
53+
result = subprocess.run(
54+
command,
55+
shell=True,
56+
cwd=self.path,
57+
capture_output=True,
58+
text=True,
59+
timeout=60 # Add timeout for safety
60+
)
61+
return result.stdout if result.returncode == 0 else f"Error: {result.stderr}"
62+
except subprocess.TimeoutExpired:
63+
return "Error: Command timed out after 60 seconds"
64+
except Exception as e:
65+
return f"Error: {str(e)}"
66+
67+
@property
68+
def tool_records(self):
69+
return dict(modified_files=[{"path": file} for file in self.modified_files])
70+
71+
def __get_abs_path(self, path: str):
72+
abs_path = (self.repo_path / path.lstrip("/")).resolve()
73+
if not abs_path.is_relative_to(self.repo_path):
74+
raise ValueError(f"Path {path} contains illegal path traversal")
75+
76+
return abs_path
77+
78+
def __view(self, path, view_range):
79+
abs_path = self.__get_abs_path(path)
80+
if not abs_path.exists():
81+
return f"Error: Path {path} does not exist"
82+
83+
if abs_path.is_file():
84+
with open(abs_path, "r") as f:
85+
content = f.read()
86+
87+
if view_range:
88+
lines = content.splitlines()
89+
start, end = view_range
90+
content = "\n".join(lines[start - 1 : end])
91+
return content
92+
elif abs_path.is_dir():
93+
result = []
94+
for root, dirs, files in os.walk(abs_path):
95+
level = root[len(abs_path) :].count(os.sep)
96+
if level <= 2:
97+
for d in dirs:
98+
result.append(d)
99+
for f in files:
100+
result.append(f)
101+
return "\n".join(result)
102+
103+
def __create(self, file_text, path):
104+
abs_path = self.__get_abs_path(path)
105+
if abs_path.exists():
106+
return f"Error: File {path} already exists"
107+
abs_path.parent.mkdir(parents=True, exist_ok=True)
108+
abs_path.write_text(file_text)
109+
return f"File created successfully at: {path}"
110+
111+
def __str_replace(self, new_str, old_str, path):
112+
abs_path = self.__get_abs_path(path)
113+
if not abs_path.exists():
114+
return f"Error: File {path} does not exist"
115+
if not abs_path.is_file():
116+
return f"Error: File {path} is not a file"
117+
content = abs_path.read_text()
118+
occurrences = content.count(old_str)
119+
if occurrences != 1:
120+
return f"Error: Found {occurrences} occurrences of old_str, expected exactly 1"
121+
new_content = content.replace(old_str, new_str)
122+
with open(abs_path, "w") as f:
123+
f.write(new_content)
124+
return "Replacement successful"
125+
126+
def __insert(self, insert_line, new_str, path):
127+
abs_path = self.__get_abs_path(path)
128+
if not abs_path.is_file():
129+
return f"Error: File {path} does not exist or is not a file"
130+
131+
lines = abs_path.read_text().splitlines(keepends=True)
132+
if insert_line is None or insert_line < 1 or insert_line > len(lines):
133+
return f"Error: Invalid insert line {insert_line}"
134+
135+
lines.insert(insert_line, new_str + "\n")
136+
abs_path.write_text("".join(lines))
137+
return "Insert successful"

patchwork/steps/ResolveIssue/tools/code_edit_tools.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
from patchwork.steps.ResolveIssue.tools.tool import Tool
88

99

10-
class CodeEditTool(Tool):
11-
def __init__(self, repo_path: Path):
10+
class CodeEditTool(Tool, tool_name="code_edit_tool"):
11+
def __init__(self, path: str):
1212
super().__init__()
13-
self.repo_path = repo_path
13+
self.repo_path = Path(path)
1414
self.modified_files = []
1515

1616
@property

patchwork/steps/ResolveIssue/tools/tool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ class Tool(ABC):
66
__internal_map: dict[str, Type["Tool"]] = dict()
77

88
def __init_subclass__(cls, **kwargs):
9-
cls_name = kwargs.get("name", cls.__name__)
9+
cls_name = kwargs.get("tool_name", cls.__name__)
1010
if cls_name in cls.__internal_map.keys():
1111
raise ValueError(f"Duplicate subclass name for class {cls.__name__}: {cls_name}")
1212
cls.name = cls_name
@@ -24,7 +24,7 @@ def execute(self, *args, **kwargs) -> str:
2424
@staticmethod
2525
def get_tools(**kwargs) -> dict[str, "Tool"]:
2626
rv = dict()
27-
for k, v in kwargs.items():
27+
for k, v in Tool.__internal_map.items():
2828
try:
2929
rv[k] = v(**kwargs)
3030
except Exception as e:

poetry.lock

Lines changed: 9 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,13 @@ requests = "~2.31.0"
4141
chardet = "~5.2.0"
4242
attrs = "~23.2.0"
4343
google-generativeai = "~0.8.1"
44-
anthropic = "~0.34.2"
44+
anthropic = "~0.40.0"
4545
rich = "~13.7.1"
4646
chevron = "~0.14.0"
4747
giturlparse = "~0.12.0"
4848
scikit-learn = "^1.3.2"
4949
json-repair = "~0.30.0"
5050
# pinning transitive dependencies
51-
httpx = "<0.28.0"
5251
tree-sitter = "~0.21.3"
5352
numpy = [
5453
{ version = "^1.26", python = "^3.12", optional = true},

0 commit comments

Comments
 (0)