diff --git a/README.md b/README.md index 7183ab3..9e0a7a5 100644 --- a/README.md +++ b/README.md @@ -16,9 +16,6 @@ template name. Otherwise an inline prompt. * Only include files that the bot has written in draft commits. * Add `--generate` timeout option. -* Add `Bot.state_folder` class method, returning a path to a folder specific to - the bot implementation (derived from the class' name) for storing arbitrary - data. * Add URL and API key to `openai_bot`. Also add a compatibility version which does not use threads, so that it can be used with tools only. Gemini only supports the latter. diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index 34f7f40..4efafb8 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -1,3 +1,5 @@ +"""CLI entry point""" + from __future__ import annotations import importlib.metadata diff --git a/src/git_draft/bots/common.py b/src/git_draft/bots/common.py index 2428044..884b731 100644 --- a/src/git_draft/bots/common.py +++ b/src/git_draft/bots/common.py @@ -2,10 +2,10 @@ import dataclasses from datetime import datetime -from pathlib import PurePosixPath +from pathlib import Path, PurePosixPath from typing import Callable, Sequence -from ..common import JSONObject +from ..common import ensure_state_home, JSONObject class Toolbox: @@ -97,5 +97,12 @@ class Action: class Bot: + @classmethod + def state_folder_path(cls) -> Path: + name = cls.__qualname__ + if cls.__module__: + name = f"{cls.__module__}.{name}" + return ensure_state_home() / "bots" / name + def act(self, prompt: str, toolbox: Toolbox) -> Action: raise NotImplementedError() diff --git a/src/git_draft/common.py b/src/git_draft/common.py index 8202a23..52accf2 100644 --- a/src/git_draft/common.py +++ b/src/git_draft/common.py @@ -1,3 +1,5 @@ +"""Miscellaneous utilities""" + from __future__ import annotations import dataclasses diff --git a/src/git_draft/editor.py b/src/git_draft/editor.py index 96fe63c..0cc148c 100644 --- a/src/git_draft/editor.py +++ b/src/git_draft/editor.py @@ -1,3 +1,5 @@ +"""CLI interactive editing utilities""" + import os import shutil import subprocess @@ -24,6 +26,11 @@ def _get_tty_filename(): def open_editor(placeholder="", *, _open_tty=open) -> str: + """Open an editor to edit a file and return its contents + + The method returns once the editor is closed. It respects the `$EDITOR` + environment variable. + """ with tempfile.NamedTemporaryFile(delete_on_close=False) as temp: binpath = _guess_editor_binpath() if not binpath: diff --git a/src/git_draft/prompt.py b/src/git_draft/prompt.py index d523750..c776708 100644 --- a/src/git_draft/prompt.py +++ b/src/git_draft/prompt.py @@ -1,3 +1,5 @@ +"""Prompt templating support""" + import dataclasses import git import jinja2 diff --git a/src/git_draft/store.py b/src/git_draft/store.py index 997bdc4..614c246 100644 --- a/src/git_draft/store.py +++ b/src/git_draft/store.py @@ -1,3 +1,5 @@ +"""Persistent state storage""" + import contextlib from datetime import datetime import functools diff --git a/tests/git_draft/bots/common_test.py b/tests/git_draft/bots/common_test.py index d75cd25..95dc794 100644 --- a/tests/git_draft/bots/common_test.py +++ b/tests/git_draft/bots/common_test.py @@ -46,3 +46,12 @@ def test_delete_file(self): self._toolbox.delete_file(PurePosixPath("/mock/path")) self._hook.assert_called_once() assert self._toolbox.operations[0].tool == "delete_file" + + +class FakeBot(sut.Bot): + pass + + +class TestBot: + def test_state_folder_path(self) -> None: + assert "bots.common_test.FakeBot" in str(FakeBot.state_folder_path()) diff --git a/tests/git_draft/drafter_test.py b/tests/git_draft/drafter_test.py index ffe282c..74fc8b2 100644 --- a/tests/git_draft/drafter_test.py +++ b/tests/git_draft/drafter_test.py @@ -45,7 +45,7 @@ def test_write_file(self, repo: git.Repo) -> None: assert f.read() == "hi" -class _FakeBot(Bot): +class FakeBot(Bot): def act(self, prompt: str, toolbox: Toolbox) -> Action: toolbox.write_file(PurePosixPath("PROMPT"), prompt) return Action() @@ -72,33 +72,33 @@ def _commits(self) -> Sequence[git.Commit]: return list(self._repo.iter_commits()) def test_generate_draft(self) -> None: - self._drafter.generate_draft("hello", _FakeBot()) + self._drafter.generate_draft("hello", FakeBot()) assert len(self._commits()) == 2 def test_generate_then_discard_draft(self) -> None: - self._drafter.generate_draft("hello", _FakeBot()) + self._drafter.generate_draft("hello", FakeBot()) self._drafter.discard_draft() assert len(self._commits()) == 1 def test_generate_outside_branch(self) -> None: self._repo.git.checkout("--detach") with pytest.raises(RuntimeError): - self._drafter.generate_draft("ok", _FakeBot()) + self._drafter.generate_draft("ok", FakeBot()) def test_generate_empty_prompt(self) -> None: with pytest.raises(ValueError): - self._drafter.generate_draft("", _FakeBot()) + self._drafter.generate_draft("", FakeBot()) def test_generate_dirty_index_no_reset(self) -> None: self._write("log") self._repo.git.add(all=True) with pytest.raises(ValueError): - self._drafter.generate_draft("hi", _FakeBot()) + self._drafter.generate_draft("hi", FakeBot()) def test_generate_dirty_index_reset_sync(self) -> None: self._write("log", "11") self._repo.git.add(all=True) - self._drafter.generate_draft("hi", _FakeBot(), reset=True, sync=True) + self._drafter.generate_draft("hi", FakeBot(), reset=True, sync=True) assert self._read("log") == "11" assert not self._path("PROMPT").exists() self._repo.git.checkout(".") @@ -107,13 +107,13 @@ def test_generate_dirty_index_reset_sync(self) -> None: def test_generate_clean_index_sync(self) -> None: prompt = TemplatedPrompt("add-test", {"symbol": "abc"}) - self._drafter.generate_draft(prompt, _FakeBot(), sync=True) + self._drafter.generate_draft(prompt, FakeBot(), sync=True) self._repo.git.checkout(".") assert "abc" in self._read("PROMPT") assert len(self._commits()) == 2 # init, prompt def test_generate_reuse_branch(self) -> None: - bot = _FakeBot() + bot = FakeBot() self._drafter.generate_draft("prompt1", bot) self._drafter.generate_draft("prompt2", bot) self._repo.git.checkout(".") @@ -121,7 +121,7 @@ def test_generate_reuse_branch(self) -> None: assert len(self._commits()) == 3 # init, prompt, prompt def test_generate_reuse_branch_sync(self) -> None: - bot = _FakeBot() + bot = FakeBot() self._drafter.generate_draft("prompt1", bot) self._drafter.generate_draft("prompt2", bot, sync=True) assert len(self._commits()) == 4 # init, prompt, sync, prompt @@ -132,7 +132,7 @@ def test_discard_outside_draft(self) -> None: def test_discard_after_branch_move(self) -> None: self._write("log", "11") - self._drafter.generate_draft("hi", _FakeBot(), sync=True) + self._drafter.generate_draft("hi", FakeBot(), sync=True) branch = self._repo.active_branch self._repo.git.checkout("main") self._repo.index.commit("advance") @@ -143,7 +143,7 @@ def test_discard_after_branch_move(self) -> None: def test_discard_restores_worktree(self) -> None: self._write("p1.txt", "a1") self._write("p2.txt", "b1") - self._drafter.generate_draft("hello", _FakeBot(), sync=True) + self._drafter.generate_draft("hello", FakeBot(), sync=True) self._write("p1.txt", "a2") self._drafter.discard_draft(delete=True) assert self._read("p1.txt") == "a1" @@ -151,7 +151,7 @@ def test_discard_restores_worktree(self) -> None: def test_finalize_keeps_changes(self) -> None: self._write("p1.txt", "a1") - self._drafter.generate_draft("hello", _FakeBot(), checkout=True) + self._drafter.generate_draft("hello", FakeBot(), checkout=True) self._write("p1.txt", "a2") self._drafter.finalize_draft() assert self._read("p1.txt") == "a2"