diff --git a/pyproject.toml b/pyproject.toml index 9af51b0..c0ea674 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,6 +97,7 @@ ignore = ["E203", "E501", "E704", "W503"] [tool.isort] profile = "black" force_sort_within_sections = true +lines_after_imports = 2 [tool.mypy] disable_error_code = "import-untyped" diff --git a/src/git_draft/__init__.py b/src/git_draft/__init__.py index 6433e36..64b74e5 100644 --- a/src/git_draft/__init__.py +++ b/src/git_draft/__init__.py @@ -2,6 +2,7 @@ from .bots import Action, Bot, Toolbox + __all__ = [ "Action", "Bot", diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index e27c3f8..edf246f 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -15,6 +15,9 @@ from .store import Store +_logger = logging.getLogger(__name__) + + def new_parser() -> optparse.OptionParser: parser = optparse.OptionParser( prog=PROGRAM, @@ -134,12 +137,19 @@ def main() -> None: prompt, bot, checkout=opts.checkout, reset=opts.reset ) elif command == "finalize": - drafter.finalize_draft(delete=opts.delete) + name = drafter.finalize_draft(delete=opts.delete) + print(f"Finalized {name}") elif command == "revert": - drafter.revert_draft(delete=opts.delete) + name = drafter.revert_draft(delete=opts.delete) + print(f"Reverted {name}") else: raise UnreachableError() if __name__ == "__main__": - main() + try: + main() + except Exception as err: + _logger.exception("Program failed.") + print(f"Error: {err}", file=sys.stderr) + sys.exit(1) diff --git a/src/git_draft/bots/__init__.py b/src/git_draft/bots/__init__.py index 86a2cd6..85aacfa 100644 --- a/src/git_draft/bots/__init__.py +++ b/src/git_draft/bots/__init__.py @@ -8,7 +8,9 @@ import sys from ..common import BotConfig, reindent -from .common import Action, Bot, Goal, Operation, OperationHook, Toolbox +from ..toolbox import Operation, OperationHook, Toolbox +from .common import Action, Bot, Goal + __all__ = [ "Action", diff --git a/src/git_draft/bots/common.py b/src/git_draft/bots/common.py index 3b95c72..4eb2313 100644 --- a/src/git_draft/bots/common.py +++ b/src/git_draft/bots/common.py @@ -1,94 +1,10 @@ from __future__ import annotations import dataclasses -from datetime import datetime -from pathlib import Path, PurePosixPath -from typing import Callable, Sequence +from pathlib import Path -from ..common import JSONObject, ensure_state_home - - -class Toolbox: - """File-system intermediary - - Note that the toolbox is not thread-safe. Concurrent operations should be - serialized by the caller. - """ - - # TODO: Something similar to https://aider.chat/docs/repomap.html, - # including inferring the most important files, and allowing returning - # signature-only versions. - - # TODO: Support a diff-based edit method. - # https://gist.github.com/noporpoise/16e731849eb1231e86d78f9dfeca3abc - - def __init__(self, hook: OperationHook | None = None) -> None: - self.operations = list[Operation]() - self._operation_hook = hook - - def _record_operation( - self, reason: str | None, tool: str, **kwargs - ) -> None: - op = Operation( - tool=tool, details=kwargs, reason=reason, start=datetime.now() - ) - self.operations.append(op) - if self._operation_hook: - self._operation_hook(op) - - def list_files(self, reason: str | None = None) -> Sequence[PurePosixPath]: - self._record_operation(reason, "list_files") - return self._list() - - def read_file( - self, - path: PurePosixPath, - reason: str | None = None, - ) -> str: - self._record_operation(reason, "read_file", path=str(path)) - return self._read(path) - - def write_file( - self, - path: PurePosixPath, - contents: str, - reason: str | None = None, - ) -> None: - self._record_operation( - reason, "write_file", path=str(path), size=len(contents) - ) - return self._write(path, contents) - - def delete_file( - self, - path: PurePosixPath, - reason: str | None = None, - ) -> None: - self._record_operation(reason, "delete_file", path=str(path)) - return self._delete(path) - - def _list(self) -> Sequence[PurePosixPath]: - raise NotImplementedError() - - def _read(self, path: PurePosixPath) -> str: - raise NotImplementedError() - - def _write(self, path: PurePosixPath, contents: str) -> None: - raise NotImplementedError() - - def _delete(self, path: PurePosixPath) -> None: - raise NotImplementedError() - - -@dataclasses.dataclass(frozen=True) -class Operation: - tool: str - details: JSONObject - reason: str | None - start: datetime - - -type OperationHook = Callable[[Operation], None] +from ..common import ensure_state_home +from ..toolbox import Toolbox @dataclasses.dataclass(frozen=True) diff --git a/src/git_draft/bots/openai.py b/src/git_draft/bots/openai.py index 39e9dcf..453aa03 100644 --- a/src/git_draft/bots/openai.py +++ b/src/git_draft/bots/openai.py @@ -23,6 +23,7 @@ from ..common import JSONObject, reindent from .common import Action, Bot, Goal, Toolbox + _logger = logging.getLogger(__name__) @@ -132,7 +133,7 @@ class _ToolHandler[V]: def __init__(self, toolbox: Toolbox) -> None: self._toolbox = toolbox - def _on_read_file(self, path: PurePosixPath, contents: str) -> V: + def _on_read_file(self, path: PurePosixPath, contents: str | None) -> V: raise NotImplementedError() def _on_write_file(self, path: PurePosixPath) -> V: @@ -196,10 +197,10 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action: class _CompletionsToolHandler(_ToolHandler[str | None]): - def _on_read_file(self, path: PurePosixPath, contents: str) -> str: - return ( - f"Here are the contents of `{path}`:\n\n```\n{contents}\n```\n" "" - ) + def _on_read_file(self, path: PurePosixPath, contents: str | None) -> str: + if contents is None: + return f"`{path}` does not exist." + return f"The contents of `{path}` are:\n\n```\n{contents}\n```\n" def _on_write_file(self, path: PurePosixPath) -> None: return None @@ -303,8 +304,10 @@ def __init__(self, toolbox: Toolbox, call_id: str) -> None: def _wrap(self, output: str) -> _ToolOutput: return _ToolOutput(tool_call_id=self._call_id, output=output) - def _on_read_file(self, path: PurePosixPath, contents: str) -> _ToolOutput: - return self._wrap(contents) + def _on_read_file( + self, path: PurePosixPath, contents: str | None + ) -> _ToolOutput: + return self._wrap(contents or "") def _on_write_file(self, path: PurePosixPath) -> _ToolOutput: return self._wrap("OK") diff --git a/src/git_draft/common.py b/src/git_draft/common.py index ace2b6f..4b251e9 100644 --- a/src/git_draft/common.py +++ b/src/git_draft/common.py @@ -14,6 +14,7 @@ import xdg_base_dirs + PROGRAM = "git-draft" diff --git a/src/git_draft/drafter.py b/src/git_draft/drafter.py index b9eb12d..f973197 100644 --- a/src/git_draft/drafter.py +++ b/src/git_draft/drafter.py @@ -3,19 +3,19 @@ import dataclasses import json import logging -from pathlib import PurePosixPath import re -import tempfile import textwrap import time -from typing import Match, Sequence, override +from typing import Match, Sequence import git -from .bots import Bot, Goal, OperationHook, Toolbox +from .bots import Bot, Goal, OperationHook from .common import random_id from .prompt import PromptRenderer, TemplatedPrompt from .store import Store, sql +from .toolbox import StagingToolbox + _logger = logging.getLogger(__name__) @@ -49,53 +49,6 @@ def new_suffix(): return random_id(9) -class _Toolbox(Toolbox): - """Git-index backed toolbox - - All files are directly read from and written to the index. This allows - concurrent editing without interference. - """ - - def __init__(self, repo: git.Repo, hook: OperationHook | None) -> None: - super().__init__(hook) - self._repo = repo - self._written = set[str]() - - @override - def _list(self) -> Sequence[PurePosixPath]: - # Show staged files. - return self._repo.git.ls_files().splitlines() - - @override - def _read(self, path: PurePosixPath) -> str: - # Read the file from the index. - return self._repo.git.show(f":{path}") - - @override - def _write(self, path: PurePosixPath, contents: str) -> None: - self._written.add(str(path)) - # Update the index without touching the worktree. - # https://stackoverflow.com/a/25352119 - with tempfile.NamedTemporaryFile(delete_on_close=False) as temp: - temp.write(contents.encode("utf8")) - temp.close() - sha = self._repo.git.hash_object("-w", temp.name, path=path) - mode = 644 # TODO: Read from original file if it exists. - self._repo.git.update_index( - f"{mode},{sha},{path}", add=True, cacheinfo=True - ) - - def trim_index(self) -> None: - diff = self._repo.git.diff(name_only=True, cached=True) - untouched = [ - path - for path in diff.splitlines() - if path and path not in self._written - ] - if untouched: - self._repo.git.reset("--", *untouched) - - class Drafter: """Draft state orchestrator""" @@ -139,17 +92,19 @@ def generate_draft( branch = _Branch.active(self._repo) if branch: - _logger.debug("Reusing active branch %s.", branch) self._stage_changes(sync) + _logger.debug("Reusing active branch %s.", branch) else: branch = self._create_branch(sync) _logger.debug("Created branch %s.", branch) + toolbox = StagingToolbox(self._repo, self._operation_hook) if isinstance(prompt, TemplatedPrompt): - renderer = PromptRenderer.for_repo(self._repo) + renderer = PromptRenderer.for_toolbox(toolbox) prompt_contents = renderer.render(prompt) else: prompt_contents = prompt + with self._store.cursor() as cursor: [(prompt_id,)] = cursor.execute( sql("add-prompt"), @@ -161,7 +116,6 @@ def generate_draft( start_time = time.perf_counter() goal = Goal(prompt_contents, timeout) - toolbox = _Toolbox(self._repo, self._operation_hook) action = bot.act(goal, toolbox) end_time = time.perf_counter() @@ -201,11 +155,11 @@ def generate_draft( if checkout: self._repo.git.checkout("--", ".") - def finalize_draft(self, delete=False) -> None: - self._exit_draft(revert=False, delete=delete) + def finalize_draft(self, delete=False) -> str: + return self._exit_draft(revert=False, delete=delete) - def revert_draft(self, delete=False) -> None: - self._exit_draft(revert=True, delete=delete) + def revert_draft(self, delete=False) -> str: + return self._exit_draft(revert=True, delete=delete) def _create_branch(self, sync: bool) -> _Branch: if self._repo.head.is_detached: @@ -241,7 +195,7 @@ def _stage_changes(self, sync: bool) -> str | None: ref = self._repo.index.commit("draft! sync") return ref.hexsha - def _exit_draft(self, *, revert: bool, delete: bool) -> None: + def _exit_draft(self, *, revert: bool, delete: bool) -> str: branch = _Branch.active(self._repo) if not branch: raise RuntimeError("Not currently on a draft branch") @@ -268,7 +222,7 @@ def _exit_draft(self, *, revert: bool, delete: bool) -> None: self._repo.git.reset("-N", origin_branch) self._repo.git.checkout(origin_branch) - # Finally, we revert the relevant files if needed. If a sync commit had + # Next, we revert the relevant files if needed. If a sync commit had # been created, we simply revert to it. Otherwise we compute which # files have changed due to draft commits and revert only those. if revert: @@ -283,6 +237,8 @@ def _exit_draft(self, *, revert: bool, delete: bool) -> None: if delete: self._repo.git.branch("-D", branch.name) + return branch.name + def _changed_files(self, spec) -> Sequence[str]: return self._repo.git.diff(spec, name_only=True).splitlines() diff --git a/src/git_draft/editor.py b/src/git_draft/editor.py index 7327b61..0cc148c 100644 --- a/src/git_draft/editor.py +++ b/src/git_draft/editor.py @@ -6,6 +6,7 @@ import sys import tempfile + _default_editors = ["vim", "emacs", "nano"] diff --git a/src/git_draft/prompt.py b/src/git_draft/prompt.py index 1cb4076..05e0949 100644 --- a/src/git_draft/prompt.py +++ b/src/git_draft/prompt.py @@ -3,11 +3,12 @@ import dataclasses from typing import Mapping, Self -import git import jinja2 +from .bots import Toolbox from .common import Config, package_root + _prompt_root = package_root / "prompts" @@ -35,7 +36,7 @@ def __init__(self, env: jinja2.Environment) -> None: self._environment = env @classmethod - def for_repo(cls, repo: git.Repo) -> Self: + def for_toolbox(cls, toolbox: Toolbox) -> Self: env = jinja2.Environment( auto_reload=False, autoescape=False, @@ -46,7 +47,7 @@ def for_repo(cls, repo: git.Repo) -> Self: undefined=jinja2.StrictUndefined, ) env.globals["repo"] = { - "file_paths": repo.git.ls_files().splitlines(), + "file_paths": [str(p) for p in toolbox.list_files()], } return cls(env) diff --git a/src/git_draft/store.py b/src/git_draft/store.py index 0b12e9c..614c246 100644 --- a/src/git_draft/store.py +++ b/src/git_draft/store.py @@ -8,6 +8,7 @@ from .common import ensure_state_home, package_root + sqlite3.register_adapter(datetime, lambda d: d.isoformat()) sqlite3.register_converter( "timestamp", lambda v: datetime.fromisoformat(v.decode()) diff --git a/src/git_draft/toolbox.py b/src/git_draft/toolbox.py new file mode 100644 index 0000000..0bc450f --- /dev/null +++ b/src/git_draft/toolbox.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import dataclasses +from datetime import datetime +from pathlib import PurePosixPath +import tempfile +from typing import Callable, Sequence, override + +import git + +from .common import JSONObject + + +class Toolbox: + """File-system intermediary + + Note that the toolbox is not thread-safe. Concurrent operations should be + serialized by the caller. + """ + + # TODO: Something similar to https://aider.chat/docs/repomap.html, + # including inferring the most important files, and allowing returning + # signature-only versions. + + # TODO: Support a diff-based edit method. + # https://gist.github.com/noporpoise/16e731849eb1231e86d78f9dfeca3abc + + def __init__(self, hook: OperationHook | None = None) -> None: + self.operations = list[Operation]() + self._operation_hook = hook + + def _record_operation( + self, reason: str | None, tool: str, **kwargs + ) -> None: + op = Operation( + tool=tool, details=kwargs, reason=reason, start=datetime.now() + ) + self.operations.append(op) + if self._operation_hook: + self._operation_hook(op) + + def list_files(self, reason: str | None = None) -> Sequence[PurePosixPath]: + self._record_operation(reason, "list_files") + return self._list() + + def read_file( + self, + path: PurePosixPath, + reason: str | None = None, + ) -> str | None: + self._record_operation(reason, "read_file", path=str(path)) + try: + return self._read(path) + except FileNotFoundError: + return None + + def write_file( + self, + path: PurePosixPath, + contents: str, + reason: str | None = None, + ) -> None: + self._record_operation( + reason, "write_file", path=str(path), size=len(contents) + ) + return self._write(path, contents) + + def delete_file( + self, + path: PurePosixPath, + reason: str | None = None, + ) -> None: + self._record_operation(reason, "delete_file", path=str(path)) + return self._delete(path) + + def _list(self) -> Sequence[PurePosixPath]: + raise NotImplementedError() + + def _read(self, path: PurePosixPath) -> str: + raise NotImplementedError() + + def _write(self, path: PurePosixPath, contents: str) -> None: + raise NotImplementedError() + + def _delete(self, path: PurePosixPath) -> None: + raise NotImplementedError() + + +@dataclasses.dataclass(frozen=True) +class Operation: + tool: str + details: JSONObject + reason: str | None + start: datetime + + +type OperationHook = Callable[[Operation], None] + + +class StagingToolbox(Toolbox): + """Git-index backed toolbox + + All files are directly read from and written to the index. This allows + concurrent editing without interference. + """ + + def __init__( + self, repo: git.Repo, hook: OperationHook | None = None + ) -> None: + super().__init__(hook) + self._repo = repo + self._written = set[str]() + + @override + def _list(self) -> Sequence[PurePosixPath]: + # Show staged files. + return self._repo.git.ls_files().splitlines() + + @override + def _read(self, path: PurePosixPath) -> str: + # Read the file from the index. + return self._repo.git.show(f":{path}") + + @override + def _write(self, path: PurePosixPath, contents: str) -> None: + self._written.add(str(path)) + # Update the index without touching the worktree. + # https://stackoverflow.com/a/25352119 + with tempfile.NamedTemporaryFile(delete_on_close=False) as temp: + temp.write(contents.encode("utf8")) + temp.close() + sha = self._repo.git.hash_object("-w", temp.name, path=path) + mode = 644 # TODO: Read from original file if it exists. + self._repo.git.update_index( + f"{mode},{sha},{path}", add=True, cacheinfo=True + ) + + def trim_index(self) -> None: + diff = self._repo.git.diff(name_only=True, cached=True) + untouched = [ + path + for path in diff.splitlines() + if path and path not in self._written + ] + if untouched: + self._repo.git.reset("--", *untouched) diff --git a/tests/git_draft/drafter_test.py b/tests/git_draft/drafter_test.py index fcf432c..c65f703 100644 --- a/tests/git_draft/drafter_test.py +++ b/tests/git_draft/drafter_test.py @@ -10,42 +10,6 @@ from git_draft.store import Store -class TestToolbox: - @pytest.fixture(autouse=True) - def setup(self, repo: git.Repo) -> None: - self._toolbox = sut._Toolbox(repo, None) - - def test_list_files(self, repo: git.Repo) -> None: - assert self._toolbox.list_files() == [] - names = set(["one.txt", "two.txt"]) - for name in names: - with open(Path(repo.working_dir, name), "w") as f: - f.write("ok") - repo.git.add(all=True) - assert set(self._toolbox.list_files()) == names - - def test_read_file(self, repo: git.Repo) -> None: - with open(Path(repo.working_dir, "one"), "w") as f: - f.write("ok") - - path = PurePosixPath("one") - with pytest.raises(git.GitCommandError): - assert self._toolbox.read_file(path) == "" - - repo.git.add(all=True) - assert self._toolbox.read_file(path) == "ok" - - def test_write_file(self, repo: git.Repo) -> None: - self._toolbox.write_file(PurePosixPath("one"), "hi") - - path = Path(repo.working_dir, "one") - assert not path.exists() - - repo.git.checkout_index(all=True) - with open(path) as f: - assert f.read() == "hi" - - class FakeBot(Bot): def act(self, goal: Goal, toolbox: Toolbox) -> Action: toolbox.write_file(PurePosixPath("PROMPT"), goal.prompt) diff --git a/tests/git_draft/prompt_test.py b/tests/git_draft/prompt_test.py index 0227cbd..e21ef2e 100644 --- a/tests/git_draft/prompt_test.py +++ b/tests/git_draft/prompt_test.py @@ -1,12 +1,14 @@ import pytest import git_draft.prompt as sut +from git_draft.toolbox import StagingToolbox class TestPromptRenderer: @pytest.fixture(autouse=True) def setup(self, repo) -> None: - self._renderer = sut.PromptRenderer.for_repo(repo) + toolbox = StagingToolbox(repo) + self._renderer = sut.PromptRenderer.for_toolbox(toolbox) def test_ok(self) -> None: prompt = sut.TemplatedPrompt.parse("add-test", "symbol=foo") diff --git a/tests/git_draft/toolbox_test.py b/tests/git_draft/toolbox_test.py new file mode 100644 index 0000000..4430263 --- /dev/null +++ b/tests/git_draft/toolbox_test.py @@ -0,0 +1,42 @@ +from pathlib import Path, PurePosixPath + +import git +import pytest + +import git_draft.toolbox as sut + + +class TestStagingToolbox: + @pytest.fixture(autouse=True) + def setup(self, repo: git.Repo) -> None: + self._toolbox = sut.StagingToolbox(repo) + + def test_list_files(self, repo: git.Repo) -> None: + assert self._toolbox.list_files() == [] + names = set(["one.txt", "two.txt"]) + for name in names: + with open(Path(repo.working_dir, name), "w") as f: + f.write("ok") + repo.git.add(all=True) + assert set(self._toolbox.list_files()) == names + + def test_read_file(self, repo: git.Repo) -> None: + with open(Path(repo.working_dir, "one"), "w") as f: + f.write("ok") + + path = PurePosixPath("one") + with pytest.raises(git.GitCommandError): + assert self._toolbox.read_file(path) == "" + + repo.git.add(all=True) + assert self._toolbox.read_file(path) == "ok" + + def test_write_file(self, repo: git.Repo) -> None: + self._toolbox.write_file(PurePosixPath("one"), "hi") + + path = Path(repo.working_dir, "one") + assert not path.exists() + + repo.git.checkout_index(all=True) + with open(path) as f: + assert f.read() == "hi"