diff --git a/README.md b/README.md index 23a692c..a1ea604 100644 --- a/README.md +++ b/README.md @@ -23,5 +23,4 @@ pipx install git-draft[openai] * Mechanism for reporting feedback from a bot, and possibly allowing user to interactively respond. -* Support file rename tool. * Add MCP bot. diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index 84b2e8f..0c2102f 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -156,6 +156,14 @@ def on_write_file( def on_delete_file(self, path: PurePosixPath, _reason: str | None) -> None: print(f"Deleted {path!r}.") + def on_rename_file( + self, + src_path: PurePosixPath, + dst_path: PurePosixPath, + _reason: str | None + ) -> None: + print(f"Renamed {src_path!r} to {dst_path!r}.") + def edit(*, path: Path | None = None, text: str | None = None) -> str: if sys.stdin.isatty(): diff --git a/src/git_draft/bots/openai.py b/src/git_draft/bots/openai.py index 7e994a5..1b9921d 100644 --- a/src/git_draft/bots/openai.py +++ b/src/git_draft/bots/openai.py @@ -127,6 +127,20 @@ def params(self) -> Sequence[openai.types.chat.ChatCompletionToolParam]: }, }, ), + self._param( + name="rename_file", + description="Rename a file", + inputs={ + "src_path": { + "type": "string", + "description": "Old file path", + }, + "dst_path": { + "type": "string", + "description": "New file path", + }, + }, + ), ] @@ -159,29 +173,39 @@ def _on_write_file(self, path: PurePosixPath) -> V: def _on_delete_file(self, path: PurePosixPath) -> V: raise NotImplementedError() + def _on_rename_file( + self, src_path: PurePosixPath, dst_path: PurePosixPath + ) -> V: + raise NotImplementedError() + def _on_list_files(self, paths: Sequence[PurePosixPath]) -> V: raise NotImplementedError() def handle_function(self, function: Any) -> V: - name = function.name inputs = json.loads(function.arguments) _logger.info("Requested function: %s", function) - if name == "read_file": - path = PurePosixPath(inputs["path"]) - return self._on_read_file(path, self._toolbox.read_file(path)) - elif name == "write_file": - path = PurePosixPath(inputs["path"]) - contents = inputs["contents"] - self._toolbox.write_file(path, contents) - return self._on_write_file(path) - elif name == "delete_file": - path = PurePosixPath(inputs["path"]) - self._toolbox.delete_file(path) - return self._on_delete_file(path) - else: - assert name == "list_files" and not inputs - paths = self._toolbox.list_files() - return self._on_list_files(paths) + match function.name: + case "read_file": + path = PurePosixPath(inputs["path"]) + return self._on_read_file(path, self._toolbox.read_file(path)) + case "write_file": + path = PurePosixPath(inputs["path"]) + contents = inputs["contents"] + self._toolbox.write_file(path, contents) + return self._on_write_file(path) + case "delete_file": + path = PurePosixPath(inputs["path"]) + self._toolbox.delete_file(path) + return self._on_delete_file(path) + case "rename_file": + src_path = PurePosixPath(inputs["src_path"]) + dst_path = PurePosixPath(inputs["dst_path"]) + self._toolbox.rename_file(src_path, dst_path) + return self._on_rename_file(src_path, dst_path) + case _ as name: + assert name == "list_files" and not inputs + paths = self._toolbox.list_files() + return self._on_list_files(paths) class _CompletionsBot(Bot): @@ -234,6 +258,11 @@ def _on_write_file(self, _path: PurePosixPath) -> None: def _on_delete_file(self, _path: PurePosixPath) -> None: return None + def _on_rename_file( + self, _src_path: PurePosixPath, _dst_path: PurePosixPath + ) -> None: + return None + def _on_list_files(self, paths: Sequence[PurePosixPath]) -> str: joined = "\n".join(f"* {p}" for p in paths) return f"Here are the available files: {joined}" @@ -360,5 +389,10 @@ def _on_write_file(self, _path: PurePosixPath) -> _ToolOutput: def _on_delete_file(self, _path: PurePosixPath) -> _ToolOutput: return self._wrap("OK") + def _on_rename_file( + self, _src_path: PurePosixPath, _dst_path: PurePosixPath + ) -> _ToolOutput: + return self._wrap("OK") + def _on_list_files(self, paths: Sequence[PurePosixPath]) -> _ToolOutput: return self._wrap("\n".join(str(p) for p in paths)) diff --git a/src/git_draft/drafter.py b/src/git_draft/drafter.py index c74c2b4..b3b7401 100644 --- a/src/git_draft/drafter.py +++ b/src/git_draft/drafter.py @@ -388,6 +388,19 @@ def on_write_file( def on_delete_file(self, path: PurePosixPath, reason: str | None) -> None: self._record(reason, "delete_file", path=str(path)) + def on_rename_file( + self, + src_path: PurePosixPath, + dst_path: PurePosixPath, + reason: str | None, + ) -> None: + self._record( + reason, + "rename_file", + src_path=str(src_path), + dst_path=str(dst_path), + ) + def _record(self, reason: str | None, tool: str, **kwargs) -> None: op = _Operation( tool=tool, details=kwargs, reason=reason, start=datetime.now() diff --git a/src/git_draft/toolbox.py b/src/git_draft/toolbox.py index 304f632..1b10f7c 100644 --- a/src/git_draft/toolbox.py +++ b/src/git_draft/toolbox.py @@ -69,6 +69,15 @@ def delete_file( self._dispatch(lambda v: v.on_delete_file(path, reason)) return self._delete(path) + def rename_file( + self, + src_path: PurePosixPath, + dst_path: PurePosixPath, + reason: str | None = None, + ) -> None: + self._dispatch(lambda v: v.on_rename_file(src_path, dst_path, reason)) + self._rename(src_path, dst_path) + def _list(self) -> Sequence[PurePosixPath]: # pragma: no cover raise NotImplementedError() @@ -83,6 +92,14 @@ def _write( def _delete(self, path: PurePosixPath) -> bool: # pragma: no cover raise NotImplementedError() + def _rename( + self, src_path: PurePosixPath, dst_path: PurePosixPath + ) -> None: + # We can provide a default implementation here. + contents = self._read(src_path) + self._write(dst_path, contents) + self._delete(src_path) + class ToolVisitor(Protocol): """Tool usage hook""" @@ -103,6 +120,13 @@ def on_delete_file( self, path: PurePosixPath, reason: str | None ) -> None: ... # pragma: no cover + def on_rename_file( + self, + src_path: PurePosixPath, + dst_path: PurePosixPath, + reason: str | None, + ) -> None: ... # pragma: no cover + class StagingToolbox(Toolbox): """Git-index backed toolbox implementation diff --git a/tests/git_draft/toolbox_test.py b/tests/git_draft/toolbox_test.py index 4430263..83b1aa4 100644 --- a/tests/git_draft/toolbox_test.py +++ b/tests/git_draft/toolbox_test.py @@ -15,13 +15,13 @@ def test_list_files(self, repo: git.Repo) -> None: assert self._toolbox.list_files() == [] names = set(["one.txt", "two.txt"]) for name in names: - with open(Path(repo.working_dir, name), "w") as f: + with Path(repo.working_dir, name).open("w") as f: f.write("ok") repo.git.add(all=True) assert set(self._toolbox.list_files()) == names def test_read_file(self, repo: git.Repo) -> None: - with open(Path(repo.working_dir, "one"), "w") as f: + with Path(repo.working_dir, "one").open("w") as f: f.write("ok") path = PurePosixPath("one") @@ -38,5 +38,14 @@ def test_write_file(self, repo: git.Repo) -> None: assert not path.exists() repo.git.checkout_index(all=True) - with open(path) as f: + with path.open() as f: + assert f.read() == "hi" + + def test_rename_file(self, repo: git.Repo) -> None: + self._toolbox.write_file(PurePosixPath("one"), "hi") + self._toolbox.rename_file(PurePosixPath("one"), PurePosixPath("two")) + + repo.git.checkout_index(all=True) + assert not Path(repo.working_dir, "one").exists() + with Path(repo.working_dir, "two").open() as f: assert f.read() == "hi"