diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index 5f91675..7ecaf5b 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -36,7 +36,7 @@ def new_parser() -> optparse.OptionParser: ) parser.add_option( "--root", - help="path used to locate repository", + help="path used to locate repository root", dest="root", ) @@ -64,8 +64,8 @@ def callback(_option, _opt, _value, parser) -> None: ) parser.add_option( "-c", - "--checkout", - help="check out generated changes", + "--clean", + help="remove deleted files from work directory", action="store_true", ) parser.add_option( @@ -96,7 +96,7 @@ def callback(_option, _opt, _value, parser) -> None: return parser -class _ToolPrinter(ToolVisitor): +class ToolPrinter(ToolVisitor): def on_list_files( self, _paths: Sequence[PurePosixPath], _reason: str | None ) -> None: @@ -126,10 +126,7 @@ def main() -> None: return logging.basicConfig(level=config.log_level, filename=str(log_path)) - drafter = Drafter.create( - store=Store.persistent(), - path=opts.root, - ) + drafter = Drafter.create(store=Store.persistent(), path=opts.root) command = getattr(opts, "command", "generate") if command == "generate": bot_config = None @@ -154,13 +151,12 @@ def main() -> None: name = drafter.generate_draft( prompt, bot, - tool_visitors=[_ToolPrinter()], - checkout=opts.checkout, + tool_visitors=[ToolPrinter()], reset=opts.reset, ) print(f"Generated {name}.") elif command == "finalize": - name = drafter.finalize_draft(delete=opts.delete) + name = drafter.finalize_draft(clean=opts.clean, delete=opts.delete) print(f"Finalized {name}.") elif command == "revert": name = drafter.revert_draft(delete=opts.delete) diff --git a/src/git_draft/drafter.py b/src/git_draft/drafter.py index 34ceee7..5d564d8 100644 --- a/src/git_draft/drafter.py +++ b/src/git_draft/drafter.py @@ -4,6 +4,8 @@ from datetime import datetime import json import logging +import os +import os.path as osp from pathlib import PurePosixPath import re import textwrap @@ -69,7 +71,6 @@ def generate_draft( prompt: str | TemplatedPrompt, bot: Bot, tool_visitors: Sequence[ToolVisitor] | None = None, - checkout: bool = False, reset: bool = False, sync: bool = False, timeout: float | None = None, @@ -107,11 +108,13 @@ def generate_draft( }, ) + _logger.debug("Running bot... [bot=%s]", bot) start_time = time.perf_counter() goal = Goal(prompt_contents, timeout) action = bot.act(goal, toolbox) end_time = time.perf_counter() walltime = end_time - start_time + _logger.info("Completed bot action. [action=%s]", action) toolbox.trim_index() title = action.title @@ -145,16 +148,18 @@ def generate_draft( ], ) - _logger.info("Generated draft.") - if checkout: - self._repo.git.checkout("--", ".") + _logger.info("Generated %s.", branch) return str(branch) - def finalize_draft(self, delete=False) -> str: - return self._exit_draft(revert=False, delete=delete) + def finalize_draft(self, clean=False, delete=False) -> str: + name = self._exit_draft(revert=False, clean=clean, delete=delete) + _logger.info("Finalized %s.", name) + return name def revert_draft(self, delete=False) -> str: - return self._exit_draft(revert=True, delete=delete) + name = self._exit_draft(revert=True, clean=False, delete=delete) + _logger.info("Reverted %s.", name) + return name def _create_branch(self, sync: bool) -> _Branch: if self._repo.head.is_detached: @@ -190,7 +195,7 @@ def _stage_changes(self, sync: bool) -> str | None: ref = self._repo.index.commit("draft! sync") return ref.hexsha - def _exit_draft(self, *, revert: bool, delete: bool) -> str: + def _exit_draft(self, *, revert: bool, clean: bool, delete: bool) -> str: branch = _Branch.active(self._repo) if not branch: raise RuntimeError("Not currently on a draft branch") @@ -200,7 +205,7 @@ def _exit_draft(self, *, revert: bool, delete: bool) -> str: sql("get-branch-by-suffix"), {"suffix": branch.suffix} ) if not rows: - raise RuntimeError("Unrecognized branch") + raise RuntimeError("Unrecognized draft branch") [(origin_branch, origin_sha, sync_sha)] = rows if ( @@ -208,7 +213,16 @@ def _exit_draft(self, *, revert: bool, delete: bool) -> str: and sync_sha and self._repo.commit(origin_branch).hexsha != origin_sha ): - raise RuntimeError("Parent branch has moved, please rebase") + raise RuntimeError("Parent branch has moved, please rebase first") + + if clean: + # We delete files which have been deleted in the draft manually, + # otherwise they would still show up as untracked. + origin_delta = self._delta(f"{origin_branch}..{branch}") + deleted = self._untracked() & origin_delta.deleted + for path in deleted: + os.remove(osp.join(self._repo.working_dir, path)) + _logger.info("Cleaned up files. [deleted=%s]", deleted) # We do a small dance to move back to the original branch, keeping the # draft branch untouched. See https://stackoverflow.com/a/15993574 for @@ -217,25 +231,60 @@ def _exit_draft(self, *, revert: bool, delete: bool) -> str: self._repo.git.reset("-N", origin_branch) self._repo.git.checkout(origin_branch) - # Next, we revert the relevant files if needed. If a sync commit had - # been created, we simply revert to it. Otherwise we compute which - # files have changed due to draft commits and revert only those. if revert: + # We revert the relevant files if needed. If a sync commit had been + # created, we simply revert to it. Otherwise we compute which files + # have changed due to draft commits and revert only those. if sync_sha: - self._repo.git.checkout(sync_sha, "--", ".") + delta = self._delta(sync_sha) + if delta.changed: + self._repo.git.checkout(sync_sha, "--", ".") + _logger.info("Reverted to sync commit. [sha=%s]", sync_sha) else: - diffed = set(self._changed_files(f"{origin_branch}..{branch}")) - dirty = [p for p in self._changed_files("HEAD") if p in diffed] - if dirty: - self._repo.git.checkout("--", *dirty) + origin_delta = self._delta(f"{origin_branch}..{branch}") + head_delta = self._delta("HEAD") + changed = head_delta.touched & origin_delta.changed + if changed: + self._repo.git.checkout("--", *changed) + deleted = head_delta.touched & origin_delta.deleted + if deleted: + self._repo.git.rm("--", *deleted) + _logger.info( + "Reverted touched files. [changed=%s, deleted=%s]", + changed, + deleted, + ) if delete: self._repo.git.branch("-D", branch.name) + _logger.debug("Deleted branch %s.", branch) return branch.name - def _changed_files(self, spec) -> Sequence[str]: - return self._repo.git.diff(spec, name_only=True).splitlines() + def _untracked(self) -> frozenset[str]: + text = self._repo.git.ls_files(exclude_standard=True, others=True) + return frozenset(text.splitlines()) + + def _delta(self, spec) -> _Delta: + changed = list[str]() + deleted = list[str]() + for line in self._repo.git.diff(spec, name_status=True).splitlines(): + state, name = line.split(None, 1) + if state == "D": + deleted.append(name) + else: + changed.append(name) + return _Delta(changed=frozenset(changed), deleted=frozenset(deleted)) + + +@dataclasses.dataclass(frozen=True) +class _Delta: + changed: frozenset[str] + deleted: frozenset[str] + + @property + def touched(self) -> frozenset[str]: + return self.changed | self.deleted class _OperationRecorder(ToolVisitor): @@ -266,11 +315,11 @@ def on_delete_file(self, path: PurePosixPath, reason: str | None) -> None: self._record(reason, "delete_file", path=str(path)) def _record(self, reason: str | None, tool: str, **kwargs) -> None: - self.operations.append( - _Operation( - tool=tool, details=kwargs, reason=reason, start=datetime.now() - ) + op = _Operation( + tool=tool, details=kwargs, reason=reason, start=datetime.now() ) + _logger.debug("Recorded operation. [op=%s]", op) + self.operations.append(op) @dataclasses.dataclass(frozen=True) diff --git a/src/git_draft/toolbox.py b/src/git_draft/toolbox.py index 04ad846..9e2972d 100644 --- a/src/git_draft/toolbox.py +++ b/src/git_draft/toolbox.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging from pathlib import PurePosixPath import tempfile from typing import Callable, Protocol, Sequence, override @@ -7,6 +8,9 @@ import git +_logger = logging.getLogger(__name__) + + class Toolbox: """File-system intermediary @@ -58,7 +62,7 @@ def delete_file( self, path: PurePosixPath, reason: str | None = None, - ) -> None: + ) -> bool: self._dispatch(lambda v: v.on_delete_file(path, reason)) return self._delete(path) @@ -71,7 +75,7 @@ def _read(self, path: PurePosixPath) -> str: def _write(self, path: PurePosixPath, contents: str) -> None: raise NotImplementedError() - def _delete(self, path: PurePosixPath) -> None: + def _delete(self, path: PurePosixPath) -> bool: raise NotImplementedError() @@ -94,7 +98,7 @@ def on_delete_file( class StagingToolbox(Toolbox): - """Git-index backed toolbox + """Git-index backed toolbox implementation All files are directly read from and written to the index. This allows concurrent editing without interference with the working directory. @@ -132,12 +136,18 @@ def _write(self, path: PurePosixPath, contents: str) -> None: ) @override - def _delete(self, path: PurePosixPath) -> None: - self._updated.add(str(path)) - raise NotImplementedError() # TODO + def _delete(self, path: PurePosixPath) -> bool: + try: + self._repo.git.rm("--", str(path), cached=True) + except git.GitCommandError as err: + _logger.warning("Failed to delete file. [err=%r]", err) + return False + else: + self._updated.add(str(path)) + return True def trim_index(self) -> None: - """Unstage any files which have not been written to.""" + """Unstage any files which have not been written to""" diff = self._repo.git.diff(name_only=True, cached=True) untouched = [ path @@ -146,3 +156,4 @@ def trim_index(self) -> None: ] if untouched: self._repo.git.reset("--", *untouched) + _logger.debug("Trimmed index. [reset_paths=%s]", untouched) diff --git a/tests/git_draft/drafter_test.py b/tests/git_draft/drafter_test.py index c65f703..f51cdf3 100644 --- a/tests/git_draft/drafter_test.py +++ b/tests/git_draft/drafter_test.py @@ -1,3 +1,4 @@ +import os from pathlib import Path, PurePosixPath from typing import Sequence @@ -36,6 +37,9 @@ def _write(self, name: str, contents="") -> None: with open(self._path(name), "w") as f: f.write(contents) + def _delete(self, name: str) -> None: + os.remove(self._path(name)) + def _commits(self) -> Sequence[git.Commit]: return list(self._repo.iter_commits()) @@ -45,6 +49,9 @@ def _commit_files(self, ref: str) -> frozenset[str]: ) return frozenset(text.splitlines()) + def _checkout(self) -> None: + self._repo.git.checkout("--", ".") + def test_generate_draft(self) -> None: self._drafter.generate_draft("hello", FakeBot()) assert len(self._commits()) == 2 @@ -125,6 +132,47 @@ def act(self, _goal: Goal, _toolbox: Toolbox) -> Action: assert len(self._commits()) == 2 # init, prompt assert not self._commit_files("HEAD") + def test_delete_unknown_file(self) -> None: + class CustomBot(Bot): + def act(self, _goal: Goal, toolbox: Toolbox) -> Action: + toolbox.delete_file(PurePosixPath("p1")) + return Action() + + self._drafter.generate_draft("hello", CustomBot()) + + def test_sync_delete_revert(self) -> None: + self._write("p1", "a") + self._repo.git.add(all=True) + self._repo.index.commit("advance") + self._delete("p1") + + class CustomBot(Bot): + def act(self, _goal: Goal, toolbox: Toolbox) -> Action: + toolbox.write_file(PurePosixPath("p2"), "b") + return Action() + + self._drafter.generate_draft("hello", CustomBot(), sync=True) + assert self._read("p1") is None + + self._drafter.revert_draft() + assert self._read("p1") is None + + def test_generate_delete_finalize_clean(self) -> None: + self._write("p1", "a") + self._repo.git.add(all=True) + self._repo.index.commit("advance") + + class CustomBot(Bot): + def act(self, _goal: Goal, toolbox: Toolbox) -> Action: + toolbox.delete_file(PurePosixPath("p1")) + return Action() + + self._drafter.generate_draft("hello", CustomBot()) + assert self._read("p1") == "a" + + self._drafter.finalize_draft(clean=True) + assert self._read("p1") is None + def test_revert_outside_draft(self) -> None: with pytest.raises(RuntimeError): self._drafter.revert_draft() @@ -180,7 +228,8 @@ def act(self, _goal: Goal, toolbox: Toolbox) -> Action: def test_finalize_keeps_changes(self) -> None: self._write("p1.txt", "a1") - self._drafter.generate_draft("hello", FakeBot(), checkout=True) + self._drafter.generate_draft("hello", FakeBot()) + self._checkout() self._write("p1.txt", "a2") self._drafter.finalize_draft() assert self._read("p1.txt") == "a2"