Skip to content

Commit 6eb7b6e

Browse files
authored
fix: keep initial sync contents on discard (#15)
1 parent 1b97279 commit 6eb7b6e

File tree

3 files changed

+52
-12
lines changed

3 files changed

+52
-12
lines changed

src/git_draft/__main__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from .assistants import load_assistant
1010
from .common import open_editor
11-
from .manager import Manager, enclosing_repo
11+
from .manager import Manager
1212

1313

1414
logging.basicConfig(level=logging.INFO)
@@ -86,8 +86,7 @@ def callback(_option, _opt, _value, parser) -> None:
8686
def main() -> None:
8787
(opts, _args) = parser.parse_args()
8888

89-
repo = enclosing_repo()
90-
manager = Manager(repo)
89+
manager = Manager.enclosing()
9190

9291
command = getattr(opts, "command", "generate")
9392
if command == "generate":

src/git_draft/manager.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,6 @@
1616
_logger = logging.getLogger(__name__)
1717

1818

19-
def enclosing_repo(path: str | None = None) -> git.Repo:
20-
"""Returns the repository to which the given path belongs"""
21-
return git.Repo(path, search_parent_directories=True)
22-
23-
2419
class _Note:
2520
"""Structured metadata attached to a commit"""
2621

@@ -157,6 +152,10 @@ class Manager:
157152
def __init__(self, repo: git.Repo) -> None:
158153
self._repo = repo
159154

155+
@classmethod
156+
def enclosing(cls, path: str | None = None) -> Manager:
157+
return cls(git.Repo(path, search_parent_directories=True))
158+
160159
def generate_draft(
161160
self, prompt: str, assistant: Assistant, checkout=False, reset=False
162161
) -> None:
@@ -206,15 +205,17 @@ def _exit_draft(self, apply: bool, delete=False) -> None:
206205
if not apply and branch.needs_rebase(self._repo):
207206
raise ValueError("Parent branch has moved, please rebase")
208207

209-
# https://stackoverflow.com/a/15993574
210208
note = branch.init_note
209+
# https://stackoverflow.com/a/15993574
211210
self._repo.git.checkout("--detach")
212211
if apply:
213-
# We discard index (internal) changes
212+
# We discard index (internal) changes.
214213
self._repo.git.reset(note.origin_branch)
214+
self._repo.git.checkout(note.origin_branch)
215215
else:
216-
self._repo.git.reset("--hard", note.sync_sha or note.origin_branch)
217-
self._repo.git.checkout(note.origin_branch)
216+
self._repo.git.reset("--hard", note.origin_branch)
217+
if note.sync_sha:
218+
self._repo.git.checkout(note.sync_sha, "--", ".")
218219

219220
if delete:
220221
self._repo.git.branch("-D", branch.name)

tests/git_draft/manager_test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import dataclasses
22
import git
3+
import os.path as osp
34
from pathlib import PurePosixPath
45
import pytest
56
import tempfile
@@ -64,3 +65,42 @@ def test_generate_then_discard_draft(self, repo: git.Repo) -> None:
6465
manager.generate_draft("hello", _FakeAssistant())
6566
manager.discard_draft()
6667
assert len(list(repo.iter_commits())) == 1
68+
69+
def test_discard_restores_worktree(self, repo: git.Repo) -> None:
70+
manager = sut.Manager(repo)
71+
72+
p1 = osp.join(repo.working_dir, "p1.txt")
73+
with open(p1, "w") as writer:
74+
writer.write("a1")
75+
p2 = osp.join(repo.working_dir, "p2.txt")
76+
with open(p2, "w") as writer:
77+
writer.write("b1")
78+
79+
manager.generate_draft("hello", _FakeAssistant())
80+
with open(p1, "w") as writer:
81+
writer.write("a2")
82+
83+
manager.discard_draft()
84+
85+
with open(p1) as reader:
86+
assert reader.read() == "a1"
87+
with open(p2) as reader:
88+
assert reader.read() == "b1"
89+
90+
def test_finalize_keeps_changes(self, repo: git.Repo) -> None:
91+
manager = sut.Manager(repo)
92+
93+
p1 = osp.join(repo.working_dir, "p1.txt")
94+
with open(p1, "w") as writer:
95+
writer.write("a1")
96+
97+
manager.generate_draft("hello", _FakeAssistant(), checkout=True)
98+
with open(p1, "w") as writer:
99+
writer.write("a2")
100+
101+
manager.finalize_draft()
102+
103+
with open(p1) as reader:
104+
assert reader.read() == "a2"
105+
with open(osp.join(repo.working_dir, "PROMPT")) as reader:
106+
assert reader.read() == "hello"

0 commit comments

Comments
 (0)