Skip to content

Commit a82f85f

Browse files
authored
test: add drafter tests (#32)
1 parent 2d08434 commit a82f85f

File tree

3 files changed

+148
-63
lines changed

3 files changed

+148
-63
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
dist/
22
docs/_*
3+
htmlcov/
34
.coverage

src/git_draft/drafter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __str__(self) -> str:
3838
@classmethod
3939
def active(cls, repo: git.Repo) -> _Branch | None:
4040
match: Match | None = None
41-
if repo.active_branch:
41+
if not repo.head.is_detached:
4242
match = cls._name_pattern.fullmatch(repo.active_branch.name)
4343
if not match:
4444
return None
@@ -63,7 +63,7 @@ def __init__(self, repo: git.Repo, hook: OperationHook | None) -> None:
6363
@override
6464
def _list(self) -> Sequence[PurePosixPath]:
6565
# Show staged files.
66-
return self._repo.git.ls_files()
66+
return self._repo.git.ls_files().splitlines()
6767

6868
@override
6969
def _read(self, path: PurePosixPath) -> str:
@@ -110,7 +110,7 @@ def create(
110110
)
111111

112112
def _create_branch(self, sync: bool) -> _Branch:
113-
if not self._repo.active_branch:
113+
if self._repo.head.is_detached:
114114
raise RuntimeError("No currently active branch")
115115
origin_branch = self._repo.active_branch.name
116116
origin_sha = self._repo.commit().hexsha
@@ -244,7 +244,7 @@ def _exit_draft(self, apply: bool, delete=False) -> None:
244244
and sync_sha
245245
and self._repo.commit(origin_branch).hexsha != origin_sha
246246
):
247-
raise ValueError("Parent branch has moved, please rebase")
247+
raise RuntimeError("Parent branch has moved, please rebase")
248248

249249
# We do a small dance to move back to the original branch, keeping the
250250
# draft branch untouched. See https://stackoverflow.com/a/15993574 for

tests/git_draft/drafter_test.py

Lines changed: 143 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,158 @@
11
import git
2-
import os.path as osp
3-
from pathlib import PurePosixPath
2+
from pathlib import Path, PurePosixPath
43
import pytest
4+
from typing import Sequence
55

66
from git_draft.bots import Action, Bot, Toolbox
77
import git_draft.drafter as sut
8+
from git_draft.prompt import TemplatedPrompt
89
from git_draft.store import Store
910

1011

12+
class TestToolbox:
13+
@pytest.fixture(autouse=True)
14+
def setup(self, repo: git.Repo) -> None:
15+
self._toolbox = sut._Toolbox(repo, None)
16+
17+
def test_list_files(self, repo: git.Repo) -> None:
18+
assert self._toolbox.list_files() == []
19+
names = set(["one.txt", "two.txt"])
20+
for name in names:
21+
with open(Path(repo.working_dir, name), "w") as f:
22+
f.write("ok")
23+
repo.git.add(all=True)
24+
assert set(self._toolbox.list_files()) == names
25+
26+
def test_read_file(self, repo: git.Repo) -> None:
27+
with open(Path(repo.working_dir, "one"), "w") as f:
28+
f.write("ok")
29+
30+
path = PurePosixPath("one")
31+
with pytest.raises(git.GitCommandError):
32+
assert self._toolbox.read_file(path) == ""
33+
34+
repo.git.add(all=True)
35+
assert self._toolbox.read_file(path) == "ok"
36+
37+
def test_write_file(self, repo: git.Repo) -> None:
38+
self._toolbox.write_file(PurePosixPath("one"), "hi")
39+
40+
path = Path(repo.working_dir, "one")
41+
assert not path.exists()
42+
43+
repo.git.checkout_index(all=True)
44+
with open(path) as f:
45+
assert f.read() == "hi"
46+
47+
1148
class _FakeBot(Bot):
1249
def act(self, prompt: str, toolbox: Toolbox) -> Action:
1350
toolbox.write_file(PurePosixPath("PROMPT"), prompt)
1451
return Action()
1552

1653

17-
@pytest.fixture
18-
def drafter(repo: git.Repo) -> sut.Drafter:
19-
return sut.Drafter(Store.in_memory(), repo)
20-
21-
2254
class TestDrafter:
23-
def test_generate_draft(
24-
self, drafter: sut.Drafter, repo: git.Repo
25-
) -> None:
26-
drafter.generate_draft("hello", _FakeBot())
27-
commits = list(repo.iter_commits())
28-
assert len(commits) == 2
29-
30-
def test_generate_then_discard_draft(
31-
self, drafter: sut.Drafter, repo: git.Repo
32-
) -> None:
33-
drafter.generate_draft("hello", _FakeBot())
34-
drafter.discard_draft()
35-
assert len(list(repo.iter_commits())) == 1
36-
37-
def test_discard_restores_worktree(
38-
self, drafter: sut.Drafter, repo: git.Repo
39-
) -> None:
40-
p1 = osp.join(repo.working_dir, "p1.txt")
41-
with open(p1, "w") as writer:
42-
writer.write("a1")
43-
p2 = osp.join(repo.working_dir, "p2.txt")
44-
with open(p2, "w") as writer:
45-
writer.write("b1")
46-
47-
drafter.generate_draft("hello", _FakeBot(), sync=True)
48-
with open(p1, "w") as writer:
49-
writer.write("a2")
50-
51-
drafter.discard_draft()
52-
53-
with open(p1) as reader:
54-
assert reader.read() == "a1"
55-
with open(p2) as reader:
56-
assert reader.read() == "b1"
57-
58-
def test_finalize_keeps_changes(
59-
self, drafter: sut.Drafter, repo: git.Repo
60-
) -> None:
61-
p1 = osp.join(repo.working_dir, "p1.txt")
62-
with open(p1, "w") as writer:
63-
writer.write("a1")
64-
65-
drafter.generate_draft("hello", _FakeBot(), checkout=True)
66-
with open(p1, "w") as writer:
67-
writer.write("a2")
68-
69-
drafter.finalize_draft()
70-
71-
with open(p1) as reader:
72-
assert reader.read() == "a2"
73-
with open(osp.join(repo.working_dir, "PROMPT")) as reader:
74-
assert reader.read() == "hello"
55+
@pytest.fixture(autouse=True)
56+
def setup(self, repo: git.Repo) -> None:
57+
self._repo = repo
58+
self._drafter = sut.Drafter(Store.in_memory(), repo)
59+
60+
def _path(self, name: str) -> Path:
61+
return Path(self._repo.working_dir, name)
62+
63+
def _read(self, name: str) -> str:
64+
with open(self._path(name)) as f:
65+
return f.read()
66+
67+
def _write(self, name: str, contents="") -> None:
68+
with open(self._path(name), "w") as f:
69+
f.write(contents)
70+
71+
def _commits(self) -> Sequence[git.Commit]:
72+
return list(self._repo.iter_commits())
73+
74+
def test_generate_draft(self) -> None:
75+
self._drafter.generate_draft("hello", _FakeBot())
76+
assert len(self._commits()) == 2
77+
78+
def test_generate_then_discard_draft(self) -> None:
79+
self._drafter.generate_draft("hello", _FakeBot())
80+
self._drafter.discard_draft()
81+
assert len(self._commits()) == 1
82+
83+
def test_generate_outside_branch(self) -> None:
84+
self._repo.git.checkout("--detach")
85+
with pytest.raises(RuntimeError):
86+
self._drafter.generate_draft("ok", _FakeBot())
87+
88+
def test_generate_empty_prompt(self) -> None:
89+
with pytest.raises(ValueError):
90+
self._drafter.generate_draft("", _FakeBot())
91+
92+
def test_generate_dirty_index_no_reset(self) -> None:
93+
self._write("log")
94+
self._repo.git.add(all=True)
95+
with pytest.raises(ValueError):
96+
self._drafter.generate_draft("hi", _FakeBot())
97+
98+
def test_generate_dirty_index_reset_sync(self) -> None:
99+
self._write("log", "11")
100+
self._repo.git.add(all=True)
101+
self._drafter.generate_draft("hi", _FakeBot(), reset=True, sync=True)
102+
assert self._read("log") == "11"
103+
assert not self._path("PROMPT").exists()
104+
self._repo.git.checkout(".")
105+
assert self._read("PROMPT") == "hi"
106+
assert len(self._commits()) == 3 # init, sync, prompt
107+
108+
def test_generate_clean_index_sync(self) -> None:
109+
prompt = TemplatedPrompt("add-test", {"symbol": "abc"})
110+
self._drafter.generate_draft(prompt, _FakeBot(), sync=True)
111+
self._repo.git.checkout(".")
112+
assert "abc" in self._read("PROMPT")
113+
assert len(self._commits()) == 2 # init, prompt
114+
115+
def test_generate_reuse_branch(self) -> None:
116+
bot = _FakeBot()
117+
self._drafter.generate_draft("prompt1", bot)
118+
self._drafter.generate_draft("prompt2", bot)
119+
self._repo.git.checkout(".")
120+
assert self._read("PROMPT") == "prompt2"
121+
assert len(self._commits()) == 3 # init, prompt, prompt
122+
123+
def test_generate_reuse_branch_sync(self) -> None:
124+
bot = _FakeBot()
125+
self._drafter.generate_draft("prompt1", bot)
126+
self._drafter.generate_draft("prompt2", bot, sync=True)
127+
assert len(self._commits()) == 4 # init, prompt, sync, prompt
128+
129+
def test_discard_outside_draft(self) -> None:
130+
with pytest.raises(RuntimeError):
131+
self._drafter.discard_draft()
132+
133+
def test_discard_after_branch_move(self) -> None:
134+
self._write("log", "11")
135+
self._drafter.generate_draft("hi", _FakeBot(), sync=True)
136+
branch = self._repo.active_branch
137+
self._repo.git.checkout("main")
138+
self._repo.index.commit("advance")
139+
self._repo.git.checkout(branch)
140+
with pytest.raises(RuntimeError):
141+
self._drafter.discard_draft()
142+
143+
def test_discard_restores_worktree(self) -> None:
144+
self._write("p1.txt", "a1")
145+
self._write("p2.txt", "b1")
146+
self._drafter.generate_draft("hello", _FakeBot(), sync=True)
147+
self._write("p1.txt", "a2")
148+
self._drafter.discard_draft(delete=True)
149+
assert self._read("p1.txt") == "a1"
150+
assert self._read("p2.txt") == "b1"
151+
152+
def test_finalize_keeps_changes(self) -> None:
153+
self._write("p1.txt", "a1")
154+
self._drafter.generate_draft("hello", _FakeBot(), checkout=True)
155+
self._write("p1.txt", "a2")
156+
self._drafter.finalize_draft()
157+
assert self._read("p1.txt") == "a2"
158+
assert self._read("PROMPT") == "hello"

0 commit comments

Comments
 (0)