diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index 1450bf5..ca2ae18 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -8,7 +8,7 @@ from .assistants import load_assistant from .common import open_editor -from .manager import Manager, enclosing_repo +from .manager import Manager logging.basicConfig(level=logging.INFO) @@ -86,8 +86,7 @@ def callback(_option, _opt, _value, parser) -> None: def main() -> None: (opts, _args) = parser.parse_args() - repo = enclosing_repo() - manager = Manager(repo) + manager = Manager.enclosing() command = getattr(opts, "command", "generate") if command == "generate": diff --git a/src/git_draft/manager.py b/src/git_draft/manager.py index bf717fe..706d179 100644 --- a/src/git_draft/manager.py +++ b/src/git_draft/manager.py @@ -16,11 +16,6 @@ _logger = logging.getLogger(__name__) -def enclosing_repo(path: str | None = None) -> git.Repo: - """Returns the repository to which the given path belongs""" - return git.Repo(path, search_parent_directories=True) - - class _Note: """Structured metadata attached to a commit""" @@ -157,6 +152,10 @@ class Manager: def __init__(self, repo: git.Repo) -> None: self._repo = repo + @classmethod + def enclosing(cls, path: str | None = None) -> Manager: + return cls(git.Repo(path, search_parent_directories=True)) + def generate_draft( self, prompt: str, assistant: Assistant, checkout=False, reset=False ) -> None: @@ -206,15 +205,17 @@ def _exit_draft(self, apply: bool, delete=False) -> None: if not apply and branch.needs_rebase(self._repo): raise ValueError("Parent branch has moved, please rebase") - # https://stackoverflow.com/a/15993574 note = branch.init_note + # https://stackoverflow.com/a/15993574 self._repo.git.checkout("--detach") if apply: - # We discard index (internal) changes + # We discard index (internal) changes. self._repo.git.reset(note.origin_branch) + self._repo.git.checkout(note.origin_branch) else: - self._repo.git.reset("--hard", note.sync_sha or note.origin_branch) - self._repo.git.checkout(note.origin_branch) + self._repo.git.reset("--hard", note.origin_branch) + if note.sync_sha: + self._repo.git.checkout(note.sync_sha, "--", ".") if delete: self._repo.git.branch("-D", branch.name) diff --git a/tests/git_draft/manager_test.py b/tests/git_draft/manager_test.py index 475679b..c6c551a 100644 --- a/tests/git_draft/manager_test.py +++ b/tests/git_draft/manager_test.py @@ -1,5 +1,6 @@ import dataclasses import git +import os.path as osp from pathlib import PurePosixPath import pytest import tempfile @@ -64,3 +65,42 @@ def test_generate_then_discard_draft(self, repo: git.Repo) -> None: manager.generate_draft("hello", _FakeAssistant()) manager.discard_draft() assert len(list(repo.iter_commits())) == 1 + + def test_discard_restores_worktree(self, repo: git.Repo) -> None: + manager = sut.Manager(repo) + + p1 = osp.join(repo.working_dir, "p1.txt") + with open(p1, "w") as writer: + writer.write("a1") + p2 = osp.join(repo.working_dir, "p2.txt") + with open(p2, "w") as writer: + writer.write("b1") + + manager.generate_draft("hello", _FakeAssistant()) + with open(p1, "w") as writer: + writer.write("a2") + + manager.discard_draft() + + with open(p1) as reader: + assert reader.read() == "a1" + with open(p2) as reader: + assert reader.read() == "b1" + + def test_finalize_keeps_changes(self, repo: git.Repo) -> None: + manager = sut.Manager(repo) + + p1 = osp.join(repo.working_dir, "p1.txt") + with open(p1, "w") as writer: + writer.write("a1") + + manager.generate_draft("hello", _FakeAssistant(), checkout=True) + with open(p1, "w") as writer: + writer.write("a2") + + manager.finalize_draft() + + with open(p1) as reader: + assert reader.read() == "a2" + with open(osp.join(repo.working_dir, "PROMPT")) as reader: + assert reader.read() == "hello"