Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
dist/
docs/_*
htmlcov/
.coverage
8 changes: 4 additions & 4 deletions src/git_draft/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __str__(self) -> str:
@classmethod
def active(cls, repo: git.Repo) -> _Branch | None:
match: Match | None = None
if repo.active_branch:
if not repo.head.is_detached:
match = cls._name_pattern.fullmatch(repo.active_branch.name)
if not match:
return None
Expand All @@ -63,7 +63,7 @@ def __init__(self, repo: git.Repo, hook: OperationHook | None) -> None:
@override
def _list(self) -> Sequence[PurePosixPath]:
# Show staged files.
return self._repo.git.ls_files()
return self._repo.git.ls_files().splitlines()

@override
def _read(self, path: PurePosixPath) -> str:
Expand Down Expand Up @@ -110,7 +110,7 @@ def create(
)

def _create_branch(self, sync: bool) -> _Branch:
if not self._repo.active_branch:
if self._repo.head.is_detached:
raise RuntimeError("No currently active branch")
origin_branch = self._repo.active_branch.name
origin_sha = self._repo.commit().hexsha
Expand Down Expand Up @@ -244,7 +244,7 @@ def _exit_draft(self, apply: bool, delete=False) -> None:
and sync_sha
and self._repo.commit(origin_branch).hexsha != origin_sha
):
raise ValueError("Parent branch has moved, please rebase")
raise RuntimeError("Parent branch has moved, please rebase")

# We do a small dance to move back to the original branch, keeping the
# draft branch untouched. See https://stackoverflow.com/a/15993574 for
Expand Down
202 changes: 143 additions & 59 deletions tests/git_draft/drafter_test.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,158 @@
import git
import os.path as osp
from pathlib import PurePosixPath
from pathlib import Path, PurePosixPath
import pytest
from typing import Sequence

from git_draft.bots import Action, Bot, Toolbox
import git_draft.drafter as sut
from git_draft.prompt import TemplatedPrompt
from git_draft.store import Store


class TestToolbox:
@pytest.fixture(autouse=True)
def setup(self, repo: git.Repo) -> None:
self._toolbox = sut._Toolbox(repo, None)

def test_list_files(self, repo: git.Repo) -> None:
assert self._toolbox.list_files() == []
names = set(["one.txt", "two.txt"])
for name in names:
with open(Path(repo.working_dir, name), "w") as f:
f.write("ok")
repo.git.add(all=True)
assert set(self._toolbox.list_files()) == names

def test_read_file(self, repo: git.Repo) -> None:
with open(Path(repo.working_dir, "one"), "w") as f:
f.write("ok")

path = PurePosixPath("one")
with pytest.raises(git.GitCommandError):
assert self._toolbox.read_file(path) == ""

repo.git.add(all=True)
assert self._toolbox.read_file(path) == "ok"

def test_write_file(self, repo: git.Repo) -> None:
self._toolbox.write_file(PurePosixPath("one"), "hi")

path = Path(repo.working_dir, "one")
assert not path.exists()

repo.git.checkout_index(all=True)
with open(path) as f:
assert f.read() == "hi"


class _FakeBot(Bot):
def act(self, prompt: str, toolbox: Toolbox) -> Action:
toolbox.write_file(PurePosixPath("PROMPT"), prompt)
return Action()


@pytest.fixture
def drafter(repo: git.Repo) -> sut.Drafter:
return sut.Drafter(Store.in_memory(), repo)


class TestDrafter:
def test_generate_draft(
self, drafter: sut.Drafter, repo: git.Repo
) -> None:
drafter.generate_draft("hello", _FakeBot())
commits = list(repo.iter_commits())
assert len(commits) == 2

def test_generate_then_discard_draft(
self, drafter: sut.Drafter, repo: git.Repo
) -> None:
drafter.generate_draft("hello", _FakeBot())
drafter.discard_draft()
assert len(list(repo.iter_commits())) == 1

def test_discard_restores_worktree(
self, drafter: sut.Drafter, repo: git.Repo
) -> None:
p1 = osp.join(repo.working_dir, "p1.txt")
with open(p1, "w") as writer:
writer.write("a1")
p2 = osp.join(repo.working_dir, "p2.txt")
with open(p2, "w") as writer:
writer.write("b1")

drafter.generate_draft("hello", _FakeBot(), sync=True)
with open(p1, "w") as writer:
writer.write("a2")

drafter.discard_draft()

with open(p1) as reader:
assert reader.read() == "a1"
with open(p2) as reader:
assert reader.read() == "b1"

def test_finalize_keeps_changes(
self, drafter: sut.Drafter, repo: git.Repo
) -> None:
p1 = osp.join(repo.working_dir, "p1.txt")
with open(p1, "w") as writer:
writer.write("a1")

drafter.generate_draft("hello", _FakeBot(), checkout=True)
with open(p1, "w") as writer:
writer.write("a2")

drafter.finalize_draft()

with open(p1) as reader:
assert reader.read() == "a2"
with open(osp.join(repo.working_dir, "PROMPT")) as reader:
assert reader.read() == "hello"
@pytest.fixture(autouse=True)
def setup(self, repo: git.Repo) -> None:
self._repo = repo
self._drafter = sut.Drafter(Store.in_memory(), repo)

def _path(self, name: str) -> Path:
return Path(self._repo.working_dir, name)

def _read(self, name: str) -> str:
with open(self._path(name)) as f:
return f.read()

def _write(self, name: str, contents="") -> None:
with open(self._path(name), "w") as f:
f.write(contents)

def _commits(self) -> Sequence[git.Commit]:
return list(self._repo.iter_commits())

def test_generate_draft(self) -> None:
self._drafter.generate_draft("hello", _FakeBot())
assert len(self._commits()) == 2

def test_generate_then_discard_draft(self) -> None:
self._drafter.generate_draft("hello", _FakeBot())
self._drafter.discard_draft()
assert len(self._commits()) == 1

def test_generate_outside_branch(self) -> None:
self._repo.git.checkout("--detach")
with pytest.raises(RuntimeError):
self._drafter.generate_draft("ok", _FakeBot())

def test_generate_empty_prompt(self) -> None:
with pytest.raises(ValueError):
self._drafter.generate_draft("", _FakeBot())

def test_generate_dirty_index_no_reset(self) -> None:
self._write("log")
self._repo.git.add(all=True)
with pytest.raises(ValueError):
self._drafter.generate_draft("hi", _FakeBot())

def test_generate_dirty_index_reset_sync(self) -> None:
self._write("log", "11")
self._repo.git.add(all=True)
self._drafter.generate_draft("hi", _FakeBot(), reset=True, sync=True)
assert self._read("log") == "11"
assert not self._path("PROMPT").exists()
self._repo.git.checkout(".")
assert self._read("PROMPT") == "hi"
assert len(self._commits()) == 3 # init, sync, prompt

def test_generate_clean_index_sync(self) -> None:
prompt = TemplatedPrompt("add-test", {"symbol": "abc"})
self._drafter.generate_draft(prompt, _FakeBot(), sync=True)
self._repo.git.checkout(".")
assert "abc" in self._read("PROMPT")
assert len(self._commits()) == 2 # init, prompt

def test_generate_reuse_branch(self) -> None:
bot = _FakeBot()
self._drafter.generate_draft("prompt1", bot)
self._drafter.generate_draft("prompt2", bot)
self._repo.git.checkout(".")
assert self._read("PROMPT") == "prompt2"
assert len(self._commits()) == 3 # init, prompt, prompt

def test_generate_reuse_branch_sync(self) -> None:
bot = _FakeBot()
self._drafter.generate_draft("prompt1", bot)
self._drafter.generate_draft("prompt2", bot, sync=True)
assert len(self._commits()) == 4 # init, prompt, sync, prompt

def test_discard_outside_draft(self) -> None:
with pytest.raises(RuntimeError):
self._drafter.discard_draft()

def test_discard_after_branch_move(self) -> None:
self._write("log", "11")
self._drafter.generate_draft("hi", _FakeBot(), sync=True)
branch = self._repo.active_branch
self._repo.git.checkout("main")
self._repo.index.commit("advance")
self._repo.git.checkout(branch)
with pytest.raises(RuntimeError):
self._drafter.discard_draft()

def test_discard_restores_worktree(self) -> None:
self._write("p1.txt", "a1")
self._write("p2.txt", "b1")
self._drafter.generate_draft("hello", _FakeBot(), sync=True)
self._write("p1.txt", "a2")
self._drafter.discard_draft(delete=True)
assert self._read("p1.txt") == "a1"
assert self._read("p2.txt") == "b1"

def test_finalize_keeps_changes(self) -> None:
self._write("p1.txt", "a1")
self._drafter.generate_draft("hello", _FakeBot(), checkout=True)
self._write("p1.txt", "a2")
self._drafter.finalize_draft()
assert self._read("p1.txt") == "a2"
assert self._read("PROMPT") == "hello"