Skip to content

Commit 5ac4ec6

Browse files
authored
feat: improve draft discard logic (#41)
1 parent 91233fc commit 5ac4ec6

File tree

4 files changed

+72
-32
lines changed

4 files changed

+72
-32
lines changed

poetry.lock

Lines changed: 7 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/git_draft/__main__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ def callback(_option, _opt, _value, parser) -> None:
4646
**kwargs,
4747
)
4848

49-
add_command("discard", help="discard the current draft")
5049
add_command("finalize", help="apply current draft to original branch")
5150
add_command("generate", help="start a new draft from a prompt")
51+
add_command("revert", help="discard the current draft")
5252

5353
parser.add_option(
5454
"-b",
@@ -135,8 +135,8 @@ def main() -> None:
135135
)
136136
elif command == "finalize":
137137
drafter.finalize_draft(delete=opts.delete)
138-
elif command == "discard":
139-
drafter.discard_draft(delete=opts.delete)
138+
elif command == "revert":
139+
drafter.revert_draft(delete=opts.delete)
140140
else:
141141
raise UnreachableError()
142142

src/git_draft/drafter.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
class _Branch:
2525
"""Draft branch"""
2626

27-
_name_pattern = re.compile(r"drafts/(.+)")
27+
_name_pattern = re.compile(r"draft/(.+)")
2828

2929
suffix: str
3030

3131
@property
3232
def name(self) -> str:
33-
return f"drafts/{self.suffix}"
33+
return f"draft/{self.suffix}"
3434

3535
def __str__(self) -> str:
3636
return self.name
@@ -202,10 +202,10 @@ def generate_draft(
202202
self._repo.git.checkout("--", ".")
203203

204204
def finalize_draft(self, delete=False) -> None:
205-
self._exit_draft(True, delete=delete)
205+
self._exit_draft(revert=False, delete=delete)
206206

207-
def discard_draft(self, delete=False) -> None:
208-
self._exit_draft(False, delete=delete)
207+
def revert_draft(self, delete=False) -> None:
208+
self._exit_draft(revert=True, delete=delete)
209209

210210
def _create_branch(self, sync: bool) -> _Branch:
211211
if self._repo.head.is_detached:
@@ -241,7 +241,7 @@ def _stage_changes(self, sync: bool) -> str | None:
241241
ref = self._repo.index.commit("draft! sync")
242242
return ref.hexsha
243243

244-
def _exit_draft(self, apply: bool, delete=False) -> None:
244+
def _exit_draft(self, *, revert: bool, delete: bool) -> None:
245245
branch = _Branch.active(self._repo)
246246
if not branch:
247247
raise RuntimeError("Not currently on a draft branch")
@@ -255,7 +255,7 @@ def _exit_draft(self, apply: bool, delete=False) -> None:
255255
[(origin_branch, origin_sha, sync_sha)] = rows
256256

257257
if (
258-
not apply
258+
revert
259259
and sync_sha
260260
and self._repo.commit(origin_branch).hexsha != origin_sha
261261
):
@@ -265,14 +265,27 @@ def _exit_draft(self, apply: bool, delete=False) -> None:
265265
# draft branch untouched. See https://stackoverflow.com/a/15993574 for
266266
# the inspiration.
267267
self._repo.git.checkout(detach=True)
268-
self._repo.git.reset("--mixed" if apply else "--hard", origin_branch)
268+
self._repo.git.reset("-N", origin_branch)
269269
self._repo.git.checkout(origin_branch)
270270

271-
if not apply and sync_sha:
272-
self._repo.git.checkout(sync_sha, "--", ".")
271+
# Finally, we revert the relevant files if needed. If a sync commit had
272+
# been created, we simply revert to it. Otherwise we compute which
273+
# files have changed due to draft commits and revert only those.
274+
if revert:
275+
if sync_sha:
276+
self._repo.git.checkout(sync_sha, "--", ".")
277+
else:
278+
diffed = set(self._changed_files(f"{origin_branch}..{branch}"))
279+
dirty = [p for p in self._changed_files("HEAD") if p in diffed]
280+
if dirty:
281+
self._repo.git.checkout("--", *dirty)
282+
273283
if delete:
274284
self._repo.git.branch("-D", branch.name)
275285

286+
def _changed_files(self, spec) -> Sequence[str]:
287+
return self._repo.git.diff(spec, name_only=True).splitlines()
288+
276289

277290
def _default_title(prompt: str) -> str:
278291
return textwrap.shorten(prompt, break_on_hyphens=False, width=72)

tests/git_draft/drafter_test.py

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,12 @@ def setup(self, repo: git.Repo) -> None:
6060
def _path(self, name: str) -> Path:
6161
return Path(self._repo.working_dir, name)
6262

63-
def _read(self, name: str) -> str:
64-
with open(self._path(name)) as f:
65-
return f.read()
63+
def _read(self, name: str) -> str | None:
64+
try:
65+
with open(self._path(name)) as f:
66+
return f.read()
67+
except FileNotFoundError:
68+
return None
6669

6770
def _write(self, name: str, contents="") -> None:
6871
with open(self._path(name), "w") as f:
@@ -95,9 +98,9 @@ def act(self, _goal: Goal, toolbox: Toolbox) -> Action:
9598
self._drafter.generate_draft("hello", CustomBot())
9699
assert self._commit_files("HEAD") == set(["p2", "p3"])
97100

98-
def test_generate_then_discard_draft(self) -> None:
101+
def test_generate_then_revert_draft(self) -> None:
99102
self._drafter.generate_draft("hello", FakeBot())
100-
self._drafter.discard_draft()
103+
self._drafter.revert_draft()
101104
assert len(self._commits()) == 1
102105

103106
def test_generate_outside_branch(self) -> None:
@@ -129,7 +132,7 @@ def test_generate_clean_index_sync(self) -> None:
129132
prompt = TemplatedPrompt("add-test", {"symbol": "abc"})
130133
self._drafter.generate_draft(prompt, FakeBot(), sync=True)
131134
self._repo.git.checkout(".")
132-
assert "abc" in self._read("PROMPT")
135+
assert "abc" in (self._read("PROMPT") or "")
133136
assert len(self._commits()) == 2 # init, prompt
134137

135138
def test_generate_reuse_branch(self) -> None:
@@ -157,29 +160,53 @@ def act(self, _goal: Goal, _toolbox: Toolbox) -> Action:
157160
assert len(self._commits()) == 2 # init, prompt
158161
assert not self._commit_files("HEAD")
159162

160-
def test_discard_outside_draft(self) -> None:
163+
def test_revert_outside_draft(self) -> None:
161164
with pytest.raises(RuntimeError):
162-
self._drafter.discard_draft()
165+
self._drafter.revert_draft()
163166

164-
def test_discard_after_branch_move(self) -> None:
167+
def test_revert_after_branch_move(self) -> None:
165168
self._write("log", "11")
166169
self._drafter.generate_draft("hi", FakeBot(), sync=True)
167170
branch = self._repo.active_branch
168171
self._repo.git.checkout("main")
169172
self._repo.index.commit("advance")
170173
self._repo.git.checkout(branch)
171174
with pytest.raises(RuntimeError):
172-
self._drafter.discard_draft()
175+
self._drafter.revert_draft()
173176

174-
def test_discard_restores_worktree(self) -> None:
177+
def test_revert_restores_worktree(self) -> None:
175178
self._write("p1.txt", "a1")
176179
self._write("p2.txt", "b1")
177180
self._drafter.generate_draft("hello", FakeBot(), sync=True)
178181
self._write("p1.txt", "a2")
179-
self._drafter.discard_draft(delete=True)
182+
self._drafter.revert_draft(delete=True)
180183
assert self._read("p1.txt") == "a1"
181184
assert self._read("p2.txt") == "b1"
182185

186+
def test_revert_keeps_untouched_files(self) -> None:
187+
class CustomBot(Bot):
188+
def act(self, _goal: Goal, toolbox: Toolbox) -> Action:
189+
toolbox.write_file(PurePosixPath("p2.txt"), "t2")
190+
toolbox.write_file(PurePosixPath("p4.txt"), "t2")
191+
return Action()
192+
193+
self._write("p1.txt", "t0")
194+
self._write("p2.txt", "t0")
195+
self._repo.git.add(all=True)
196+
self._repo.index.commit("update")
197+
self._write("p1.txt", "t1")
198+
self._write("p2.txt", "t1")
199+
self._write("p3.txt", "t1")
200+
self._drafter.generate_draft("hello", CustomBot())
201+
self._write("p1.txt", "t3")
202+
self._write("p2.txt", "t3")
203+
self._drafter.revert_draft()
204+
205+
assert self._read("p1.txt") == "t3"
206+
assert self._read("p2.txt") == "t0"
207+
assert self._read("p3.txt") == "t1"
208+
assert self._read("p4.txt") is None
209+
183210
def test_finalize_keeps_changes(self) -> None:
184211
self._write("p1.txt", "a1")
185212
self._drafter.generate_draft("hello", FakeBot(), checkout=True)

0 commit comments

Comments
 (0)