diff --git a/README.md b/README.md index 623886f..6f118bf 100644 --- a/README.md +++ b/README.md @@ -23,15 +23,6 @@ pipx install git-draft[openai] * Mechanism for reporting feedback from a bot, and possibly allowing user to interactively respond. -* Add configuration option to auto sync and `--no-sync` flag. Similar to reset. - Also rename both options to `sync` and `reset`, this will make it more natural - to support a similar config option for `accept`. -* Add `--sync` `finalize` option which creates a additional commit when - finalizing if any changes were added to the bot's output. This could be useful - training data, showing what the bot did not get right. -* Convenience `--accept` functionality for simple cases: checkout option which - applies the changes, and finalizes the draft if specified multiple times. For - example `git draft -aa add-test symbol=foo` * Support file rename tool. * https://stackoverflow.com/q/49853177/1062617 * https://stackoverflow.com/q/6658313/1062617 diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index 018d35f..cf7ee8d 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -59,7 +59,7 @@ def callback( ) add_command("finalize", help="apply current draft to original branch") - add_command("generate", help="start a new draft from a prompt") + add_command("generate", help="create or update draft from a prompt") add_command("show-drafts", short="D", help="show draft history") add_command("show-prompts", short="P", help="show prompt history") add_command("show-templates", short="T", help="show template information") @@ -67,7 +67,7 @@ def callback( parser.add_option( "-a", "--accept", - help="apply generated changes", + help="accept draft, may be repeated", action="count", ) parser.add_option( @@ -79,7 +79,7 @@ def callback( parser.add_option( "-d", "--delete", - help="delete draft after finalizing or discarding", + help="delete draft after finalizing", action="store_true", ) parser.add_option( @@ -101,12 +101,25 @@ def callback( action="store_true", ) + parser.add_option( + "--no-accept", + help="do not update worktree from draft", + dest="accept", + action="store_const", + const=0, + ) parser.add_option( "--no-reset", help="abort if there are any staged changes", dest="reset", action="store_false", ) + parser.add_option( + "--no-sync", + help="do not commit intermediate worktree changes", + dest="sync", + action="store_false", + ) parser.add_option( "--reset", help="reset index before generating a new draft", @@ -209,12 +222,15 @@ def main() -> None: # noqa: PLR0912 PLR0915 bot_name=opts.bot, prompt_transform=open_editor if editable else None, tool_visitors=[ToolPrinter()], - reset=config.auto_reset if opts.reset is None else opts.reset, - sync=opts.sync, + reset=config.reset if opts.reset is None else opts.reset, + sync=config.sync if opts.sync is None else opts.sync, ) print(f"Generated change in {name}.") case "finalize": - name = drafter.finalize_draft(delete=opts.delete) + name = drafter.finalize_draft( + delete=opts.delete, + sync=config.sync if opts.sync is None else opts.sync, + ) print(f"Finalized {name}.") case "show-drafts": table = drafter.history_table(args[0] if args else None) diff --git a/src/git_draft/common.py b/src/git_draft/common.py index 1535a8d..c8f9ed0 100644 --- a/src/git_draft/common.py +++ b/src/git_draft/common.py @@ -38,9 +38,10 @@ def ensure_state_home() -> Path: class Config: """Overall CLI configuration""" - log_level: int = logging.INFO - auto_reset: bool = True bots: Sequence[BotConfig] = dataclasses.field(default_factory=lambda: []) + log_level: int = logging.INFO + reset: bool = True + sync: bool = False @staticmethod def folder_path() -> Path: diff --git a/src/git_draft/drafter.py b/src/git_draft/drafter.py index 6715ded..a2c1f79 100644 --- a/src/git_draft/drafter.py +++ b/src/git_draft/drafter.py @@ -164,7 +164,7 @@ def generate_draft( # noqa: PLR0913 if delta and accept.value >= Accept.CHECKOUT.value: delta.apply() if accept.value >= Accept.FINALIZE.value: - self.finalize_draft(delete=accept == Accept.NO_REGRETS) + self.finalize_draft(delete=accept == Accept.NO_REGRETS, sync=sync) return str(branch) def _prepare_prompt( @@ -212,10 +212,13 @@ def _generate_change( commit.hexsha, timedelta(seconds=walltime), action, self._repo ) - def finalize_draft(self, *, delete: bool = False) -> str: + def finalize_draft( + self, *, delete: bool = False, sync: bool = False + ) -> str: branch = _Branch.active(self._repo) if not branch: raise RuntimeError("Not currently on a draft branch") + self._stage_repo(sync) with self._store.cursor() as cursor: rows = cursor.execute( @@ -223,7 +226,7 @@ def finalize_draft(self, *, delete: bool = False) -> str: ) if not rows: raise RuntimeError("Unrecognized draft branch") - [(origin_branch, origin_sha, sync_sha)] = rows + [(origin_branch, origin_sha)] = rows # We do a small dance to move back to the original branch, keeping the # draft branch untouched. See https://stackoverflow.com/a/15993574 for @@ -246,7 +249,7 @@ def _create_branch(self, sync: bool) -> _Branch: origin_sha = self._repo.commit().hexsha self._repo.git.checkout(detach=True) - sync_sha = self._stage_repo(sync) + self._stage_repo(sync) suffix = _Branch.new_suffix() with self._store.cursor() as cursor: @@ -257,7 +260,6 @@ def _create_branch(self, sync: bool) -> _Branch: "repo_path": self._repo.working_dir, "origin_branch": origin_branch, "origin_sha": origin_sha, - "sync_sha": sync_sha, }, ) diff --git a/src/git_draft/queries/add-branch.sql b/src/git_draft/queries/add-branch.sql index 5ee144f..4865cd6 100644 --- a/src/git_draft/queries/add-branch.sql +++ b/src/git_draft/queries/add-branch.sql @@ -1,2 +1,2 @@ -insert into branches (suffix, repo_path, origin_branch, origin_sha, sync_sha) - values (:suffix, :repo_path, :origin_branch, :origin_sha, :sync_sha); +insert into branches (suffix, repo_path, origin_branch, origin_sha) + values (:suffix, :repo_path, :origin_branch, :origin_sha); diff --git a/src/git_draft/queries/create-tables.sql b/src/git_draft/queries/create-tables.sql index 09506d3..5c6c58a 100644 --- a/src/git_draft/queries/create-tables.sql +++ b/src/git_draft/queries/create-tables.sql @@ -3,8 +3,7 @@ create table if not exists branches ( repo_path text not null, created_at timestamp default current_timestamp, origin_branch text not null, - origin_sha text not null, - sync_sha text + origin_sha text not null ) without rowid; create table if not exists prompts ( diff --git a/src/git_draft/queries/get-branch-by-suffix.sql b/src/git_draft/queries/get-branch-by-suffix.sql index 66d9c55..c4c34ef 100644 --- a/src/git_draft/queries/get-branch-by-suffix.sql +++ b/src/git_draft/queries/get-branch-by-suffix.sql @@ -1,3 +1,3 @@ -select origin_branch, origin_sha, sync_sha +select origin_branch, origin_sha from branches where suffix = :suffix; diff --git a/tests/git_draft/drafter_test.py b/tests/git_draft/drafter_test.py index ffea16d..d3c52f7 100644 --- a/tests/git_draft/drafter_test.py +++ b/tests/git_draft/drafter_test.py @@ -1,6 +1,7 @@ -from collections.abc import Sequence +from collections.abc import Callable, Mapping, Sequence import os from pathlib import Path, PurePosixPath +from typing import Self import git import pytest @@ -11,9 +12,30 @@ from git_draft.store import Store -class FakeBot(Bot): +class _SimpleBot(Bot): + """A simple bot which updates files to match a mapping""" + + def __init__( + self, contents: Mapping[str, str | None | Callable[[Goal], str]] + ) -> None: + self._contents = contents + + @classmethod + def noop(cls) -> Self: + return cls({}) + + @classmethod + def prompt(cls) -> Self: + return cls({"PROMPT": lambda goal: goal.prompt}) + def act(self, goal: Goal, toolbox: Toolbox) -> Action: - toolbox.write_file(PurePosixPath("PROMPT"), goal.prompt) + for key, value in self._contents.items(): + path = PurePosixPath(key) + if value is None: + toolbox.delete_file(path) + else: + contents = value if isinstance(value, str) else value(goal) + toolbox.write_file(path, contents) return Action() @@ -40,8 +62,8 @@ def _write(self, name: str, contents="") -> None: def _delete(self, name: str) -> None: os.remove(self._path(name)) - def _commits(self) -> Sequence[git.Commit]: - return list(self._repo.iter_commits()) + def _commits(self, ref: str | None = None) -> Sequence[git.Commit]: + return list(self._repo.iter_commits(rev=ref)) def _commit_files(self, ref: str) -> frozenset[str]: text = self._repo.git.diff_tree( @@ -53,7 +75,11 @@ def _checkout(self) -> None: self._repo.git.checkout("--", ".") def test_generate_draft(self) -> None: - self._drafter.generate_draft("hello", FakeBot()) + self._drafter.generate_draft("hello", _SimpleBot({"p1": "A"})) + assert len(self._commits()) == 2 + + def test_generate_empty_draft(self) -> None: + self._drafter.generate_draft("hello", _SimpleBot.noop()) assert len(self._commits()) == 2 def test_generate_stages_then_resets_worktree(self) -> None: @@ -73,22 +99,24 @@ def act(self, _goal: Goal, toolbox: Toolbox) -> Action: def test_generate_outside_branch(self) -> None: self._repo.git.checkout("--detach") with pytest.raises(RuntimeError): - self._drafter.generate_draft("ok", FakeBot()) + self._drafter.generate_draft("ok", _SimpleBot.noop()) def test_generate_empty_prompt(self) -> None: with pytest.raises(ValueError): - self._drafter.generate_draft("", FakeBot()) + self._drafter.generate_draft("", _SimpleBot.noop()) def test_generate_dirty_index_no_reset(self) -> None: self._write("log") self._repo.git.add(all=True) with pytest.raises(ValueError): - self._drafter.generate_draft("hi", FakeBot()) + self._drafter.generate_draft("hi", _SimpleBot.noop()) def test_generate_dirty_index_reset_sync(self) -> None: self._write("log", "11") self._repo.git.add(all=True) - self._drafter.generate_draft("hi", FakeBot(), reset=True, sync=True) + self._drafter.generate_draft( + "hi", _SimpleBot.prompt(), reset=True, sync=True + ) assert self._read("log") == "11" assert not self._path("PROMPT").exists() self._repo.git.checkout(".") @@ -97,48 +125,41 @@ def test_generate_dirty_index_reset_sync(self) -> None: def test_generate_clean_index_sync(self) -> None: prompt = TemplatedPrompt("add-test", {"symbol": "abc"}) - self._drafter.generate_draft(prompt, FakeBot(), sync=True) + self._drafter.generate_draft( + prompt, _SimpleBot({"p1": "abc"}), sync=True + ) self._repo.git.checkout(".") - assert "abc" in (self._read("PROMPT") or "") - assert len(self._commits()) == 2 # init, prompt + assert "abc" in (self._read("p1") or "") + assert len(self._commits()) == 2 # sync, prompt def test_generate_reuse_branch(self) -> None: - bot = FakeBot() + bot = _SimpleBot({"prompt": lambda goal: goal.prompt}) self._drafter.generate_draft("prompt1", bot) self._drafter.generate_draft("prompt2", bot) self._repo.git.checkout(".") - assert self._read("PROMPT") == "prompt2" + assert self._read("prompt") == "prompt2" assert len(self._commits()) == 3 # init, prompt, prompt def test_generate_reuse_branch_sync(self) -> None: - bot = FakeBot() + bot = _SimpleBot({"p1": "A"}) self._drafter.generate_draft("prompt1", bot) self._drafter.generate_draft("prompt2", bot, sync=True) assert len(self._commits()) == 4 # init, prompt, sync, prompt def test_generate_noop(self) -> None: self._write("unrelated", "a") - - class CustomBot(Bot): - def act(self, _goal: Goal, _toolbox: Toolbox) -> Action: - return Action() - - self._drafter.generate_draft("prompt", CustomBot()) + self._drafter.generate_draft("prompt", _SimpleBot.noop()) assert len(self._commits()) == 2 # init, prompt assert not self._commit_files("HEAD") def test_generate_accept_checkout(self) -> None: self._write("p1", "A") self._write("p2", "B") - - class CustomBot(Bot): - def act(self, _goal: Goal, toolbox: Toolbox) -> Action: - toolbox.write_file(PurePosixPath("p1"), "C") - toolbox.write_file(PurePosixPath("p3"), "D") - return Action() - self._drafter.generate_draft( - "hello", CustomBot(), accept=sut.Accept.CHECKOUT, sync=True + "hello", + _SimpleBot({"p1": "C", "p3": "D"}), + accept=sut.Accept.CHECKOUT, + sync=True, ) assert self._read("p1") == "C" assert self._read("p2") == "B" @@ -146,31 +167,20 @@ def act(self, _goal: Goal, toolbox: Toolbox) -> Action: def test_generate_accept_checkout_conflict(self) -> None: self._write("p1", "A") - - class CustomBot(Bot): - def act(self, _goal: Goal, toolbox: Toolbox) -> Action: - toolbox.write_file(PurePosixPath("p1"), "B") - toolbox.write_file(PurePosixPath("p2"), "C") - return Action() - with pytest.raises(sut.ConflictError): self._drafter.generate_draft( - "hello", CustomBot(), accept=sut.Accept.CHECKOUT + "hello", + _SimpleBot({"p1": "B", "p2": "C"}), + accept=sut.Accept.CHECKOUT ) assert """<<<<<<< ours\nA""" in (self._read("p1") or "") assert self._read("p2") == "C" def test_generate_accept_finalize(self) -> None: self._write("p1", "A") - - class CustomBot(Bot): - def act(self, _goal: Goal, toolbox: Toolbox) -> Action: - toolbox.write_file(PurePosixPath("p2"), "B") - return Action() - self._drafter.generate_draft( "hello", - CustomBot(), + _SimpleBot({"p2": "B"}), accept=sut.Accept.FINALIZE, ) assert self._read("p1") == "A" @@ -178,33 +188,41 @@ def act(self, _goal: Goal, toolbox: Toolbox) -> Action: assert self._repo.active_branch.name == "main" def test_delete_unknown_file(self) -> None: - class CustomBot(Bot): - def act(self, _goal: Goal, toolbox: Toolbox) -> Action: - toolbox.delete_file(PurePosixPath("p1")) - return Action() - - self._drafter.generate_draft("hello", CustomBot()) + self._drafter.generate_draft("hello", _SimpleBot({"p1": None})) def test_finalize_keeps_changes(self) -> None: self._write("p1.txt", "a1") - self._drafter.generate_draft("hello", FakeBot()) + self._drafter.generate_draft("hello", _SimpleBot.prompt()) self._checkout() self._write("p1.txt", "a2") self._drafter.finalize_draft() assert self._read("p1.txt") == "a2" assert self._read("PROMPT") == "hello" + def test_finalize_and_sync(self) -> None: + branch = self._drafter.generate_draft( + "hello", + _SimpleBot.prompt(), + accept=sut.Accept.CHECKOUT, + ) + self._write("PROMPT", "a2") + self._drafter.finalize_draft(sync=True) + assert self._read("PROMPT") == "a2" + commits = self._commits(branch) + assert len(commits) == 3 # init, prompt, sync + assert "sync" in commits[0].message + def test_history_table_empty(self) -> None: table = self._drafter.history_table() assert not table def test_history_table_active_draft(self) -> None: - self._drafter.generate_draft("hello", FakeBot()) + self._drafter.generate_draft("hello", _SimpleBot.noop()) table = self._drafter.history_table() assert table def test_latest_draft_prompt(self) -> None: - bot = FakeBot() + bot = _SimpleBot.noop() prompt1 = "First prompt" self._drafter.generate_draft(prompt1, bot)