Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/git_draft/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .assistants import load_assistant
from .common import open_editor
from .manager import Manager, enclosing_repo
from .manager import Manager


logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -86,8 +86,7 @@ def callback(_option, _opt, _value, parser) -> None:
def main() -> None:
(opts, _args) = parser.parse_args()

repo = enclosing_repo()
manager = Manager(repo)
manager = Manager.enclosing()

command = getattr(opts, "command", "generate")
if command == "generate":
Expand Down
19 changes: 10 additions & 9 deletions src/git_draft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@
_logger = logging.getLogger(__name__)


def enclosing_repo(path: str | None = None) -> git.Repo:
"""Returns the repository to which the given path belongs"""
return git.Repo(path, search_parent_directories=True)


class _Note:
"""Structured metadata attached to a commit"""

Expand Down Expand Up @@ -157,6 +152,10 @@ class Manager:
def __init__(self, repo: git.Repo) -> None:
self._repo = repo

@classmethod
def enclosing(cls, path: str | None = None) -> Manager:
return cls(git.Repo(path, search_parent_directories=True))

def generate_draft(
self, prompt: str, assistant: Assistant, checkout=False, reset=False
) -> None:
Expand Down Expand Up @@ -206,15 +205,17 @@ def _exit_draft(self, apply: bool, delete=False) -> None:
if not apply and branch.needs_rebase(self._repo):
raise ValueError("Parent branch has moved, please rebase")

# https://stackoverflow.com/a/15993574
note = branch.init_note
# https://stackoverflow.com/a/15993574
self._repo.git.checkout("--detach")
if apply:
# We discard index (internal) changes
# We discard index (internal) changes.
self._repo.git.reset(note.origin_branch)
self._repo.git.checkout(note.origin_branch)
else:
self._repo.git.reset("--hard", note.sync_sha or note.origin_branch)
self._repo.git.checkout(note.origin_branch)
self._repo.git.reset("--hard", note.origin_branch)
if note.sync_sha:
self._repo.git.checkout(note.sync_sha, "--", ".")

if delete:
self._repo.git.branch("-D", branch.name)
40 changes: 40 additions & 0 deletions tests/git_draft/manager_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dataclasses
import git
import os.path as osp
from pathlib import PurePosixPath
import pytest
import tempfile
Expand Down Expand Up @@ -64,3 +65,42 @@ def test_generate_then_discard_draft(self, repo: git.Repo) -> None:
manager.generate_draft("hello", _FakeAssistant())
manager.discard_draft()
assert len(list(repo.iter_commits())) == 1

def test_discard_restores_worktree(self, repo: git.Repo) -> None:
manager = sut.Manager(repo)

p1 = osp.join(repo.working_dir, "p1.txt")
with open(p1, "w") as writer:
writer.write("a1")
p2 = osp.join(repo.working_dir, "p2.txt")
with open(p2, "w") as writer:
writer.write("b1")

manager.generate_draft("hello", _FakeAssistant())
with open(p1, "w") as writer:
writer.write("a2")

manager.discard_draft()

with open(p1) as reader:
assert reader.read() == "a1"
with open(p2) as reader:
assert reader.read() == "b1"

def test_finalize_keeps_changes(self, repo: git.Repo) -> None:
manager = sut.Manager(repo)

p1 = osp.join(repo.working_dir, "p1.txt")
with open(p1, "w") as writer:
writer.write("a1")

manager.generate_draft("hello", _FakeAssistant(), checkout=True)
with open(p1, "w") as writer:
writer.write("a2")

manager.finalize_draft()

with open(p1) as reader:
assert reader.read() == "a2"
with open(osp.join(repo.working_dir, "PROMPT")) as reader:
assert reader.read() == "hello"