diff --git a/patchwork/common/tools/__init__.py b/patchwork/common/tools/__init__.py index e813cb9d9..1501b3687 100644 --- a/patchwork/common/tools/__init__.py +++ b/patchwork/common/tools/__init__.py @@ -1,9 +1,13 @@ from patchwork.common.tools.bash_tool import BashTool -from patchwork.common.tools.code_edit_tools import CodeEditTool +from patchwork.common.tools.code_edit_tools import CodeEditTool, FileViewTool +from patchwork.common.tools.grep_tool import FindTextTool, FindTool from patchwork.common.tools.tool import Tool __all__ = [ "Tool", "CodeEditTool", "BashTool", + "FileViewTool", + "FindTool", + "FindTextTool", ] diff --git a/patchwork/common/tools/bash_tool.py b/patchwork/common/tools/bash_tool.py index 77e9052ae..8440f179a 100644 --- a/patchwork/common/tools/bash_tool.py +++ b/patchwork/common/tools/bash_tool.py @@ -3,6 +3,8 @@ import subprocess from pathlib import Path +from typing_extensions import Optional + from patchwork.common.tools.tool import Tool @@ -35,9 +37,7 @@ def json_schema(self) -> dict: def execute( self, - command: str | None = None, - *args, - **kwargs, + command: Optional[str] = None, ) -> str: """Execute editor commands on files in the repository.""" if command is None: diff --git a/patchwork/common/tools/code_edit_tools.py b/patchwork/common/tools/code_edit_tools.py index 48c7701b0..9197fb9c5 100644 --- a/patchwork/common/tools/code_edit_tools.py +++ b/patchwork/common/tools/code_edit_tools.py @@ -1,14 +1,94 @@ from __future__ import annotations from pathlib import Path -from typing import Literal + +from typing_extensions import Literal, Optional, Union from patchwork.common.tools.tool import Tool from patchwork.common.utils.utils import detect_newline +class FileViewTool(Tool, tool_name="file_view"): + __TRUNCATION_TOKEN = "" + __VIEW_LIMIT = 3000 + + def __init__(self, path: Union[Path, str]): + super().__init__() + self.repo_path = Path(path).resolve() + + @property + def json_schema(self) -> dict: + return { + "name": "file_view", + "description": f"""\ +Custom tool for viewing files + +* If `path` is a file, `view` displays the result of applying `cat -n` up to {self.__VIEW_LIMIT} characters. If `path` is a directory, `view` lists non-hidden files and directories. +* The output is too lone, it will be truncated and marked with `{self.__TRUNCATION_TOKEN}` +* The working directory is always {self.repo_path} +""", + "input_schema": { + "type": "object", + "properties": { + "path": { + "description": "Absolute path to file or directory, e.g. `/repo/file.py` or `/repo`.", + "type": "string", + }, + "view_range": { + "description": "Optional parameter when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.", + "items": {"type": "integer"}, + "type": "array", + }, + }, + "required": ["path"], + }, + } + + def __get_abs_path(self, path: str): + wanted_path = Path(path).resolve() + if wanted_path.is_relative_to(self.repo_path): + return wanted_path + else: + raise ValueError(f"Path {path} contains illegal path traversal") + + def execute(self, path: str, view_range: Optional[list[int]] = None) -> str: + abs_path = self.__get_abs_path(path) + if not abs_path.exists(): + return f"Error: Path {abs_path} does not exist" + + if abs_path.is_file(): + with open(abs_path, "r") as f: + content = f.read() + + if view_range: + lines = content.splitlines() + start, end = view_range + content = "\n".join(lines[start - 1 : end]) + + if len(content) > self.__VIEW_LIMIT: + content = content[: self.__VIEW_LIMIT] + self.__TRUNCATION_TOKEN + return content + elif abs_path.is_dir(): + directories = [] + files = [] + for file in abs_path.iterdir(): + directories.append(file.name) if file.is_dir() else files.append(file.name) + + rv = "" + if len(directories) > 0: + rv += "Directories: \n" + rv += "\n".join(directories) + rv += "\n" + + if len(files) > 0: + rv += "Files: \n" + rv += "\n".join(files) + + return rv + + class CodeEditTool(Tool, tool_name="code_edit_tool"): - def __init__(self, path: Path | str): + def __init__(self, path: Union[Path, str]): super().__init__() self.repo_path = Path(path).resolve() self.modified_files = set() @@ -21,9 +101,7 @@ def json_schema(self) -> dict: Custom editing tool for viewing, creating and editing files * State is persistent across command calls and discussions with the user -* If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep * The `create` command cannot be used if the specified `path` already exists as a file -* If a `command` generates a long output, it will be truncated and marked with `` * The working directory is always {self.repo_path} Notes for using the `str_replace` command: @@ -35,8 +113,8 @@ def json_schema(self) -> dict: "properties": { "command": { "type": "string", - "enum": ["view", "create", "str_replace", "insert"], - "description": "The commands to run. Allowed options are: `view`, `create`, `str_replace`, `insert`.", + "enum": ["create", "str_replace", "insert"], + "description": "The commands to run. Allowed options are: `create`, `str_replace`, `insert`.", }, "file_text": { "description": "Required parameter of `create` command, with the content of the file to be created.", @@ -58,11 +136,6 @@ def json_schema(self) -> dict: "description": "Absolute path to file or directory, e.g. `/repo/file.py` or `/repo`.", "type": "string", }, - "view_range": { - "description": "Optional parameter of `view` command when `path` points to a file. If none is given, the full file is shown. If provided, the file will be shown in the indicated line number range, e.g. [11, 12] will show lines 11 and 12. Indexing at 1 to start. Setting `[start_line, -1]` shows all lines from `start_line` to the end of the file.", - "items": {"type": "integer"}, - "type": "array", - }, }, "required": ["command", "path"], }, @@ -70,15 +143,12 @@ def json_schema(self) -> dict: def execute( self, - command: Literal["view", "create", "str_replace", "insert"] | None = None, + command: Optional[Literal["create", "str_replace", "insert"]] = None, file_text: str = "", - insert_line: int | None = None, + insert_line: Optional[int] = None, new_str: str = "", - old_str: str | None = None, - path: str | None = None, - view_range: list[int] | None = None, - *args, - **kwargs, + old_str: Optional[str] = None, + path: Optional[str] = None, ) -> str: """Execute editor commands on files in the repository.""" required_dict = dict(command=command, path=path) @@ -88,9 +158,7 @@ def execute( try: abs_path = self.__get_abs_path(path) - if command == "view": - result = self.__view(abs_path, view_range) - elif command == "create": + if command == "create": result = self.__create(file_text, abs_path) elif command == "str_replace": result = self.__str_replace(new_str, old_str, abs_path) @@ -101,9 +169,8 @@ def execute( except Exception as e: return f"Error: {str(e)}" - if command in {"create", "str_replace", "insert"}: - self.modified_files.update({abs_path}) + self.modified_files.update({abs_path}) return result @property @@ -117,37 +184,6 @@ def __get_abs_path(self, path: str): else: raise ValueError(f"Path {path} contains illegal path traversal") - def __view(self, abs_path: Path, view_range): - if not abs_path.exists(): - return f"Error: Path {abs_path} does not exist" - - if abs_path.is_file(): - with open(abs_path, "r") as f: - content = f.read() - - if view_range: - lines = content.splitlines() - start, end = view_range - content = "\n".join(lines[start - 1 : end]) - return content - elif abs_path.is_dir(): - directories = [] - files = [] - for file in abs_path.iterdir(): - directories.append(file.name) if file.is_dir() else files.append(file.name) - - rv = "" - if len(directories) > 0: - rv += "Directories: \n" - rv += "\n".join(directories) - rv += "\n" - - if len(files) > 0: - rv += "Files: \n" - rv += "\n".join(files) - - return rv - def __create(self, file_text, abs_path): if abs_path.exists(): return f"Error: File {abs_path} already exists" diff --git a/patchwork/common/tools/grep_tool.py b/patchwork/common/tools/grep_tool.py new file mode 100644 index 000000000..ac173875d --- /dev/null +++ b/patchwork/common/tools/grep_tool.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +import fnmatch +import itertools +import os +from pathlib import Path + +from typing_extensions import Optional + +from patchwork.common.tools.tool import Tool + + +class FindTool(Tool, tool_name="find_files"): + def __init__(self, path: Path | str, **kwargs): + self.__working_dir = Path(path).resolve() + + @property + def json_schema(self) -> dict: + return { + "name": "find_files", + "description": f"""\ +Tool to find files in {self.__working_dir} using a pattern based on the Unix shell style. +Unix shell style: + * matches everything + ? matches any single character + [seq] matches any character in seq + [!seq] matches any char not in seq + +""", + "input_schema": { + "type": "object", + "properties": { + "pattern": { + "description": """\ +The Unix shell style pattern to match files using. + +Unix shell style: + * matches everything + ? matches any single character + [seq] matches any character in seq + [!seq] matches any char not in seq + +Example: +* '*macs' will match the file '.emacs' +* '*.py' will match all files with the '.py' extension +""", + "type": "string", + }, + "depth": { + "description": "The maximum depth files in directories to look for. Default is 1.", + "type": "integer", + }, + "is_case_sensitive": { + "description": "Whether the pattern should be case-sensitive.", + "type": "boolean", + }, + }, + "required": ["pattern"], + }, + } + + def __is_dot(self, path: Path | str) -> bool: + return any(part.startswith(".") for part in path.relative_to(self.__working_dir).parts) + + def execute(self, pattern: Optional[str] = None, depth: int = 1, is_case_sensitive: bool = False) -> str: + if pattern is None: + raise ValueError("Pattern argument is required!") + + matcher = fnmatch.fnmatch + if is_case_sensitive: + matcher = fnmatch.fnmatchcase + + file_matches = [] + dir_matches = [] + for root, dirs, files in os.walk(self.__working_dir): + root_path = Path(root) + if len(root_path.resolve().relative_to(self.__working_dir).parts) > depth: + continue + + for file in itertools.chain(dirs, files): + file_path = root_path / file + if self.__is_dot(file_path): + continue + + if file_path.is_file(): + list_to_append = file_matches + else: + list_to_append = dir_matches + + if matcher(str(file_path), pattern): + relative_file_path = file_path.relative_to(self.__working_dir) + list_to_append.append(str(relative_file_path)) + + delim = "\n * " + files_part = (delim + delim.join(file_matches)) if len(file_matches) > 0 else "\n No files found" + dirs_part = (delim + delim.join(dir_matches)) if len(dir_matches) > 0 else "\n No directories found" + return f"""\ +Files:{files_part} + +Directories:{dirs_part} + +""" + + +class FindTextTool(Tool, tool_name="find_text"): + __CHAR_LIMIT = 200 + __CHAR_LIMIT_TEXT = "" + + def __init__(self, path: Path | str, **kwargs): + self.__working_dir = Path(path).resolve() + + @property + def json_schema(self) -> dict: + return { + "name": "find_text", + "description": f"""\ +Tool to find text in a file using a pattern based on the Unix shell style. +The current working directory is always {self.__working_dir}. +The path provided should either be absolute or relative to the current working directory. + +This tool will match each line of the file with the provided pattern and prints the line number and the line content. +If the line contains more than {self.__CHAR_LIMIT} characters, the line content will be replaced with {self.__CHAR_LIMIT_TEXT}. +""", + "input_schema": { + "type": "object", + "properties": { + "path": { + "description": "The path to the file to find text in.", + "type": "string", + }, + "pattern": { + "description": """\ +The Unix shell style pattern to match files using. + +Unix shell style: + * matches everything + ? matches any single character + [seq] matches any character in seq + [!seq] matches any char not in seq + +Example: +* '*macs' will match the file '.emacs' +* '*.py' will match all files with the '.py' extension +""", + "type": "string", + }, + "is_case_sensitive": { + "description": "Whether the pattern should be case-sensitive.", + "type": "boolean", + }, + }, + "required": ["path", "pattern"], + }, + } + + def execute( + self, + path: Optional[Path] = None, + pattern: Optional[str] = None, + is_case_sensitive: bool = False, + ) -> str: + if path is None: + raise ValueError("Path argument is required!") + + if pattern is None: + raise ValueError("pattern argument is required!") + + matcher = fnmatch.fnmatch + if is_case_sensitive: + matcher = fnmatch.fnmatchcase + + path = Path(path).resolve() + if not path.is_relative_to(self.__working_dir): + raise ValueError("Path must be relative to working dir") + + matches = [] + with path.open("r") as f: + for i, line in enumerate(f.readlines()): + if not matcher(line, pattern): + continue + + content = f"Line {i + 1}: {line}" + if len(line) > self.__CHAR_LIMIT: + content = f"Line {i + 1}: {self.__CHAR_LIMIT_TEXT}" + + matches.append(content) + return f"Pattern matches found in '{path}':\n" + "\n".join(matches) diff --git a/patchwork/common/tools/tool.py b/patchwork/common/tools/tool.py index f9011c14b..af88dead3 100644 --- a/patchwork/common/tools/tool.py +++ b/patchwork/common/tools/tool.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import Type + +from typing_extensions import Type class Tool(ABC): @@ -36,9 +37,9 @@ def get_tools(**kwargs) -> dict[str, "Tool"]: return rv @staticmethod - def get_description(tooling: "ToolProtocol") -> str: + def get_description(tooling: "Tool") -> str: return tooling.json_schema.get("description", "") @staticmethod - def get_parameters(tooling: "ToolProtocol") -> str: + def get_parameters(tooling: "Tool") -> str: return ", ".join(tooling.json_schema.get("required", [])) diff --git a/pyproject.toml b/pyproject.toml index 02d0823ba..ca242733d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "patchwork-cli" -version = "0.0.99" +version = "0.0.100.dev0" description = "" authors = ["patched.codes"] license = "AGPL"