diff --git a/.gitignore b/.gitignore index ad767ff..ddadbe4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ dist/ docs/_* +htmlcov/ .coverage diff --git a/src/git_draft/drafter.py b/src/git_draft/drafter.py index 8d42be3..d683fe5 100644 --- a/src/git_draft/drafter.py +++ b/src/git_draft/drafter.py @@ -38,7 +38,7 @@ def __str__(self) -> str: @classmethod def active(cls, repo: git.Repo) -> _Branch | None: match: Match | None = None - if repo.active_branch: + if not repo.head.is_detached: match = cls._name_pattern.fullmatch(repo.active_branch.name) if not match: return None @@ -63,7 +63,7 @@ def __init__(self, repo: git.Repo, hook: OperationHook | None) -> None: @override def _list(self) -> Sequence[PurePosixPath]: # Show staged files. - return self._repo.git.ls_files() + return self._repo.git.ls_files().splitlines() @override def _read(self, path: PurePosixPath) -> str: @@ -110,7 +110,7 @@ def create( ) def _create_branch(self, sync: bool) -> _Branch: - if not self._repo.active_branch: + if self._repo.head.is_detached: raise RuntimeError("No currently active branch") origin_branch = self._repo.active_branch.name origin_sha = self._repo.commit().hexsha @@ -244,7 +244,7 @@ def _exit_draft(self, apply: bool, delete=False) -> None: and sync_sha and self._repo.commit(origin_branch).hexsha != origin_sha ): - raise ValueError("Parent branch has moved, please rebase") + raise RuntimeError("Parent branch has moved, please rebase") # We do a small dance to move back to the original branch, keeping the # draft branch untouched. See https://stackoverflow.com/a/15993574 for diff --git a/tests/git_draft/drafter_test.py b/tests/git_draft/drafter_test.py index d2bc0cc..ffe282c 100644 --- a/tests/git_draft/drafter_test.py +++ b/tests/git_draft/drafter_test.py @@ -1,74 +1,158 @@ import git -import os.path as osp -from pathlib import PurePosixPath +from pathlib import Path, PurePosixPath import pytest +from typing import Sequence from git_draft.bots import Action, Bot, Toolbox import git_draft.drafter as sut +from git_draft.prompt import TemplatedPrompt from git_draft.store import Store +class TestToolbox: + @pytest.fixture(autouse=True) + def setup(self, repo: git.Repo) -> None: + self._toolbox = sut._Toolbox(repo, None) + + def test_list_files(self, repo: git.Repo) -> None: + assert self._toolbox.list_files() == [] + names = set(["one.txt", "two.txt"]) + for name in names: + with open(Path(repo.working_dir, name), "w") as f: + f.write("ok") + repo.git.add(all=True) + assert set(self._toolbox.list_files()) == names + + def test_read_file(self, repo: git.Repo) -> None: + with open(Path(repo.working_dir, "one"), "w") as f: + f.write("ok") + + path = PurePosixPath("one") + with pytest.raises(git.GitCommandError): + assert self._toolbox.read_file(path) == "" + + repo.git.add(all=True) + assert self._toolbox.read_file(path) == "ok" + + def test_write_file(self, repo: git.Repo) -> None: + self._toolbox.write_file(PurePosixPath("one"), "hi") + + path = Path(repo.working_dir, "one") + assert not path.exists() + + repo.git.checkout_index(all=True) + with open(path) as f: + assert f.read() == "hi" + + class _FakeBot(Bot): def act(self, prompt: str, toolbox: Toolbox) -> Action: toolbox.write_file(PurePosixPath("PROMPT"), prompt) return Action() -@pytest.fixture -def drafter(repo: git.Repo) -> sut.Drafter: - return sut.Drafter(Store.in_memory(), repo) - - class TestDrafter: - def test_generate_draft( - self, drafter: sut.Drafter, repo: git.Repo - ) -> None: - drafter.generate_draft("hello", _FakeBot()) - commits = list(repo.iter_commits()) - assert len(commits) == 2 - - def test_generate_then_discard_draft( - self, drafter: sut.Drafter, repo: git.Repo - ) -> None: - drafter.generate_draft("hello", _FakeBot()) - drafter.discard_draft() - assert len(list(repo.iter_commits())) == 1 - - def test_discard_restores_worktree( - self, drafter: sut.Drafter, repo: git.Repo - ) -> None: - p1 = osp.join(repo.working_dir, "p1.txt") - with open(p1, "w") as writer: - writer.write("a1") - p2 = osp.join(repo.working_dir, "p2.txt") - with open(p2, "w") as writer: - writer.write("b1") - - drafter.generate_draft("hello", _FakeBot(), sync=True) - with open(p1, "w") as writer: - writer.write("a2") - - drafter.discard_draft() - - with open(p1) as reader: - assert reader.read() == "a1" - with open(p2) as reader: - assert reader.read() == "b1" - - def test_finalize_keeps_changes( - self, drafter: sut.Drafter, repo: git.Repo - ) -> None: - p1 = osp.join(repo.working_dir, "p1.txt") - with open(p1, "w") as writer: - writer.write("a1") - - drafter.generate_draft("hello", _FakeBot(), checkout=True) - with open(p1, "w") as writer: - writer.write("a2") - - drafter.finalize_draft() - - with open(p1) as reader: - assert reader.read() == "a2" - with open(osp.join(repo.working_dir, "PROMPT")) as reader: - assert reader.read() == "hello" + @pytest.fixture(autouse=True) + def setup(self, repo: git.Repo) -> None: + self._repo = repo + self._drafter = sut.Drafter(Store.in_memory(), repo) + + def _path(self, name: str) -> Path: + return Path(self._repo.working_dir, name) + + def _read(self, name: str) -> str: + with open(self._path(name)) as f: + return f.read() + + def _write(self, name: str, contents="") -> None: + with open(self._path(name), "w") as f: + f.write(contents) + + def _commits(self) -> Sequence[git.Commit]: + return list(self._repo.iter_commits()) + + def test_generate_draft(self) -> None: + self._drafter.generate_draft("hello", _FakeBot()) + assert len(self._commits()) == 2 + + def test_generate_then_discard_draft(self) -> None: + self._drafter.generate_draft("hello", _FakeBot()) + self._drafter.discard_draft() + assert len(self._commits()) == 1 + + def test_generate_outside_branch(self) -> None: + self._repo.git.checkout("--detach") + with pytest.raises(RuntimeError): + self._drafter.generate_draft("ok", _FakeBot()) + + def test_generate_empty_prompt(self) -> None: + with pytest.raises(ValueError): + self._drafter.generate_draft("", _FakeBot()) + + def test_generate_dirty_index_no_reset(self) -> None: + self._write("log") + self._repo.git.add(all=True) + with pytest.raises(ValueError): + self._drafter.generate_draft("hi", _FakeBot()) + + def test_generate_dirty_index_reset_sync(self) -> None: + self._write("log", "11") + self._repo.git.add(all=True) + self._drafter.generate_draft("hi", _FakeBot(), reset=True, sync=True) + assert self._read("log") == "11" + assert not self._path("PROMPT").exists() + self._repo.git.checkout(".") + assert self._read("PROMPT") == "hi" + assert len(self._commits()) == 3 # init, sync, prompt + + def test_generate_clean_index_sync(self) -> None: + prompt = TemplatedPrompt("add-test", {"symbol": "abc"}) + self._drafter.generate_draft(prompt, _FakeBot(), sync=True) + self._repo.git.checkout(".") + assert "abc" in self._read("PROMPT") + assert len(self._commits()) == 2 # init, prompt + + def test_generate_reuse_branch(self) -> None: + bot = _FakeBot() + self._drafter.generate_draft("prompt1", bot) + self._drafter.generate_draft("prompt2", bot) + self._repo.git.checkout(".") + assert self._read("PROMPT") == "prompt2" + assert len(self._commits()) == 3 # init, prompt, prompt + + def test_generate_reuse_branch_sync(self) -> None: + bot = _FakeBot() + self._drafter.generate_draft("prompt1", bot) + self._drafter.generate_draft("prompt2", bot, sync=True) + assert len(self._commits()) == 4 # init, prompt, sync, prompt + + def test_discard_outside_draft(self) -> None: + with pytest.raises(RuntimeError): + self._drafter.discard_draft() + + def test_discard_after_branch_move(self) -> None: + self._write("log", "11") + self._drafter.generate_draft("hi", _FakeBot(), sync=True) + branch = self._repo.active_branch + self._repo.git.checkout("main") + self._repo.index.commit("advance") + self._repo.git.checkout(branch) + with pytest.raises(RuntimeError): + self._drafter.discard_draft() + + def test_discard_restores_worktree(self) -> None: + self._write("p1.txt", "a1") + self._write("p2.txt", "b1") + self._drafter.generate_draft("hello", _FakeBot(), sync=True) + self._write("p1.txt", "a2") + self._drafter.discard_draft(delete=True) + assert self._read("p1.txt") == "a1" + assert self._read("p2.txt") == "b1" + + def test_finalize_keeps_changes(self) -> None: + self._write("p1.txt", "a1") + self._drafter.generate_draft("hello", _FakeBot(), checkout=True) + self._write("p1.txt", "a2") + self._drafter.finalize_draft() + assert self._read("p1.txt") == "a2" + assert self._read("PROMPT") == "hello"