Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
template name. Otherwise an inline prompt.
* Only include files that the bot has written in draft commits.
* Add `--generate` timeout option.
* Add `Bot.state_folder` class method, returning a path to a folder specific to
the bot implementation (derived from the class' name) for storing arbitrary
data.
* Add URL and API key to `openai_bot`. Also add a compatibility version which
does not use threads, so that it can be used with tools only. Gemini only
supports the latter.
2 changes: 2 additions & 0 deletions src/git_draft/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""CLI entry point"""

from __future__ import annotations

import importlib.metadata
Expand Down
11 changes: 9 additions & 2 deletions src/git_draft/bots/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import dataclasses
from datetime import datetime
from pathlib import PurePosixPath
from pathlib import Path, PurePosixPath
from typing import Callable, Sequence

from ..common import JSONObject
from ..common import ensure_state_home, JSONObject


class Toolbox:
Expand Down Expand Up @@ -97,5 +97,12 @@ class Action:


class Bot:
@classmethod
def state_folder_path(cls) -> Path:
name = cls.__qualname__
if cls.__module__:
name = f"{cls.__module__}.{name}"
return ensure_state_home() / "bots" / name

def act(self, prompt: str, toolbox: Toolbox) -> Action:
raise NotImplementedError()
2 changes: 2 additions & 0 deletions src/git_draft/common.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Miscellaneous utilities"""

from __future__ import annotations

import dataclasses
Expand Down
7 changes: 7 additions & 0 deletions src/git_draft/editor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""CLI interactive editing utilities"""

import os
import shutil
import subprocess
Expand All @@ -24,6 +26,11 @@ def _get_tty_filename():


def open_editor(placeholder="", *, _open_tty=open) -> str:
"""Open an editor to edit a file and return its contents

The method returns once the editor is closed. It respects the `$EDITOR`
environment variable.
"""
with tempfile.NamedTemporaryFile(delete_on_close=False) as temp:
binpath = _guess_editor_binpath()
if not binpath:
Expand Down
2 changes: 2 additions & 0 deletions src/git_draft/prompt.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Prompt templating support"""

import dataclasses
import git
import jinja2
Expand Down
2 changes: 2 additions & 0 deletions src/git_draft/store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Persistent state storage"""

import contextlib
from datetime import datetime
import functools
Expand Down
9 changes: 9 additions & 0 deletions tests/git_draft/bots/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,12 @@ def test_delete_file(self):
self._toolbox.delete_file(PurePosixPath("/mock/path"))
self._hook.assert_called_once()
assert self._toolbox.operations[0].tool == "delete_file"


class FakeBot(sut.Bot):
pass


class TestBot:
def test_state_folder_path(self) -> None:
assert "bots.common_test.FakeBot" in str(FakeBot.state_folder_path())
26 changes: 13 additions & 13 deletions tests/git_draft/drafter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_write_file(self, repo: git.Repo) -> None:
assert f.read() == "hi"


class _FakeBot(Bot):
class FakeBot(Bot):
def act(self, prompt: str, toolbox: Toolbox) -> Action:
toolbox.write_file(PurePosixPath("PROMPT"), prompt)
return Action()
Expand All @@ -72,33 +72,33 @@ def _commits(self) -> Sequence[git.Commit]:
return list(self._repo.iter_commits())

def test_generate_draft(self) -> None:
self._drafter.generate_draft("hello", _FakeBot())
self._drafter.generate_draft("hello", FakeBot())
assert len(self._commits()) == 2

def test_generate_then_discard_draft(self) -> None:
self._drafter.generate_draft("hello", _FakeBot())
self._drafter.generate_draft("hello", FakeBot())
self._drafter.discard_draft()
assert len(self._commits()) == 1

def test_generate_outside_branch(self) -> None:
self._repo.git.checkout("--detach")
with pytest.raises(RuntimeError):
self._drafter.generate_draft("ok", _FakeBot())
self._drafter.generate_draft("ok", FakeBot())

def test_generate_empty_prompt(self) -> None:
with pytest.raises(ValueError):
self._drafter.generate_draft("", _FakeBot())
self._drafter.generate_draft("", FakeBot())

def test_generate_dirty_index_no_reset(self) -> None:
self._write("log")
self._repo.git.add(all=True)
with pytest.raises(ValueError):
self._drafter.generate_draft("hi", _FakeBot())
self._drafter.generate_draft("hi", FakeBot())

def test_generate_dirty_index_reset_sync(self) -> None:
self._write("log", "11")
self._repo.git.add(all=True)
self._drafter.generate_draft("hi", _FakeBot(), reset=True, sync=True)
self._drafter.generate_draft("hi", FakeBot(), reset=True, sync=True)
assert self._read("log") == "11"
assert not self._path("PROMPT").exists()
self._repo.git.checkout(".")
Expand All @@ -107,21 +107,21 @@ def test_generate_dirty_index_reset_sync(self) -> None:

def test_generate_clean_index_sync(self) -> None:
prompt = TemplatedPrompt("add-test", {"symbol": "abc"})
self._drafter.generate_draft(prompt, _FakeBot(), sync=True)
self._drafter.generate_draft(prompt, FakeBot(), sync=True)
self._repo.git.checkout(".")
assert "abc" in self._read("PROMPT")
assert len(self._commits()) == 2 # init, prompt

def test_generate_reuse_branch(self) -> None:
bot = _FakeBot()
bot = FakeBot()
self._drafter.generate_draft("prompt1", bot)
self._drafter.generate_draft("prompt2", bot)
self._repo.git.checkout(".")
assert self._read("PROMPT") == "prompt2"
assert len(self._commits()) == 3 # init, prompt, prompt

def test_generate_reuse_branch_sync(self) -> None:
bot = _FakeBot()
bot = FakeBot()
self._drafter.generate_draft("prompt1", bot)
self._drafter.generate_draft("prompt2", bot, sync=True)
assert len(self._commits()) == 4 # init, prompt, sync, prompt
Expand All @@ -132,7 +132,7 @@ def test_discard_outside_draft(self) -> None:

def test_discard_after_branch_move(self) -> None:
self._write("log", "11")
self._drafter.generate_draft("hi", _FakeBot(), sync=True)
self._drafter.generate_draft("hi", FakeBot(), sync=True)
branch = self._repo.active_branch
self._repo.git.checkout("main")
self._repo.index.commit("advance")
Expand All @@ -143,15 +143,15 @@ def test_discard_after_branch_move(self) -> None:
def test_discard_restores_worktree(self) -> None:
self._write("p1.txt", "a1")
self._write("p2.txt", "b1")
self._drafter.generate_draft("hello", _FakeBot(), sync=True)
self._drafter.generate_draft("hello", FakeBot(), sync=True)
self._write("p1.txt", "a2")
self._drafter.discard_draft(delete=True)
assert self._read("p1.txt") == "a1"
assert self._read("p2.txt") == "b1"

def test_finalize_keeps_changes(self) -> None:
self._write("p1.txt", "a1")
self._drafter.generate_draft("hello", _FakeBot(), checkout=True)
self._drafter.generate_draft("hello", FakeBot(), checkout=True)
self._write("p1.txt", "a2")
self._drafter.finalize_draft()
assert self._read("p1.txt") == "a2"
Expand Down