Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion patchwork/common/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from patchwork.common.tools.bash_tool import BashTool
from patchwork.common.tools.code_edit_tools import CodeEditTool
from patchwork.common.tools.code_edit_tools import CodeEditTool, FileViewTool
from patchwork.common.tools.grep_tool import FindTextTool, FindTool
from patchwork.common.tools.tool import Tool

__all__ = [
"Tool",
"CodeEditTool",
"BashTool",
"FileViewTool",
"FindTool",
"FindTextTool",
]
6 changes: 3 additions & 3 deletions patchwork/common/tools/bash_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import subprocess
from pathlib import Path

from typing_extensions import Optional

from patchwork.common.tools.tool import Tool


Expand Down Expand Up @@ -35,9 +37,7 @@ def json_schema(self) -> dict:

def execute(
self,
command: str | None = None,
*args,
**kwargs,
command: Optional[str] = None,
) -> str:
"""Execute editor commands on files in the repository."""
if command is None:
Expand Down
144 changes: 90 additions & 54 deletions patchwork/common/tools/code_edit_tools.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,94 @@
from __future__ import annotations

from pathlib import Path
from typing import Literal

from typing_extensions import Literal, Optional, Union

from patchwork.common.tools.tool import Tool
from patchwork.common.utils.utils import detect_newline


class FileViewTool(Tool, tool_name="file_view"):
__TRUNCATION_TOKEN = "<TRUNCATED>"
__VIEW_LIMIT = 3000

def __init__(self, path: Union[Path, str]):
super().__init__()
self.repo_path = Path(path).resolve()

@property
def json_schema(self) -> dict:
return {
"name": "file_view",
"description": f"""\
Custom tool for viewing files

* If `path` is a file, `view` displays the result of applying `cat -n` up to {self.__VIEW_LIMIT} characters. If `path` is a directory, `view` lists non-hidden files and directories.
* The output is too lone, it will be truncated and marked with `{self.__TRUNCATION_TOKEN}`
* The working directory is always {self.repo_path}
""",
"input_schema": {
"type": "object",
"properties": {
"path": {
"description": "Absolute path to file or directory, e.g. `/repo/file.py` or `/repo`.",
"type": "string",
},
"view_range": {
"description": "Optional parameter when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.",
"items": {"type": "integer"},
"type": "array",
},
},
"required": ["path"],
},
}

def __get_abs_path(self, path: str):
wanted_path = Path(path).resolve()
if wanted_path.is_relative_to(self.repo_path):
return wanted_path
else:
raise ValueError(f"Path {path} contains illegal path traversal")

def execute(self, path: str, view_range: Optional[list[int]] = None) -> str:
abs_path = self.__get_abs_path(path)
if not abs_path.exists():
return f"Error: Path {abs_path} does not exist"

if abs_path.is_file():
with open(abs_path, "r") as f:
content = f.read()

if view_range:
lines = content.splitlines()
start, end = view_range
content = "\n".join(lines[start - 1 : end])

if len(content) > self.__VIEW_LIMIT:
content = content[: self.__VIEW_LIMIT] + self.__TRUNCATION_TOKEN
return content
elif abs_path.is_dir():
directories = []
files = []
for file in abs_path.iterdir():
directories.append(file.name) if file.is_dir() else files.append(file.name)

rv = ""
if len(directories) > 0:
rv += "Directories: \n"
rv += "\n".join(directories)
rv += "\n"

if len(files) > 0:
rv += "Files: \n"
rv += "\n".join(files)

return rv


class CodeEditTool(Tool, tool_name="code_edit_tool"):
def __init__(self, path: Path | str):
def __init__(self, path: Union[Path, str]):
super().__init__()
self.repo_path = Path(path).resolve()
self.modified_files = set()
Expand All @@ -21,9 +101,7 @@ def json_schema(self) -> dict:
Custom editing tool for viewing, creating and editing files

* State is persistent across command calls and discussions with the user
* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep
* The `create` command cannot be used if the specified `path` already exists as a file
* If a `command` generates a long output, it will be truncated and marked with `<response clipped>`
* The working directory is always {self.repo_path}

Notes for using the `str_replace` command:
Expand All @@ -35,8 +113,8 @@ def json_schema(self) -> dict:
"properties": {
"command": {
"type": "string",
"enum": ["view", "create", "str_replace", "insert"],
"description": "The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`.",
"enum": ["create", "str_replace", "insert"],
"description": "The commands to run. Allowed options are: `create`, `str_replace`, `insert`.",
},
"file_text": {
"description": "Required parameter of `create` command, with the content of the file to be created.",
Expand All @@ -58,27 +136,19 @@ def json_schema(self) -> dict:
"description": "Absolute path to file or directory, e.g. `/repo/file.py` or `/repo`.",
"type": "string",
},
"view_range": {
"description": "Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.",
"items": {"type": "integer"},
"type": "array",
},
},
"required": ["command", "path"],
},
}

def execute(
self,
command: Literal["view", "create", "str_replace", "insert"] | None = None,
command: Optional[Literal["create", "str_replace", "insert"]] = None,
file_text: str = "",
insert_line: int | None = None,
insert_line: Optional[int] = None,
new_str: str = "",
old_str: str | None = None,
path: str | None = None,
view_range: list[int] | None = None,
*args,
**kwargs,
old_str: Optional[str] = None,
path: Optional[str] = None,
) -> str:
"""Execute editor commands on files in the repository."""
required_dict = dict(command=command, path=path)
Expand All @@ -88,9 +158,7 @@ def execute(

try:
abs_path = self.__get_abs_path(path)
if command == "view":
result = self.__view(abs_path, view_range)
elif command == "create":
if command == "create":
result = self.__create(file_text, abs_path)
elif command == "str_replace":
result = self.__str_replace(new_str, old_str, abs_path)
Expand All @@ -101,9 +169,8 @@ def execute(
except Exception as e:
return f"Error: {str(e)}"

if command in {"create", "str_replace", "insert"}:
self.modified_files.update({abs_path})

self.modified_files.update({abs_path})
return result

@property
Expand All @@ -117,37 +184,6 @@ def __get_abs_path(self, path: str):
else:
raise ValueError(f"Path {path} contains illegal path traversal")

def __view(self, abs_path: Path, view_range):
if not abs_path.exists():
return f"Error: Path {abs_path} does not exist"

if abs_path.is_file():
with open(abs_path, "r") as f:
content = f.read()

if view_range:
lines = content.splitlines()
start, end = view_range
content = "\n".join(lines[start - 1 : end])
return content
elif abs_path.is_dir():
directories = []
files = []
for file in abs_path.iterdir():
directories.append(file.name) if file.is_dir() else files.append(file.name)

rv = ""
if len(directories) > 0:
rv += "Directories: \n"
rv += "\n".join(directories)
rv += "\n"

if len(files) > 0:
rv += "Files: \n"
rv += "\n".join(files)

return rv

def __create(self, file_text, abs_path):
if abs_path.exists():
return f"Error: File {abs_path} already exists"
Expand Down
Loading