Skip to content

Commit 548c36c

Browse files
authored
feat: implement toolbox deletion (#46)
1 parent a3a1fa6 commit 548c36c

File tree

4 files changed

+148
-43
lines changed

4 files changed

+148
-43
lines changed

src/git_draft/__main__.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def new_parser() -> optparse.OptionParser:
3636
)
3737
parser.add_option(
3838
"--root",
39-
help="path used to locate repository",
39+
help="path used to locate repository root",
4040
dest="root",
4141
)
4242

@@ -64,8 +64,8 @@ def callback(_option, _opt, _value, parser) -> None:
6464
)
6565
parser.add_option(
6666
"-c",
67-
"--checkout",
68-
help="check out generated changes",
67+
"--clean",
68+
help="remove deleted files from work directory",
6969
action="store_true",
7070
)
7171
parser.add_option(
@@ -96,7 +96,7 @@ def callback(_option, _opt, _value, parser) -> None:
9696
return parser
9797

9898

99-
class _ToolPrinter(ToolVisitor):
99+
class ToolPrinter(ToolVisitor):
100100
def on_list_files(
101101
self, _paths: Sequence[PurePosixPath], _reason: str | None
102102
) -> None:
@@ -126,10 +126,7 @@ def main() -> None:
126126
return
127127
logging.basicConfig(level=config.log_level, filename=str(log_path))
128128

129-
drafter = Drafter.create(
130-
store=Store.persistent(),
131-
path=opts.root,
132-
)
129+
drafter = Drafter.create(store=Store.persistent(), path=opts.root)
133130
command = getattr(opts, "command", "generate")
134131
if command == "generate":
135132
bot_config = None
@@ -154,13 +151,12 @@ def main() -> None:
154151
name = drafter.generate_draft(
155152
prompt,
156153
bot,
157-
tool_visitors=[_ToolPrinter()],
158-
checkout=opts.checkout,
154+
tool_visitors=[ToolPrinter()],
159155
reset=opts.reset,
160156
)
161157
print(f"Generated {name}.")
162158
elif command == "finalize":
163-
name = drafter.finalize_draft(delete=opts.delete)
159+
name = drafter.finalize_draft(clean=opts.clean, delete=opts.delete)
164160
print(f"Finalized {name}.")
165161
elif command == "revert":
166162
name = drafter.revert_draft(delete=opts.delete)

src/git_draft/drafter.py

Lines changed: 73 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from datetime import datetime
55
import json
66
import logging
7+
import os
8+
import os.path as osp
79
from pathlib import PurePosixPath
810
import re
911
import textwrap
@@ -69,7 +71,6 @@ def generate_draft(
6971
prompt: str | TemplatedPrompt,
7072
bot: Bot,
7173
tool_visitors: Sequence[ToolVisitor] | None = None,
72-
checkout: bool = False,
7374
reset: bool = False,
7475
sync: bool = False,
7576
timeout: float | None = None,
@@ -107,11 +108,13 @@ def generate_draft(
107108
},
108109
)
109110

111+
_logger.debug("Running bot... [bot=%s]", bot)
110112
start_time = time.perf_counter()
111113
goal = Goal(prompt_contents, timeout)
112114
action = bot.act(goal, toolbox)
113115
end_time = time.perf_counter()
114116
walltime = end_time - start_time
117+
_logger.info("Completed bot action. [action=%s]", action)
115118

116119
toolbox.trim_index()
117120
title = action.title
@@ -145,16 +148,18 @@ def generate_draft(
145148
],
146149
)
147150

148-
_logger.info("Generated draft.")
149-
if checkout:
150-
self._repo.git.checkout("--", ".")
151+
_logger.info("Generated %s.", branch)
151152
return str(branch)
152153

153-
def finalize_draft(self, delete=False) -> str:
154-
return self._exit_draft(revert=False, delete=delete)
154+
def finalize_draft(self, clean=False, delete=False) -> str:
155+
name = self._exit_draft(revert=False, clean=clean, delete=delete)
156+
_logger.info("Finalized %s.", name)
157+
return name
155158

156159
def revert_draft(self, delete=False) -> str:
157-
return self._exit_draft(revert=True, delete=delete)
160+
name = self._exit_draft(revert=True, clean=False, delete=delete)
161+
_logger.info("Reverted %s.", name)
162+
return name
158163

159164
def _create_branch(self, sync: bool) -> _Branch:
160165
if self._repo.head.is_detached:
@@ -190,7 +195,7 @@ def _stage_changes(self, sync: bool) -> str | None:
190195
ref = self._repo.index.commit("draft! sync")
191196
return ref.hexsha
192197

193-
def _exit_draft(self, *, revert: bool, delete: bool) -> str:
198+
def _exit_draft(self, *, revert: bool, clean: bool, delete: bool) -> str:
194199
branch = _Branch.active(self._repo)
195200
if not branch:
196201
raise RuntimeError("Not currently on a draft branch")
@@ -200,15 +205,24 @@ def _exit_draft(self, *, revert: bool, delete: bool) -> str:
200205
sql("get-branch-by-suffix"), {"suffix": branch.suffix}
201206
)
202207
if not rows:
203-
raise RuntimeError("Unrecognized branch")
208+
raise RuntimeError("Unrecognized draft branch")
204209
[(origin_branch, origin_sha, sync_sha)] = rows
205210

206211
if (
207212
revert
208213
and sync_sha
209214
and self._repo.commit(origin_branch).hexsha != origin_sha
210215
):
211-
raise RuntimeError("Parent branch has moved, please rebase")
216+
raise RuntimeError("Parent branch has moved, please rebase first")
217+
218+
if clean:
219+
# We delete files which have been deleted in the draft manually,
220+
# otherwise they would still show up as untracked.
221+
origin_delta = self._delta(f"{origin_branch}..{branch}")
222+
deleted = self._untracked() & origin_delta.deleted
223+
for path in deleted:
224+
os.remove(osp.join(self._repo.working_dir, path))
225+
_logger.info("Cleaned up files. [deleted=%s]", deleted)
212226

213227
# We do a small dance to move back to the original branch, keeping the
214228
# draft branch untouched. See https://stackoverflow.com/a/15993574 for
@@ -217,25 +231,60 @@ def _exit_draft(self, *, revert: bool, delete: bool) -> str:
217231
self._repo.git.reset("-N", origin_branch)
218232
self._repo.git.checkout(origin_branch)
219233

220-
# Next, we revert the relevant files if needed. If a sync commit had
221-
# been created, we simply revert to it. Otherwise we compute which
222-
# files have changed due to draft commits and revert only those.
223234
if revert:
235+
# We revert the relevant files if needed. If a sync commit had been
236+
# created, we simply revert to it. Otherwise we compute which files
237+
# have changed due to draft commits and revert only those.
224238
if sync_sha:
225-
self._repo.git.checkout(sync_sha, "--", ".")
239+
delta = self._delta(sync_sha)
240+
if delta.changed:
241+
self._repo.git.checkout(sync_sha, "--", ".")
242+
_logger.info("Reverted to sync commit. [sha=%s]", sync_sha)
226243
else:
227-
diffed = set(self._changed_files(f"{origin_branch}..{branch}"))
228-
dirty = [p for p in self._changed_files("HEAD") if p in diffed]
229-
if dirty:
230-
self._repo.git.checkout("--", *dirty)
244+
origin_delta = self._delta(f"{origin_branch}..{branch}")
245+
head_delta = self._delta("HEAD")
246+
changed = head_delta.touched & origin_delta.changed
247+
if changed:
248+
self._repo.git.checkout("--", *changed)
249+
deleted = head_delta.touched & origin_delta.deleted
250+
if deleted:
251+
self._repo.git.rm("--", *deleted)
252+
_logger.info(
253+
"Reverted touched files. [changed=%s, deleted=%s]",
254+
changed,
255+
deleted,
256+
)
231257

232258
if delete:
233259
self._repo.git.branch("-D", branch.name)
260+
_logger.debug("Deleted branch %s.", branch)
234261

235262
return branch.name
236263

237-
def _changed_files(self, spec) -> Sequence[str]:
238-
return self._repo.git.diff(spec, name_only=True).splitlines()
264+
def _untracked(self) -> frozenset[str]:
265+
text = self._repo.git.ls_files(exclude_standard=True, others=True)
266+
return frozenset(text.splitlines())
267+
268+
def _delta(self, spec) -> _Delta:
269+
changed = list[str]()
270+
deleted = list[str]()
271+
for line in self._repo.git.diff(spec, name_status=True).splitlines():
272+
state, name = line.split(None, 1)
273+
if state == "D":
274+
deleted.append(name)
275+
else:
276+
changed.append(name)
277+
return _Delta(changed=frozenset(changed), deleted=frozenset(deleted))
278+
279+
280+
@dataclasses.dataclass(frozen=True)
281+
class _Delta:
282+
changed: frozenset[str]
283+
deleted: frozenset[str]
284+
285+
@property
286+
def touched(self) -> frozenset[str]:
287+
return self.changed | self.deleted
239288

240289

241290
class _OperationRecorder(ToolVisitor):
@@ -266,11 +315,11 @@ def on_delete_file(self, path: PurePosixPath, reason: str | None) -> None:
266315
self._record(reason, "delete_file", path=str(path))
267316

268317
def _record(self, reason: str | None, tool: str, **kwargs) -> None:
269-
self.operations.append(
270-
_Operation(
271-
tool=tool, details=kwargs, reason=reason, start=datetime.now()
272-
)
318+
op = _Operation(
319+
tool=tool, details=kwargs, reason=reason, start=datetime.now()
273320
)
321+
_logger.debug("Recorded operation. [op=%s]", op)
322+
self.operations.append(op)
274323

275324

276325
@dataclasses.dataclass(frozen=True)

src/git_draft/toolbox.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from __future__ import annotations
22

3+
import logging
34
from pathlib import PurePosixPath
45
import tempfile
56
from typing import Callable, Protocol, Sequence, override
67

78
import git
89

910

11+
_logger = logging.getLogger(__name__)
12+
13+
1014
class Toolbox:
1115
"""File-system intermediary
1216
@@ -58,7 +62,7 @@ def delete_file(
5862
self,
5963
path: PurePosixPath,
6064
reason: str | None = None,
61-
) -> None:
65+
) -> bool:
6266
self._dispatch(lambda v: v.on_delete_file(path, reason))
6367
return self._delete(path)
6468

@@ -71,7 +75,7 @@ def _read(self, path: PurePosixPath) -> str:
7175
def _write(self, path: PurePosixPath, contents: str) -> None:
7276
raise NotImplementedError()
7377

74-
def _delete(self, path: PurePosixPath) -> None:
78+
def _delete(self, path: PurePosixPath) -> bool:
7579
raise NotImplementedError()
7680

7781

@@ -94,7 +98,7 @@ def on_delete_file(
9498

9599

96100
class StagingToolbox(Toolbox):
97-
"""Git-index backed toolbox
101+
"""Git-index backed toolbox implementation
98102
99103
All files are directly read from and written to the index. This allows
100104
concurrent editing without interference with the working directory.
@@ -132,12 +136,18 @@ def _write(self, path: PurePosixPath, contents: str) -> None:
132136
)
133137

134138
@override
135-
def _delete(self, path: PurePosixPath) -> None:
136-
self._updated.add(str(path))
137-
raise NotImplementedError() # TODO
139+
def _delete(self, path: PurePosixPath) -> bool:
140+
try:
141+
self._repo.git.rm("--", str(path), cached=True)
142+
except git.GitCommandError as err:
143+
_logger.warning("Failed to delete file. [err=%r]", err)
144+
return False
145+
else:
146+
self._updated.add(str(path))
147+
return True
138148

139149
def trim_index(self) -> None:
140-
"""Unstage any files which have not been written to."""
150+
"""Unstage any files which have not been written to"""
141151
diff = self._repo.git.diff(name_only=True, cached=True)
142152
untouched = [
143153
path
@@ -146,3 +156,4 @@ def trim_index(self) -> None:
146156
]
147157
if untouched:
148158
self._repo.git.reset("--", *untouched)
159+
_logger.debug("Trimmed index. [reset_paths=%s]", untouched)

tests/git_draft/drafter_test.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from pathlib import Path, PurePosixPath
23
from typing import Sequence
34

@@ -36,6 +37,9 @@ def _write(self, name: str, contents="") -> None:
3637
with open(self._path(name), "w") as f:
3738
f.write(contents)
3839

40+
def _delete(self, name: str) -> None:
41+
os.remove(self._path(name))
42+
3943
def _commits(self) -> Sequence[git.Commit]:
4044
return list(self._repo.iter_commits())
4145

@@ -45,6 +49,9 @@ def _commit_files(self, ref: str) -> frozenset[str]:
4549
)
4650
return frozenset(text.splitlines())
4751

52+
def _checkout(self) -> None:
53+
self._repo.git.checkout("--", ".")
54+
4855
def test_generate_draft(self) -> None:
4956
self._drafter.generate_draft("hello", FakeBot())
5057
assert len(self._commits()) == 2
@@ -125,6 +132,47 @@ def act(self, _goal: Goal, _toolbox: Toolbox) -> Action:
125132
assert len(self._commits()) == 2 # init, prompt
126133
assert not self._commit_files("HEAD")
127134

135+
def test_delete_unknown_file(self) -> None:
136+
class CustomBot(Bot):
137+
def act(self, _goal: Goal, toolbox: Toolbox) -> Action:
138+
toolbox.delete_file(PurePosixPath("p1"))
139+
return Action()
140+
141+
self._drafter.generate_draft("hello", CustomBot())
142+
143+
def test_sync_delete_revert(self) -> None:
144+
self._write("p1", "a")
145+
self._repo.git.add(all=True)
146+
self._repo.index.commit("advance")
147+
self._delete("p1")
148+
149+
class CustomBot(Bot):
150+
def act(self, _goal: Goal, toolbox: Toolbox) -> Action:
151+
toolbox.write_file(PurePosixPath("p2"), "b")
152+
return Action()
153+
154+
self._drafter.generate_draft("hello", CustomBot(), sync=True)
155+
assert self._read("p1") is None
156+
157+
self._drafter.revert_draft()
158+
assert self._read("p1") is None
159+
160+
def test_generate_delete_finalize_clean(self) -> None:
161+
self._write("p1", "a")
162+
self._repo.git.add(all=True)
163+
self._repo.index.commit("advance")
164+
165+
class CustomBot(Bot):
166+
def act(self, _goal: Goal, toolbox: Toolbox) -> Action:
167+
toolbox.delete_file(PurePosixPath("p1"))
168+
return Action()
169+
170+
self._drafter.generate_draft("hello", CustomBot())
171+
assert self._read("p1") == "a"
172+
173+
self._drafter.finalize_draft(clean=True)
174+
assert self._read("p1") is None
175+
128176
def test_revert_outside_draft(self) -> None:
129177
with pytest.raises(RuntimeError):
130178
self._drafter.revert_draft()
@@ -180,7 +228,8 @@ def act(self, _goal: Goal, toolbox: Toolbox) -> Action:
180228

181229
def test_finalize_keeps_changes(self) -> None:
182230
self._write("p1.txt", "a1")
183-
self._drafter.generate_draft("hello", FakeBot(), checkout=True)
231+
self._drafter.generate_draft("hello", FakeBot())
232+
self._checkout()
184233
self._write("p1.txt", "a2")
185234
self._drafter.finalize_draft()
186235
assert self._read("p1.txt") == "a2"

0 commit comments

Comments
 (0)