diff --git a/README.md b/README.md index 3045b47..6db7a00 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,6 @@ * Change prompt CLI inputs to `[PROMPT] [--] [ARGS]`. If `PROMPT` does not contain any spaces or `ARGS` (or `--`) is present, it will be interpreted as a template name. Otherwise an inline prompt. -* Only include files that the bot has written in draft commits. * Add `--generate` timeout option. * Add a compatibility OpenAI bot version which does not use threads, so that it can be used with tools only. Gemini only supports the latter. diff --git a/src/git_draft/drafter.py b/src/git_draft/drafter.py index 29df65a..ad29375 100644 --- a/src/git_draft/drafter.py +++ b/src/git_draft/drafter.py @@ -59,6 +59,7 @@ class _Toolbox(Toolbox): def __init__(self, repo: git.Repo, hook: OperationHook | None) -> None: super().__init__(hook) self._repo = repo + self._written = set[str]() @override def _list(self) -> Sequence[PurePosixPath]: @@ -72,17 +73,28 @@ def _read(self, path: PurePosixPath) -> str: @override def _write(self, path: PurePosixPath, contents: str) -> None: + self._written.add(str(path)) # Update the index without touching the worktree. # https://stackoverflow.com/a/25352119 with tempfile.NamedTemporaryFile(delete_on_close=False) as temp: temp.write(contents.encode("utf8")) temp.close() - sha = self._repo.git.hash_object("-w", "--path", path, temp.name) + sha = self._repo.git.hash_object("-w", temp.name, path=path) mode = 644 # TODO: Read from original file if it exists. self._repo.git.update_index( - "--add", "--cacheinfo", f"{mode},{sha},{path}" + f"{mode},{sha},{path}", add=True, cacheinfo=True ) + def update_index(self) -> None: + diff = self._repo.git.diff(name_only=True, cached=True) + untouched = [ + path + for path in diff.splitlines() + if path and path not in self._written + ] + if untouched: + self._repo.git.reset("--", *untouched) + class Drafter: """Draft state orchestrator""" @@ -151,6 +163,7 @@ def generate_draft( action = bot.act(prompt_contents, toolbox) end_time = time.perf_counter() + toolbox.update_index() title = action.title if not title: title = _default_title(prompt_contents) @@ -197,7 +210,7 @@ def _create_branch(self, sync: bool) -> _Branch: origin_branch = self._repo.active_branch.name origin_sha = self._repo.commit().hexsha - self._repo.git.checkout("--detach") + self._repo.git.checkout(detach=True) sync_sha = self._stage_changes(sync) suffix = _Branch.new_suffix() @@ -248,7 +261,7 @@ def _exit_draft(self, apply: bool, delete=False) -> None: # We do a small dance to move back to the original branch, keeping the # draft branch untouched. See https://stackoverflow.com/a/15993574 for # the inspiration. - self._repo.git.checkout("--detach") + self._repo.git.checkout(detach=True) self._repo.git.reset("--mixed" if apply else "--hard", origin_branch) self._repo.git.checkout(origin_branch) diff --git a/tests/git_draft/drafter_test.py b/tests/git_draft/drafter_test.py index 4db31d0..059b044 100644 --- a/tests/git_draft/drafter_test.py +++ b/tests/git_draft/drafter_test.py @@ -71,20 +71,29 @@ def _write(self, name: str, contents="") -> None: def _commits(self) -> Sequence[git.Commit]: return list(self._repo.iter_commits()) - def _list_commit_files(self, ref: str) -> Sequence[str]: + def _commit_files(self, ref: str) -> frozenset[str]: text = self._repo.git.diff_tree( ref, no_commit_id=True, name_only=True, relative=True ) - return text.splitlines() + return frozenset(text.splitlines()) def test_generate_draft(self) -> None: self._drafter.generate_draft("hello", FakeBot()) assert len(self._commits()) == 2 - def test_generate_stages_worktree(self) -> None: - self._write("marker", "hi") - self._drafter.generate_draft("hello", FakeBot()) - assert "marker" in self._list_commit_files("HEAD") + def test_generate_stages_then_resets_worktree(self) -> None: + self._write("p1", "a") + self._write("p2", "b") + + class CustomBot(Bot): + def act(self, prompt: str, toolbox: Toolbox) -> Action: + assert toolbox.read_file(PurePosixPath("p1")) == "a" + toolbox.write_file(PurePosixPath("p2"), "B") + toolbox.write_file(PurePosixPath("p3"), "C") + return Action() + + self._drafter.generate_draft("hello", CustomBot()) + assert self._commit_files("HEAD") == set(["p2", "p3"]) def test_generate_then_discard_draft(self) -> None: self._drafter.generate_draft("hello", FakeBot())