From 270d2bf3315218cb735138e3c65b73fc76ebb870 Mon Sep 17 00:00:00 2001 From: TIANYOU CHEN <42710806+CTY-git@users.noreply.github.com> Date: Tue, 10 Dec 2024 14:42:44 +0800 Subject: [PATCH] Better line-ending handling and better prompts to direct the llm to the right path immediately --- patchwork/common/tools/bash_tool.py | 7 ++- patchwork/common/tools/code_edit_tools.py | 72 ++++++++++++----------- patchwork/common/utils/utils.py | 20 ++++++- patchwork/steps/FixIssue/FixIssue.py | 19 +++--- pyproject.toml | 2 +- 5 files changed, 72 insertions(+), 48 deletions(-) diff --git a/patchwork/common/tools/bash_tool.py b/patchwork/common/tools/bash_tool.py index cd6c51fa9..77e9052ae 100644 --- a/patchwork/common/tools/bash_tool.py +++ b/patchwork/common/tools/bash_tool.py @@ -7,7 +7,7 @@ class BashTool(Tool, tool_name="bash"): - def __init__(self, path: str): + def __init__(self, path: Path): super().__init__() self.path = Path(path) self.modified_files = [] @@ -16,7 +16,7 @@ def __init__(self, path: str): def json_schema(self) -> dict: return { "name": "bash", - "description": """Run commands in a bash shell + "description": f"""Run commands in a bash shell * When invoking this tool, the contents of the "command" parameter does NOT need to be XML-escaped. * You don't have access to the internet via this tool. @@ -24,7 +24,8 @@ def json_schema(self) -> dict: * State is persistent across command calls and discussions with the user. * To inspect a particular line range of a file, e.g. lines 10-25, try 'sed -n 10,25p /path/to/the/file'. * Please avoid commands that may produce a very large amount of output. -* Please run long lived commands in the background, e.g. 'sleep 10 &' or start a server in the background.""", +* Please run long lived commands in the background, e.g. 'sleep 10 &' or start a server in the background. +* The working directory is always {self.path}""", "input_schema": { "type": "object", "properties": {"command": {"type": "string", "description": "The bash command to run."}}, diff --git a/patchwork/common/tools/code_edit_tools.py b/patchwork/common/tools/code_edit_tools.py index fb44ea115..387226c9d 100644 --- a/patchwork/common/tools/code_edit_tools.py +++ b/patchwork/common/tools/code_edit_tools.py @@ -5,25 +5,28 @@ from typing import Literal from patchwork.common.tools.tool import Tool +from patchwork.common.utils.utils import detect_newline class CodeEditTool(Tool, tool_name="code_edit_tool"): - def __init__(self, path: str): + def __init__(self, path: Path): super().__init__() - self.repo_path = Path(path) + self.repo_path = path self.modified_files = set() @property def json_schema(self) -> dict: return { "name": "code_edit_tool", - "description": """Custom editing tool for viewing, creating and editing files + "description": f"""\ +Custom editing tool for viewing, creating and editing files * State is persistent across command calls and discussions with the user * If `path` is a file, `view` displays the result of applying `cat -n`. If `path` is a directory, `view` lists non-hidden files and directories up to 2 levels deep * The `create` command cannot be used if the specified `path` already exists as a file * If a `command` generates a long output, it will be truncated and marked with `` * The `undo_edit` command will revert the last edit made to the file at `path` +* The working directory is always {self.repo_path} Notes for using the `str_replace` command: * The `old_str` parameter should match EXACTLY one or more consecutive lines from the original file. Be mindful of whitespaces! @@ -86,40 +89,39 @@ def execute( return f"Error: `{'` and `'.join(missing_required)}` parameters must be set and cannot be empty" try: + abs_path = self.__get_abs_path(path) if command == "view": - result = self.__view(path, view_range) + result = self.__view(abs_path, view_range) elif command == "create": - result = self.__create(file_text, path) + result = self.__create(file_text, abs_path) elif command == "str_replace": - result = self.__str_replace(new_str, old_str, path) + result = self.__str_replace(new_str, old_str, abs_path) elif command == "insert": - result = self.__insert(insert_line, new_str, path) + result = self.__insert(insert_line, new_str, abs_path) else: return f"Error: Unknown action {command}" - - if command in {"create", "str_replace", "insert"}: - self.modified_files.update({path.lstrip("/")}) - - return result - except Exception as e: return f"Error: {str(e)}" + if command in {"create", "str_replace", "insert"}: + self.modified_files.update({abs_path.relative_to(self.repo_path)}) + + return result + @property def tool_records(self): - return dict(modified_files=[{"path": file} for file in self.modified_files]) + return dict(modified_files=[{"path": str(file)} for file in self.modified_files]) def __get_abs_path(self, path: str): - abs_path = (self.repo_path / path.lstrip("/")).resolve() - if not abs_path.is_relative_to(self.repo_path.resolve()): + wanted_path = Path(path).resolve() + if wanted_path.is_relative_to(self.repo_path): + return wanted_path + else: raise ValueError(f"Path {path} contains illegal path traversal") - return abs_path - - def __view(self, path, view_range): - abs_path = self.__get_abs_path(path) + def __view(self, abs_path: Path, view_range): if not abs_path.exists(): - return f"Error: Path {path} does not exist" + return f"Error: Path {abs_path} does not exist" if abs_path.is_file(): with open(abs_path, "r") as f: @@ -141,38 +143,38 @@ def __view(self, path, view_range): result.append(f) return "\n".join(result) - def __create(self, file_text, path): - abs_path = self.__get_abs_path(path) + def __create(self, file_text, abs_path): if abs_path.exists(): - return f"Error: File {path} already exists" + return f"Error: File {abs_path} already exists" abs_path.parent.mkdir(parents=True, exist_ok=True) abs_path.write_text(file_text) - return f"File created successfully at: {path}" + return f"File created successfully at: {abs_path}" - def __str_replace(self, new_str, old_str, path): - abs_path = self.__get_abs_path(path) + def __str_replace(self, new_str, old_str, abs_path): if not abs_path.exists(): - return f"Error: File {path} does not exist" + return f"Error: File {abs_path} does not exist" if not abs_path.is_file(): - return f"Error: File {path} is not a file" + return f"Error: File {abs_path} is not a file" content = abs_path.read_text() occurrences = content.count(old_str) if occurrences != 1: return f"Error: Found {occurrences} occurrences of old_str, expected exactly 1" new_content = content.replace(old_str, new_str) - with open(abs_path, "w") as f: - f.write(new_content) + newline = detect_newline(abs_path) + with abs_path.open("w", newline=newline) as fp: + fp.write(new_content) return "Replacement successful" - def __insert(self, insert_line, new_str, path): - abs_path = self.__get_abs_path(path) + def __insert(self, insert_line, new_str, abs_path): if not abs_path.is_file(): - return f"Error: File {path} does not exist or is not a file" + return f"Error: File {abs_path} does not exist or is not a file" lines = abs_path.read_text().splitlines(keepends=True) if insert_line is None or insert_line < 1 or insert_line > len(lines): return f"Error: Invalid insert line {insert_line}" lines.insert(insert_line, new_str + "\n") - abs_path.write_text("".join(lines)) + newline = detect_newline(abs_path) + with abs_path.open("w", newline=newline) as fp: + fp.write("".join(lines)) return "Insert successful" diff --git a/patchwork/common/utils/utils.py b/patchwork/common/utils/utils.py index 766114e6c..4ddf2ef74 100644 --- a/patchwork/common/utils/utils.py +++ b/patchwork/common/utils/utils.py @@ -9,13 +9,31 @@ import tiktoken from chardet.universaldetector import UniversalDetector from git import Head, Repo -from typing_extensions import Any, Callable +from typing_extensions import Any, Callable, Iterable, Counter from patchwork.common.utils.dependency import chromadb from patchwork.logger import logger from patchwork.managed_files import HOME_FOLDER _CLEANUP_FILES: set[Path] = set() +_NEWLINES = {"\n", "\r\n", "\r"} + +def detect_newline(path: str | Path) -> str | None: + with open(path, "r", newline="") as f: + lines = f.read().splitlines(keepends=True) + if len(lines) < 1: + return None + + counter = Counter(_NEWLINES) + for line in lines: + newline_len = 0 + newline = "\n" + for possible_newline in _NEWLINES: + if line.endswith(possible_newline) and len(possible_newline) > newline_len: + newline_len = len(possible_newline) + newline = possible_newline + counter[newline] += 1 + return counter.most_common(1)[0][0] def _cleanup_files(): diff --git a/patchwork/steps/FixIssue/FixIssue.py b/patchwork/steps/FixIssue/FixIssue.py index cb2df70b3..c7a1dc051 100644 --- a/patchwork/steps/FixIssue/FixIssue.py +++ b/patchwork/steps/FixIssue/FixIssue.py @@ -20,19 +20,21 @@ class _ResolveIssue(AnalyzeImplementStrategy): def __init__(self, repo_path: str, llm_client: LlmClient, issue_description: Any, **kwargs): - self.tool_set = Tool.get_tools(path=repo_path) + path = Path(repo_path).resolve() + self.tool_set = Tool.get_tools(path=path) super().__init__( llm_client=llm_client, initial_template_data=dict(issue=issue_description), - analysis_prompt_template=""" -. + analysis_prompt_template=f"""\ + +{path} -I've uploaded a code repository in the current working directory (not in /tmp/inputs). +I've uploaded a code repository in the current working directory. Consider the following issue: -{{issue}} +{{{{issue}}}} Let's first explore and analyze the repository to understand where the issue is located. @@ -49,15 +51,16 @@ def __init__(self, repo_path: str, llm_client: LlmClient, issue_description: Any The error reproduction script and its output Description of the specific changes needed """, - implementation_prompt_template=""" -. + implementation_prompt_template=f"""\ + +{path} I've uploaded a code repository in the current working directory (not in /tmp/inputs). Based on our previous analysis: -{{analysis_results}} +{{{{analysis_results}}}} Let's implement the necessary changes: diff --git a/pyproject.toml b/pyproject.toml index 740a88437..d72165e45 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "patchwork-cli" -version = "0.0.81" +version = "0.0.82" description = "" authors = ["patched.codes"] license = "AGPL"