Skip to content

Commit 296ad81

Browse files
authored
feat: only include written files in draft commits (#36)
1 parent 2ef9e48 commit 296ad81

File tree

3 files changed

+32
-11
lines changed

3 files changed

+32
-11
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
* Change prompt CLI inputs to `[PROMPT] [--] [ARGS]`. If `PROMPT` does not
1515
contain any spaces or `ARGS` (or `--`) is present, it will be interpreted as a
1616
template name. Otherwise an inline prompt.
17-
* Only include files that the bot has written in draft commits.
1817
* Add `--generate` timeout option.
1918
* Add a compatibility OpenAI bot version which does not use threads, so that it
2019
can be used with tools only. Gemini only supports the latter.

src/git_draft/drafter.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class _Toolbox(Toolbox):
5959
def __init__(self, repo: git.Repo, hook: OperationHook | None) -> None:
6060
super().__init__(hook)
6161
self._repo = repo
62+
self._written = set[str]()
6263

6364
@override
6465
def _list(self) -> Sequence[PurePosixPath]:
@@ -72,17 +73,28 @@ def _read(self, path: PurePosixPath) -> str:
7273

7374
@override
7475
def _write(self, path: PurePosixPath, contents: str) -> None:
76+
self._written.add(str(path))
7577
# Update the index without touching the worktree.
7678
# https://stackoverflow.com/a/25352119
7779
with tempfile.NamedTemporaryFile(delete_on_close=False) as temp:
7880
temp.write(contents.encode("utf8"))
7981
temp.close()
80-
sha = self._repo.git.hash_object("-w", "--path", path, temp.name)
82+
sha = self._repo.git.hash_object("-w", temp.name, path=path)
8183
mode = 644 # TODO: Read from original file if it exists.
8284
self._repo.git.update_index(
83-
"--add", "--cacheinfo", f"{mode},{sha},{path}"
85+
f"{mode},{sha},{path}", add=True, cacheinfo=True
8486
)
8587

88+
def update_index(self) -> None:
89+
diff = self._repo.git.diff(name_only=True, cached=True)
90+
untouched = [
91+
path
92+
for path in diff.splitlines()
93+
if path and path not in self._written
94+
]
95+
if untouched:
96+
self._repo.git.reset("--", *untouched)
97+
8698

8799
class Drafter:
88100
"""Draft state orchestrator"""
@@ -151,6 +163,7 @@ def generate_draft(
151163
action = bot.act(prompt_contents, toolbox)
152164
end_time = time.perf_counter()
153165

166+
toolbox.update_index()
154167
title = action.title
155168
if not title:
156169
title = _default_title(prompt_contents)
@@ -197,7 +210,7 @@ def _create_branch(self, sync: bool) -> _Branch:
197210
origin_branch = self._repo.active_branch.name
198211
origin_sha = self._repo.commit().hexsha
199212

200-
self._repo.git.checkout("--detach")
213+
self._repo.git.checkout(detach=True)
201214
sync_sha = self._stage_changes(sync)
202215
suffix = _Branch.new_suffix()
203216

@@ -248,7 +261,7 @@ def _exit_draft(self, apply: bool, delete=False) -> None:
248261
# We do a small dance to move back to the original branch, keeping the
249262
# draft branch untouched. See https://stackoverflow.com/a/15993574 for
250263
# the inspiration.
251-
self._repo.git.checkout("--detach")
264+
self._repo.git.checkout(detach=True)
252265
self._repo.git.reset("--mixed" if apply else "--hard", origin_branch)
253266
self._repo.git.checkout(origin_branch)
254267

tests/git_draft/drafter_test.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,20 +71,29 @@ def _write(self, name: str, contents="") -> None:
7171
def _commits(self) -> Sequence[git.Commit]:
7272
return list(self._repo.iter_commits())
7373

74-
def _list_commit_files(self, ref: str) -> Sequence[str]:
74+
def _commit_files(self, ref: str) -> frozenset[str]:
7575
text = self._repo.git.diff_tree(
7676
ref, no_commit_id=True, name_only=True, relative=True
7777
)
78-
return text.splitlines()
78+
return frozenset(text.splitlines())
7979

8080
def test_generate_draft(self) -> None:
8181
self._drafter.generate_draft("hello", FakeBot())
8282
assert len(self._commits()) == 2
8383

84-
def test_generate_stages_worktree(self) -> None:
85-
self._write("marker", "hi")
86-
self._drafter.generate_draft("hello", FakeBot())
87-
assert "marker" in self._list_commit_files("HEAD")
84+
def test_generate_stages_then_resets_worktree(self) -> None:
85+
self._write("p1", "a")
86+
self._write("p2", "b")
87+
88+
class CustomBot(Bot):
89+
def act(self, prompt: str, toolbox: Toolbox) -> Action:
90+
assert toolbox.read_file(PurePosixPath("p1")) == "a"
91+
toolbox.write_file(PurePosixPath("p2"), "B")
92+
toolbox.write_file(PurePosixPath("p3"), "C")
93+
return Action()
94+
95+
self._drafter.generate_draft("hello", CustomBot())
96+
assert self._commit_files("HEAD") == set(["p2", "p3"])
8897

8998
def test_generate_then_discard_draft(self) -> None:
9099
self._drafter.generate_draft("hello", FakeBot())

0 commit comments

Comments
 (0)