Skip to content

Commit 2d28160

Browse files
committed
add new tools
1 parent 4cc8d3e commit 2d28160

File tree

5 files changed

+301
-61
lines changed

5 files changed

+301
-61
lines changed

patchwork/common/tools/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from patchwork.common.tools.bash_tool import BashTool
2-
from patchwork.common.tools.code_edit_tools import CodeEditTool
2+
from patchwork.common.tools.code_edit_tools import CodeEditTool, FileViewTool
3+
from patchwork.common.tools.grep_tool import FindTextTool, FindTool
34
from patchwork.common.tools.tool import Tool
45

56
__all__ = [
67
"Tool",
78
"CodeEditTool",
89
"BashTool",
10+
"FileViewTool",
11+
"FindTool",
12+
"FindTextTool",
913
]

patchwork/common/tools/bash_tool.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import subprocess
44
from pathlib import Path
55

6+
from typing_extensions import Optional
7+
68
from patchwork.common.tools.tool import Tool
79

810

@@ -35,9 +37,7 @@ def json_schema(self) -> dict:
3537

3638
def execute(
3739
self,
38-
command: str | None = None,
39-
*args,
40-
**kwargs,
40+
command: Optional[str] = None,
4141
) -> str:
4242
"""Execute editor commands on files in the repository."""
4343
if command is None:

patchwork/common/tools/code_edit_tools.py

Lines changed: 90 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,94 @@
11
from __future__ import annotations
22

33
from pathlib import Path
4-
from typing import Literal
4+
5+
from typing_extensions import Literal, Optional, Union
56

67
from patchwork.common.tools.tool import Tool
78
from patchwork.common.utils.utils import detect_newline
89

910

11+
class FileViewTool(Tool, tool_name="file_view"):
12+
__TRUNCATION_TOKEN = "<TRUNCATED>"
13+
__VIEW_LIMIT = 3000
14+
15+
def __init__(self, path: Union[Path, str]):
16+
super().__init__()
17+
self.repo_path = Path(path).resolve()
18+
19+
@property
20+
def json_schema(self) -> dict:
21+
return {
22+
"name": "file_view",
23+
"description": f"""\
24+
Custom tool for viewing files
25+
26+
* If `path` is a file, `view` displays the result of applying `cat -n` up to {self.__VIEW_LIMIT} characters. If `path` is a directory, `view` lists non-hidden files and directories.
27+
* The output is too lone, it will be truncated and marked with `{self.__TRUNCATION_TOKEN}`
28+
* The working directory is always {self.repo_path}
29+
""",
30+
"input_schema": {
31+
"type": "object",
32+
"properties": {
33+
"path": {
34+
"description": "Absolute path to file or directory, e.g. `/repo/file.py` or `/repo`.",
35+
"type": "string",
36+
},
37+
"view_range": {
38+
"description": "Optional parameter when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.",
39+
"items": {"type": "integer"},
40+
"type": "array",
41+
},
42+
},
43+
"required": ["path"],
44+
},
45+
}
46+
47+
def __get_abs_path(self, path: str):
48+
wanted_path = Path(path).resolve()
49+
if wanted_path.is_relative_to(self.repo_path):
50+
return wanted_path
51+
else:
52+
raise ValueError(f"Path {path} contains illegal path traversal")
53+
54+
def execute(self, path: str, view_range: Optional[list[int]] = None) -> str:
55+
abs_path = self.__get_abs_path(path)
56+
if not abs_path.exists():
57+
return f"Error: Path {abs_path} does not exist"
58+
59+
if abs_path.is_file():
60+
with open(abs_path, "r") as f:
61+
content = f.read()
62+
63+
if view_range:
64+
lines = content.splitlines()
65+
start, end = view_range
66+
content = "\n".join(lines[start - 1 : end])
67+
68+
if len(content) > self.__VIEW_LIMIT:
69+
content = content[: self.__VIEW_LIMIT] + self.__TRUNCATION_TOKEN
70+
return content
71+
elif abs_path.is_dir():
72+
directories = []
73+
files = []
74+
for file in abs_path.iterdir():
75+
directories.append(file.name) if file.is_dir() else files.append(file.name)
76+
77+
rv = ""
78+
if len(directories) > 0:
79+
rv += "Directories: \n"
80+
rv += "\n".join(directories)
81+
rv += "\n"
82+
83+
if len(files) > 0:
84+
rv += "Files: \n"
85+
rv += "\n".join(files)
86+
87+
return rv
88+
89+
1090
class CodeEditTool(Tool, tool_name="code_edit_tool"):
11-
def __init__(self, path: Path | str):
91+
def __init__(self, path: Union[Path, str]):
1292
super().__init__()
1393
self.repo_path = Path(path).resolve()
1494
self.modified_files = set()
@@ -21,9 +101,7 @@ def json_schema(self) -> dict:
21101
Custom editing tool for viewing, creating and editing files
22102
23103
* State is persistent across command calls and discussions with the user
24-
* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep
25104
* The `create` command cannot be used if the specified `path` already exists as a file
26-
* If a `command` generates a long output, it will be truncated and marked with `<response clipped>`
27105
* The working directory is always {self.repo_path}
28106
29107
Notes for using the `str_replace` command:
@@ -35,8 +113,8 @@ def json_schema(self) -> dict:
35113
"properties": {
36114
"command": {
37115
"type": "string",
38-
"enum": ["view", "create", "str_replace", "insert"],
39-
"description": "The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`.",
116+
"enum": ["create", "str_replace", "insert"],
117+
"description": "The commands to run. Allowed options are: `create`, `str_replace`, `insert`.",
40118
},
41119
"file_text": {
42120
"description": "Required parameter of `create` command, with the content of the file to be created.",
@@ -58,27 +136,19 @@ def json_schema(self) -> dict:
58136
"description": "Absolute path to file or directory, e.g. `/repo/file.py` or `/repo`.",
59137
"type": "string",
60138
},
61-
"view_range": {
62-
"description": "Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.",
63-
"items": {"type": "integer"},
64-
"type": "array",
65-
},
66139
},
67140
"required": ["command", "path"],
68141
},
69142
}
70143

71144
def execute(
72145
self,
73-
command: Literal["view", "create", "str_replace", "insert"] | None = None,
146+
command: Optional[Literal["create", "str_replace", "insert"]] = None,
74147
file_text: str = "",
75-
insert_line: int | None = None,
148+
insert_line: Optional[int] = None,
76149
new_str: str = "",
77-
old_str: str | None = None,
78-
path: str | None = None,
79-
view_range: list[int] | None = None,
80-
*args,
81-
**kwargs,
150+
old_str: Optional[str] = None,
151+
path: Optional[str] = None,
82152
) -> str:
83153
"""Execute editor commands on files in the repository."""
84154
required_dict = dict(command=command, path=path)
@@ -88,9 +158,7 @@ def execute(
88158

89159
try:
90160
abs_path = self.__get_abs_path(path)
91-
if command == "view":
92-
result = self.__view(abs_path, view_range)
93-
elif command == "create":
161+
if command == "create":
94162
result = self.__create(file_text, abs_path)
95163
elif command == "str_replace":
96164
result = self.__str_replace(new_str, old_str, abs_path)
@@ -101,9 +169,8 @@ def execute(
101169
except Exception as e:
102170
return f"Error: {str(e)}"
103171

104-
if command in {"create", "str_replace", "insert"}:
105-
self.modified_files.update({abs_path})
106172

173+
self.modified_files.update({abs_path})
107174
return result
108175

109176
@property
@@ -117,37 +184,6 @@ def __get_abs_path(self, path: str):
117184
else:
118185
raise ValueError(f"Path {path} contains illegal path traversal")
119186

120-
def __view(self, abs_path: Path, view_range):
121-
if not abs_path.exists():
122-
return f"Error: Path {abs_path} does not exist"
123-
124-
if abs_path.is_file():
125-
with open(abs_path, "r") as f:
126-
content = f.read()
127-
128-
if view_range:
129-
lines = content.splitlines()
130-
start, end = view_range
131-
content = "\n".join(lines[start - 1 : end])
132-
return content
133-
elif abs_path.is_dir():
134-
directories = []
135-
files = []
136-
for file in abs_path.iterdir():
137-
directories.append(file.name) if file.is_dir() else files.append(file.name)
138-
139-
rv = ""
140-
if len(directories) > 0:
141-
rv += "Directories: \n"
142-
rv += "\n".join(directories)
143-
rv += "\n"
144-
145-
if len(files) > 0:
146-
rv += "Files: \n"
147-
rv += "\n".join(files)
148-
149-
return rv
150-
151187
def __create(self, file_text, abs_path):
152188
if abs_path.exists():
153189
return f"Error: File {abs_path} already exists"

0 commit comments

Comments
 (0)