Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
* Change prompt CLI inputs to `[PROMPT] [--] [ARGS]`. If `PROMPT` does not
contain any spaces or `ARGS` (or `--`) is present, it will be interpreted as a
template name. Otherwise an inline prompt.
* Only include files that the bot has written in draft commits.
* Add `--generate` timeout option.
* Add a compatibility OpenAI bot version which does not use threads, so that it
can be used with tools only. Gemini only supports the latter.
21 changes: 17 additions & 4 deletions src/git_draft/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class _Toolbox(Toolbox):
def __init__(self, repo: git.Repo, hook: OperationHook | None) -> None:
super().__init__(hook)
self._repo = repo
self._written = set[str]()

@override
def _list(self) -> Sequence[PurePosixPath]:
Expand All @@ -72,17 +73,28 @@ def _read(self, path: PurePosixPath) -> str:

@override
def _write(self, path: PurePosixPath, contents: str) -> None:
self._written.add(str(path))
# Update the index without touching the worktree.
# https://stackoverflow.com/a/25352119
with tempfile.NamedTemporaryFile(delete_on_close=False) as temp:
temp.write(contents.encode("utf8"))
temp.close()
sha = self._repo.git.hash_object("-w", "--path", path, temp.name)
sha = self._repo.git.hash_object("-w", temp.name, path=path)
mode = 644 # TODO: Read from original file if it exists.
self._repo.git.update_index(
"--add", "--cacheinfo", f"{mode},{sha},{path}"
f"{mode},{sha},{path}", add=True, cacheinfo=True
)

def update_index(self) -> None:
diff = self._repo.git.diff(name_only=True, cached=True)
untouched = [
path
for path in diff.splitlines()
if path and path not in self._written
]
if untouched:
self._repo.git.reset("--", *untouched)


class Drafter:
"""Draft state orchestrator"""
Expand Down Expand Up @@ -151,6 +163,7 @@ def generate_draft(
action = bot.act(prompt_contents, toolbox)
end_time = time.perf_counter()

toolbox.update_index()
title = action.title
if not title:
title = _default_title(prompt_contents)
Expand Down Expand Up @@ -197,7 +210,7 @@ def _create_branch(self, sync: bool) -> _Branch:
origin_branch = self._repo.active_branch.name
origin_sha = self._repo.commit().hexsha

self._repo.git.checkout("--detach")
self._repo.git.checkout(detach=True)
sync_sha = self._stage_changes(sync)
suffix = _Branch.new_suffix()

Expand Down Expand Up @@ -248,7 +261,7 @@ def _exit_draft(self, apply: bool, delete=False) -> None:
# We do a small dance to move back to the original branch, keeping the
# draft branch untouched. See https://stackoverflow.com/a/15993574 for
# the inspiration.
self._repo.git.checkout("--detach")
self._repo.git.checkout(detach=True)
self._repo.git.reset("--mixed" if apply else "--hard", origin_branch)
self._repo.git.checkout(origin_branch)

Expand Down
21 changes: 15 additions & 6 deletions tests/git_draft/drafter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,29 @@ def _write(self, name: str, contents="") -> None:
def _commits(self) -> Sequence[git.Commit]:
return list(self._repo.iter_commits())

def _list_commit_files(self, ref: str) -> Sequence[str]:
def _commit_files(self, ref: str) -> frozenset[str]:
text = self._repo.git.diff_tree(
ref, no_commit_id=True, name_only=True, relative=True
)
return text.splitlines()
return frozenset(text.splitlines())

def test_generate_draft(self) -> None:
self._drafter.generate_draft("hello", FakeBot())
assert len(self._commits()) == 2

def test_generate_stages_worktree(self) -> None:
self._write("marker", "hi")
self._drafter.generate_draft("hello", FakeBot())
assert "marker" in self._list_commit_files("HEAD")
def test_generate_stages_then_resets_worktree(self) -> None:
self._write("p1", "a")
self._write("p2", "b")

class CustomBot(Bot):
def act(self, prompt: str, toolbox: Toolbox) -> Action:
assert toolbox.read_file(PurePosixPath("p1")) == "a"
toolbox.write_file(PurePosixPath("p2"), "B")
toolbox.write_file(PurePosixPath("p3"), "C")
return Action()

self._drafter.generate_draft("hello", CustomBot())
assert self._commit_files("HEAD") == set(["p2", "p3"])

def test_generate_then_discard_draft(self) -> None:
self._drafter.generate_draft("hello", FakeBot())
Expand Down