diff --git a/poetry.lock b/poetry.lock index 18a83b3..e5290ac 100644 --- a/poetry.lock +++ b/poetry.lock @@ -371,14 +371,14 @@ files = [ [[package]] name = "jinja2" -version = "3.1.5" +version = "3.1.6" description = "A very fast and expressive template engine." optional = false python-versions = ">=3.7" groups = ["main"] files = [ - {file = "jinja2-3.1.5-py3-none-any.whl", hash = "sha256:aba0f4dc9ed8013c424088f68a5c226f7d6097ed89b246d7749c2ec4175c6adb"}, - {file = "jinja2-3.1.5.tar.gz", hash = "sha256:8fefff8dc3034e27bb80d67c671eb8a9bc424c0ef4c0826edbff304cceff43bb"}, + {file = "jinja2-3.1.6-py3-none-any.whl", hash = "sha256:85ece4451f492d0c13c5dd7c13a64681a86afae63a5f347908daf103ce6d2f67"}, + {file = "jinja2-3.1.6.tar.gz", hash = "sha256:0137fb05990d35f1275a587e9aee6d56da821fc83491a0fb838183be43f66d6d"}, ] [package.dependencies] @@ -624,15 +624,15 @@ files = [ [[package]] name = "openai" -version = "1.64.0" +version = "1.65.4" description = "The official Python library for the openai API" optional = true python-versions = ">=3.8" groups = ["main"] markers = "extra == \"openai\"" files = [ - {file = "openai-1.64.0-py3-none-any.whl", hash = "sha256:20f85cde9e95e9fbb416e3cb5a6d3119c0b28308afd6e3cc47bf100623dac623"}, - {file = "openai-1.64.0.tar.gz", hash = "sha256:2861053538704d61340da56e2f176853d19f1dc5704bc306b7597155f850d57a"}, + {file = "openai-1.65.4-py3-none-any.whl", hash = "sha256:15566d46574b94eae3d18efc2f9a4ebd1366d1d44bfc1bdafeea7a5cf8271bcb"}, + {file = "openai-1.65.4.tar.gz", hash = "sha256:0b08c58625d556f5c6654701af1023689c173eb0989ce8f73c7fd0eb22203c76"}, ] [package.dependencies] @@ -921,7 +921,7 @@ testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "no [[package]] name = "smmap" version = "5.0.2" -description = "A pure Python implementation of a sliding window memory map drafter" +description = "A pure Python implementation of a sliding window memory map manager" optional = false python-versions = ">=3.7" groups = ["main"] diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index d943b44..e0103e8 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -46,9 +46,9 @@ def callback(_option, _opt, _value, parser) -> None: **kwargs, ) - add_command("discard", help="discard the current draft") add_command("finalize", help="apply current draft to original branch") add_command("generate", help="start a new draft from a prompt") + add_command("revert", help="discard the current draft") parser.add_option( "-b", @@ -135,8 +135,8 @@ def main() -> None: ) elif command == "finalize": drafter.finalize_draft(delete=opts.delete) - elif command == "discard": - drafter.discard_draft(delete=opts.delete) + elif command == "revert": + drafter.revert_draft(delete=opts.delete) else: raise UnreachableError() diff --git a/src/git_draft/drafter.py b/src/git_draft/drafter.py index 4f15bcd..a4f60af 100644 --- a/src/git_draft/drafter.py +++ b/src/git_draft/drafter.py @@ -24,13 +24,13 @@ class _Branch: """Draft branch""" - _name_pattern = re.compile(r"drafts/(.+)") + _name_pattern = re.compile(r"draft/(.+)") suffix: str @property def name(self) -> str: - return f"drafts/{self.suffix}" + return f"draft/{self.suffix}" def __str__(self) -> str: return self.name @@ -202,10 +202,10 @@ def generate_draft( self._repo.git.checkout("--", ".") def finalize_draft(self, delete=False) -> None: - self._exit_draft(True, delete=delete) + self._exit_draft(revert=False, delete=delete) - def discard_draft(self, delete=False) -> None: - self._exit_draft(False, delete=delete) + def revert_draft(self, delete=False) -> None: + self._exit_draft(revert=True, delete=delete) def _create_branch(self, sync: bool) -> _Branch: if self._repo.head.is_detached: @@ -241,7 +241,7 @@ def _stage_changes(self, sync: bool) -> str | None: ref = self._repo.index.commit("draft! sync") return ref.hexsha - def _exit_draft(self, apply: bool, delete=False) -> None: + def _exit_draft(self, *, revert: bool, delete: bool) -> None: branch = _Branch.active(self._repo) if not branch: raise RuntimeError("Not currently on a draft branch") @@ -255,7 +255,7 @@ def _exit_draft(self, apply: bool, delete=False) -> None: [(origin_branch, origin_sha, sync_sha)] = rows if ( - not apply + revert and sync_sha and self._repo.commit(origin_branch).hexsha != origin_sha ): @@ -265,14 +265,27 @@ def _exit_draft(self, apply: bool, delete=False) -> None: # draft branch untouched. See https://stackoverflow.com/a/15993574 for # the inspiration. self._repo.git.checkout(detach=True) - self._repo.git.reset("--mixed" if apply else "--hard", origin_branch) + self._repo.git.reset("-N", origin_branch) self._repo.git.checkout(origin_branch) - if not apply and sync_sha: - self._repo.git.checkout(sync_sha, "--", ".") + # Finally, we revert the relevant files if needed. If a sync commit had + # been created, we simply revert to it. Otherwise we compute which + # files have changed due to draft commits and revert only those. + if revert: + if sync_sha: + self._repo.git.checkout(sync_sha, "--", ".") + else: + diffed = set(self._changed_files(f"{origin_branch}..{branch}")) + dirty = [p for p in self._changed_files("HEAD") if p in diffed] + if dirty: + self._repo.git.checkout("--", *dirty) + if delete: self._repo.git.branch("-D", branch.name) + def _changed_files(self, spec) -> Sequence[str]: + return self._repo.git.diff(spec, name_only=True).splitlines() + def _default_title(prompt: str) -> str: return textwrap.shorten(prompt, break_on_hyphens=False, width=72) diff --git a/tests/git_draft/drafter_test.py b/tests/git_draft/drafter_test.py index d272a7e..db2e988 100644 --- a/tests/git_draft/drafter_test.py +++ b/tests/git_draft/drafter_test.py @@ -60,9 +60,12 @@ def setup(self, repo: git.Repo) -> None: def _path(self, name: str) -> Path: return Path(self._repo.working_dir, name) - def _read(self, name: str) -> str: - with open(self._path(name)) as f: - return f.read() + def _read(self, name: str) -> str | None: + try: + with open(self._path(name)) as f: + return f.read() + except FileNotFoundError: + return None def _write(self, name: str, contents="") -> None: with open(self._path(name), "w") as f: @@ -95,9 +98,9 @@ def act(self, _goal: Goal, toolbox: Toolbox) -> Action: self._drafter.generate_draft("hello", CustomBot()) assert self._commit_files("HEAD") == set(["p2", "p3"]) - def test_generate_then_discard_draft(self) -> None: + def test_generate_then_revert_draft(self) -> None: self._drafter.generate_draft("hello", FakeBot()) - self._drafter.discard_draft() + self._drafter.revert_draft() assert len(self._commits()) == 1 def test_generate_outside_branch(self) -> None: @@ -129,7 +132,7 @@ def test_generate_clean_index_sync(self) -> None: prompt = TemplatedPrompt("add-test", {"symbol": "abc"}) self._drafter.generate_draft(prompt, FakeBot(), sync=True) self._repo.git.checkout(".") - assert "abc" in self._read("PROMPT") + assert "abc" in (self._read("PROMPT") or "") assert len(self._commits()) == 2 # init, prompt def test_generate_reuse_branch(self) -> None: @@ -157,11 +160,11 @@ def act(self, _goal: Goal, _toolbox: Toolbox) -> Action: assert len(self._commits()) == 2 # init, prompt assert not self._commit_files("HEAD") - def test_discard_outside_draft(self) -> None: + def test_revert_outside_draft(self) -> None: with pytest.raises(RuntimeError): - self._drafter.discard_draft() + self._drafter.revert_draft() - def test_discard_after_branch_move(self) -> None: + def test_revert_after_branch_move(self) -> None: self._write("log", "11") self._drafter.generate_draft("hi", FakeBot(), sync=True) branch = self._repo.active_branch @@ -169,17 +172,41 @@ def test_discard_after_branch_move(self) -> None: self._repo.index.commit("advance") self._repo.git.checkout(branch) with pytest.raises(RuntimeError): - self._drafter.discard_draft() + self._drafter.revert_draft() - def test_discard_restores_worktree(self) -> None: + def test_revert_restores_worktree(self) -> None: self._write("p1.txt", "a1") self._write("p2.txt", "b1") self._drafter.generate_draft("hello", FakeBot(), sync=True) self._write("p1.txt", "a2") - self._drafter.discard_draft(delete=True) + self._drafter.revert_draft(delete=True) assert self._read("p1.txt") == "a1" assert self._read("p2.txt") == "b1" + def test_revert_keeps_untouched_files(self) -> None: + class CustomBot(Bot): + def act(self, _goal: Goal, toolbox: Toolbox) -> Action: + toolbox.write_file(PurePosixPath("p2.txt"), "t2") + toolbox.write_file(PurePosixPath("p4.txt"), "t2") + return Action() + + self._write("p1.txt", "t0") + self._write("p2.txt", "t0") + self._repo.git.add(all=True) + self._repo.index.commit("update") + self._write("p1.txt", "t1") + self._write("p2.txt", "t1") + self._write("p3.txt", "t1") + self._drafter.generate_draft("hello", CustomBot()) + self._write("p1.txt", "t3") + self._write("p2.txt", "t3") + self._drafter.revert_draft() + + assert self._read("p1.txt") == "t3" + assert self._read("p2.txt") == "t0" + assert self._read("p3.txt") == "t1" + assert self._read("p4.txt") is None + def test_finalize_keeps_changes(self) -> None: self._write("p1.txt", "a1") self._drafter.generate_draft("hello", FakeBot(), checkout=True)