Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions src/git_draft/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def callback(_option, _opt, _value, parser) -> None:
**kwargs,
)

add_command("discard", help="discard the current draft")
add_command("finalize", help="apply current draft to original branch")
add_command("generate", help="start a new draft from a prompt")
add_command("revert", help="discard the current draft")

parser.add_option(
"-b",
Expand Down Expand Up @@ -135,8 +135,8 @@ def main() -> None:
)
elif command == "finalize":
drafter.finalize_draft(delete=opts.delete)
elif command == "discard":
drafter.discard_draft(delete=opts.delete)
elif command == "revert":
drafter.revert_draft(delete=opts.delete)
else:
raise UnreachableError()

Expand Down
33 changes: 23 additions & 10 deletions src/git_draft/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
class _Branch:
"""Draft branch"""

_name_pattern = re.compile(r"drafts/(.+)")
_name_pattern = re.compile(r"draft/(.+)")

suffix: str

@property
def name(self) -> str:
return f"drafts/{self.suffix}"
return f"draft/{self.suffix}"

def __str__(self) -> str:
return self.name
Expand Down Expand Up @@ -202,10 +202,10 @@ def generate_draft(
self._repo.git.checkout("--", ".")

def finalize_draft(self, delete=False) -> None:
self._exit_draft(True, delete=delete)
self._exit_draft(revert=False, delete=delete)

def discard_draft(self, delete=False) -> None:
self._exit_draft(False, delete=delete)
def revert_draft(self, delete=False) -> None:
self._exit_draft(revert=True, delete=delete)

def _create_branch(self, sync: bool) -> _Branch:
if self._repo.head.is_detached:
Expand Down Expand Up @@ -241,7 +241,7 @@ def _stage_changes(self, sync: bool) -> str | None:
ref = self._repo.index.commit("draft! sync")
return ref.hexsha

def _exit_draft(self, apply: bool, delete=False) -> None:
def _exit_draft(self, *, revert: bool, delete: bool) -> None:
branch = _Branch.active(self._repo)
if not branch:
raise RuntimeError("Not currently on a draft branch")
Expand All @@ -255,7 +255,7 @@ def _exit_draft(self, apply: bool, delete=False) -> None:
[(origin_branch, origin_sha, sync_sha)] = rows

if (
not apply
revert
and sync_sha
and self._repo.commit(origin_branch).hexsha != origin_sha
):
Expand All @@ -265,14 +265,27 @@ def _exit_draft(self, apply: bool, delete=False) -> None:
# draft branch untouched. See https://stackoverflow.com/a/15993574 for
# the inspiration.
self._repo.git.checkout(detach=True)
self._repo.git.reset("--mixed" if apply else "--hard", origin_branch)
self._repo.git.reset("-N", origin_branch)
self._repo.git.checkout(origin_branch)

if not apply and sync_sha:
self._repo.git.checkout(sync_sha, "--", ".")
# Finally, we revert the relevant files if needed. If a sync commit had
# been created, we simply revert to it. Otherwise we compute which
# files have changed due to draft commits and revert only those.
if revert:
if sync_sha:
self._repo.git.checkout(sync_sha, "--", ".")
else:
diffed = set(self._changed_files(f"{origin_branch}..{branch}"))
dirty = [p for p in self._changed_files("HEAD") if p in diffed]
if dirty:
self._repo.git.checkout("--", *dirty)

if delete:
self._repo.git.branch("-D", branch.name)

def _changed_files(self, spec) -> Sequence[str]:
return self._repo.git.diff(spec, name_only=True).splitlines()


def _default_title(prompt: str) -> str:
return textwrap.shorten(prompt, break_on_hyphens=False, width=72)
51 changes: 39 additions & 12 deletions tests/git_draft/drafter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,12 @@ def setup(self, repo: git.Repo) -> None:
def _path(self, name: str) -> Path:
return Path(self._repo.working_dir, name)

def _read(self, name: str) -> str:
with open(self._path(name)) as f:
return f.read()
def _read(self, name: str) -> str | None:
try:
with open(self._path(name)) as f:
return f.read()
except FileNotFoundError:
return None

def _write(self, name: str, contents="") -> None:
with open(self._path(name), "w") as f:
Expand Down Expand Up @@ -95,9 +98,9 @@ def act(self, _goal: Goal, toolbox: Toolbox) -> Action:
self._drafter.generate_draft("hello", CustomBot())
assert self._commit_files("HEAD") == set(["p2", "p3"])

def test_generate_then_discard_draft(self) -> None:
def test_generate_then_revert_draft(self) -> None:
self._drafter.generate_draft("hello", FakeBot())
self._drafter.discard_draft()
self._drafter.revert_draft()
assert len(self._commits()) == 1

def test_generate_outside_branch(self) -> None:
Expand Down Expand Up @@ -129,7 +132,7 @@ def test_generate_clean_index_sync(self) -> None:
prompt = TemplatedPrompt("add-test", {"symbol": "abc"})
self._drafter.generate_draft(prompt, FakeBot(), sync=True)
self._repo.git.checkout(".")
assert "abc" in self._read("PROMPT")
assert "abc" in (self._read("PROMPT") or "")
assert len(self._commits()) == 2 # init, prompt

def test_generate_reuse_branch(self) -> None:
Expand Down Expand Up @@ -157,29 +160,53 @@ def act(self, _goal: Goal, _toolbox: Toolbox) -> Action:
assert len(self._commits()) == 2 # init, prompt
assert not self._commit_files("HEAD")

def test_discard_outside_draft(self) -> None:
def test_revert_outside_draft(self) -> None:
with pytest.raises(RuntimeError):
self._drafter.discard_draft()
self._drafter.revert_draft()

def test_discard_after_branch_move(self) -> None:
def test_revert_after_branch_move(self) -> None:
self._write("log", "11")
self._drafter.generate_draft("hi", FakeBot(), sync=True)
branch = self._repo.active_branch
self._repo.git.checkout("main")
self._repo.index.commit("advance")
self._repo.git.checkout(branch)
with pytest.raises(RuntimeError):
self._drafter.discard_draft()
self._drafter.revert_draft()

def test_discard_restores_worktree(self) -> None:
def test_revert_restores_worktree(self) -> None:
self._write("p1.txt", "a1")
self._write("p2.txt", "b1")
self._drafter.generate_draft("hello", FakeBot(), sync=True)
self._write("p1.txt", "a2")
self._drafter.discard_draft(delete=True)
self._drafter.revert_draft(delete=True)
assert self._read("p1.txt") == "a1"
assert self._read("p2.txt") == "b1"

def test_revert_keeps_untouched_files(self) -> None:
class CustomBot(Bot):
def act(self, _goal: Goal, toolbox: Toolbox) -> Action:
toolbox.write_file(PurePosixPath("p2.txt"), "t2")
toolbox.write_file(PurePosixPath("p4.txt"), "t2")
return Action()

self._write("p1.txt", "t0")
self._write("p2.txt", "t0")
self._repo.git.add(all=True)
self._repo.index.commit("update")
self._write("p1.txt", "t1")
self._write("p2.txt", "t1")
self._write("p3.txt", "t1")
self._drafter.generate_draft("hello", CustomBot())
self._write("p1.txt", "t3")
self._write("p2.txt", "t3")
self._drafter.revert_draft()

assert self._read("p1.txt") == "t3"
assert self._read("p2.txt") == "t0"
assert self._read("p3.txt") == "t1"
assert self._read("p4.txt") is None

def test_finalize_keeps_changes(self) -> None:
self._write("p1.txt", "a1")
self._drafter.generate_draft("hello", FakeBot(), checkout=True)
Expand Down