diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index edf246f..5f91675 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -5,14 +5,17 @@ import importlib.metadata import logging import optparse +from pathlib import PurePosixPath import sys +from typing import Sequence -from .bots import Operation, load_bot +from .bots import load_bot from .common import PROGRAM, Config, UnreachableError, ensure_state_home from .drafter import Drafter from .editor import open_editor from .prompt import TemplatedPrompt from .store import Store +from .toolbox import ToolVisitor _logger = logging.getLogger(__name__) @@ -93,8 +96,24 @@ def callback(_option, _opt, _value, parser) -> None: return parser -def print_operation(op: Operation) -> None: - print(op) +class _ToolPrinter(ToolVisitor): + def on_list_files( + self, _paths: Sequence[PurePosixPath], _reason: str | None + ) -> None: + print("Listing available files...") + + def on_read_file( + self, path: PurePosixPath, _contents: str | None, _reason: str | None + ) -> None: + print(f"Reading {path}...") + + def on_write_file( + self, path: PurePosixPath, _contents: str, _reason: str | None + ) -> None: + print(f"Updated {path}.") + + def on_delete_file(self, path: PurePosixPath, _reason: str | None) -> None: + print(f"Deleted {path}.") def main() -> None: @@ -110,7 +129,6 @@ def main() -> None: drafter = Drafter.create( store=Store.persistent(), path=opts.root, - operation_hook=print_operation, ) command = getattr(opts, "command", "generate") if command == "generate": @@ -133,15 +151,20 @@ def main() -> None: else: prompt = sys.stdin.read() - drafter.generate_draft( - prompt, bot, checkout=opts.checkout, reset=opts.reset + name = drafter.generate_draft( + prompt, + bot, + tool_visitors=[_ToolPrinter()], + checkout=opts.checkout, + reset=opts.reset, ) + print(f"Generated {name}.") elif command == "finalize": name = drafter.finalize_draft(delete=opts.delete) - print(f"Finalized {name}") + print(f"Finalized {name}.") elif command == "revert": name = drafter.revert_draft(delete=opts.delete) - print(f"Reverted {name}") + print(f"Reverted {name}.") else: raise UnreachableError() diff --git a/src/git_draft/bots/__init__.py b/src/git_draft/bots/__init__.py index 85aacfa..9952ffd 100644 --- a/src/git_draft/bots/__init__.py +++ b/src/git_draft/bots/__init__.py @@ -8,7 +8,7 @@ import sys from ..common import BotConfig, reindent -from ..toolbox import Operation, OperationHook, Toolbox +from ..toolbox import Toolbox from .common import Action, Bot, Goal @@ -16,8 +16,6 @@ "Action", "Bot", "Goal", - "Operation", - "OperationHook", "Toolbox", ] diff --git a/src/git_draft/drafter.py b/src/git_draft/drafter.py index f973197..34ceee7 100644 --- a/src/git_draft/drafter.py +++ b/src/git_draft/drafter.py @@ -1,8 +1,10 @@ from __future__ import annotations import dataclasses +from datetime import datetime import json import logging +from pathlib import PurePosixPath import re import textwrap import time @@ -10,11 +12,11 @@ import git -from .bots import Bot, Goal, OperationHook -from .common import random_id +from .bots import Bot, Goal +from .common import JSONObject, random_id from .prompt import PromptRenderer, TemplatedPrompt from .store import Store, sql -from .toolbox import StagingToolbox +from .toolbox import StagingToolbox, ToolVisitor _logger = logging.getLogger(__name__) @@ -52,37 +54,26 @@ def new_suffix(): class Drafter: """Draft state orchestrator""" - def __init__( - self, store: Store, repo: git.Repo, hook: OperationHook | None = None - ) -> None: + def __init__(self, store: Store, repo: git.Repo) -> None: with store.cursor() as cursor: cursor.executescript(sql("create-tables")) self._store = store self._repo = repo - self._operation_hook = hook @classmethod - def create( - cls, - store: Store, - path: str | None = None, - operation_hook: OperationHook | None = None, - ) -> Drafter: - return cls( - store, - git.Repo(path, search_parent_directories=True), - operation_hook, - ) + def create(cls, store: Store, path: str | None = None) -> Drafter: + return cls(store, git.Repo(path, search_parent_directories=True)) def generate_draft( self, prompt: str | TemplatedPrompt, bot: Bot, + tool_visitors: Sequence[ToolVisitor] | None = None, checkout: bool = False, reset: bool = False, sync: bool = False, timeout: float | None = None, - ) -> None: + ) -> str: if isinstance(prompt, str) and not prompt.strip(): raise ValueError("Empty prompt") if self._repo.is_dirty(working_tree=False): @@ -98,7 +89,9 @@ def generate_draft( branch = self._create_branch(sync) _logger.debug("Created branch %s.", branch) - toolbox = StagingToolbox(self._repo, self._operation_hook) + operation_recorder = _OperationRecorder() + tool_visitors = [operation_recorder] + list(tool_visitors or []) + toolbox = StagingToolbox(self._repo, tool_visitors) if isinstance(prompt, TemplatedPrompt): renderer = PromptRenderer.for_toolbox(toolbox) prompt_contents = renderer.render(prompt) @@ -118,6 +111,7 @@ def generate_draft( goal = Goal(prompt_contents, timeout) action = bot.act(goal, toolbox) end_time = time.perf_counter() + walltime = end_time - start_time toolbox.trim_index() title = action.title @@ -134,7 +128,7 @@ def generate_draft( { "commit_sha": commit.hexsha, "prompt_id": prompt_id, - "walltime": end_time - start_time, + "walltime": walltime, }, ) cursor.executemany( @@ -147,13 +141,14 @@ def generate_draft( "details": json.dumps(o.details), "started_at": o.start, } - for o in toolbox.operations + for o in operation_recorder.operations ], ) - _logger.info("Generated draft.") + _logger.info("Generated draft.") if checkout: self._repo.git.checkout("--", ".") + return str(branch) def finalize_draft(self, delete=False) -> str: return self._exit_draft(revert=False, delete=delete) @@ -243,5 +238,48 @@ def _changed_files(self, spec) -> Sequence[str]: return self._repo.git.diff(spec, name_only=True).splitlines() +class _OperationRecorder(ToolVisitor): + def __init__(self) -> None: + self.operations = list[_Operation]() + + def on_list_files( + self, paths: Sequence[PurePosixPath], reason: str | None + ) -> None: + self._record(reason, "list_files", count=len(paths)) + + def on_read_file( + self, path: PurePosixPath, contents: str | None, reason: str | None + ) -> None: + self._record( + reason, + "read_file", + path=str(path), + size=-1 if contents is None else len(contents), + ) + + def on_write_file( + self, path: PurePosixPath, contents: str, reason: str | None + ) -> None: + self._record(reason, "write_file", path=str(path), size=len(contents)) + + def on_delete_file(self, path: PurePosixPath, reason: str | None) -> None: + self._record(reason, "delete_file", path=str(path)) + + def _record(self, reason: str | None, tool: str, **kwargs) -> None: + self.operations.append( + _Operation( + tool=tool, details=kwargs, reason=reason, start=datetime.now() + ) + ) + + +@dataclasses.dataclass(frozen=True) +class _Operation: + tool: str + details: JSONObject + reason: str | None + start: datetime + + def _default_title(prompt: str) -> str: return textwrap.shorten(prompt, break_on_hyphens=False, width=72) diff --git a/src/git_draft/toolbox.py b/src/git_draft/toolbox.py index 0bc450f..04ad846 100644 --- a/src/git_draft/toolbox.py +++ b/src/git_draft/toolbox.py @@ -1,15 +1,11 @@ from __future__ import annotations -import dataclasses -from datetime import datetime from pathlib import PurePosixPath import tempfile -from typing import Callable, Sequence, override +from typing import Callable, Protocol, Sequence, override import git -from .common import JSONObject - class Toolbox: """File-system intermediary @@ -25,34 +21,29 @@ class Toolbox: # TODO: Support a diff-based edit method. # https://gist.github.com/noporpoise/16e731849eb1231e86d78f9dfeca3abc - def __init__(self, hook: OperationHook | None = None) -> None: - self.operations = list[Operation]() - self._operation_hook = hook + def __init__(self, visitors: Sequence[ToolVisitor] | None = None) -> None: + self._visitors = visitors or [] - def _record_operation( - self, reason: str | None, tool: str, **kwargs - ) -> None: - op = Operation( - tool=tool, details=kwargs, reason=reason, start=datetime.now() - ) - self.operations.append(op) - if self._operation_hook: - self._operation_hook(op) + def _dispatch(self, effect: Callable[[ToolVisitor], None]) -> None: + for visitor in self._visitors: + effect(visitor) def list_files(self, reason: str | None = None) -> Sequence[PurePosixPath]: - self._record_operation(reason, "list_files") - return self._list() + paths = self._list() + self._dispatch(lambda v: v.on_list_files(paths, reason)) + return paths def read_file( self, path: PurePosixPath, reason: str | None = None, ) -> str | None: - self._record_operation(reason, "read_file", path=str(path)) try: - return self._read(path) + contents = self._read(path) except FileNotFoundError: - return None + contents = None + self._dispatch(lambda v: v.on_read_file(path, contents, reason)) + return contents def write_file( self, @@ -60,9 +51,7 @@ def write_file( contents: str, reason: str | None = None, ) -> None: - self._record_operation( - reason, "write_file", path=str(path), size=len(contents) - ) + self._dispatch(lambda v: v.on_write_file(path, contents, reason)) return self._write(path, contents) def delete_file( @@ -70,7 +59,7 @@ def delete_file( path: PurePosixPath, reason: str | None = None, ) -> None: - self._record_operation(reason, "delete_file", path=str(path)) + self._dispatch(lambda v: v.on_delete_file(path, reason)) return self._delete(path) def _list(self) -> Sequence[PurePosixPath]: @@ -86,30 +75,37 @@ def _delete(self, path: PurePosixPath) -> None: raise NotImplementedError() -@dataclasses.dataclass(frozen=True) -class Operation: - tool: str - details: JSONObject - reason: str | None - start: datetime +class ToolVisitor(Protocol): + def on_list_files( + self, paths: Sequence[PurePosixPath], reason: str | None + ) -> None: ... + + def on_read_file( + self, path: PurePosixPath, contents: str | None, reason: str | None + ) -> None: ... + def on_write_file( + self, path: PurePosixPath, contents: str, reason: str | None + ) -> None: ... -type OperationHook = Callable[[Operation], None] + def on_delete_file( + self, path: PurePosixPath, reason: str | None + ) -> None: ... class StagingToolbox(Toolbox): """Git-index backed toolbox All files are directly read from and written to the index. This allows - concurrent editing without interference. + concurrent editing without interference with the working directory. """ def __init__( - self, repo: git.Repo, hook: OperationHook | None = None + self, repo: git.Repo, visitors: Sequence[ToolVisitor] | None = None ) -> None: - super().__init__(hook) + super().__init__(visitors) self._repo = repo - self._written = set[str]() + self._updated = set[str]() @override def _list(self) -> Sequence[PurePosixPath]: @@ -123,7 +119,7 @@ def _read(self, path: PurePosixPath) -> str: @override def _write(self, path: PurePosixPath, contents: str) -> None: - self._written.add(str(path)) + self._updated.add(str(path)) # Update the index without touching the worktree. # https://stackoverflow.com/a/25352119 with tempfile.NamedTemporaryFile(delete_on_close=False) as temp: @@ -135,12 +131,18 @@ def _write(self, path: PurePosixPath, contents: str) -> None: f"{mode},{sha},{path}", add=True, cacheinfo=True ) + @override + def _delete(self, path: PurePosixPath) -> None: + self._updated.add(str(path)) + raise NotImplementedError() # TODO + def trim_index(self) -> None: + """Unstage any files which have not been written to.""" diff = self._repo.git.diff(name_only=True, cached=True) untouched = [ path for path in diff.splitlines() - if path and path not in self._written + if path and path not in self._updated ] if untouched: self._repo.git.reset("--", *untouched) diff --git a/tests/git_draft/bots/common_test.py b/tests/git_draft/bots/common_test.py index 052307a..3652388 100644 --- a/tests/git_draft/bots/common_test.py +++ b/tests/git_draft/bots/common_test.py @@ -1,54 +1,6 @@ -from pathlib import PurePosixPath -import unittest.mock - -import pytest - import git_draft.bots.common as sut -class FakeToolbox(sut.Toolbox): - def _list(self): - return [PurePosixPath("/mock/path")] - - def _read(self, path: PurePosixPath) -> str: - return "file contents" - - def _write(self, path: PurePosixPath, contents: str) -> None: - pass - - def _delete(self, path: PurePosixPath) -> None: - pass - - -class TestToolbox: - @pytest.fixture(autouse=True) - def setup(self) -> None: - self._hook = unittest.mock.MagicMock() - self._toolbox = FakeToolbox(hook=self._hook) - - def test_list_files(self): - result = self._toolbox.list_files() - assert result == [PurePosixPath("/mock/path")] - self._hook.assert_called_once() - assert self._toolbox.operations[0].tool == "list_files" - - def test_read_file(self): - content = self._toolbox.read_file(PurePosixPath("/mock/path")) - assert content == "file contents" - self._hook.assert_called_once() - assert self._toolbox.operations[0].tool == "read_file" - - def test_write_file(self): - self._toolbox.write_file(PurePosixPath("/mock/path"), "new contents") - self._hook.assert_called_once() - assert self._toolbox.operations[0].tool == "write_file" - - def test_delete_file(self): - self._toolbox.delete_file(PurePosixPath("/mock/path")) - self._hook.assert_called_once() - assert self._toolbox.operations[0].tool == "delete_file" - - class FakeBot(sut.Bot): pass