Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions src/git_draft/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def new_parser() -> optparse.OptionParser:
)
parser.add_option(
"--root",
help="path used to locate repository",
help="path used to locate repository root",
dest="root",
)

Expand Down Expand Up @@ -64,8 +64,8 @@ def callback(_option, _opt, _value, parser) -> None:
)
parser.add_option(
"-c",
"--checkout",
help="check out generated changes",
"--clean",
help="remove deleted files from work directory",
action="store_true",
)
parser.add_option(
Expand Down Expand Up @@ -96,7 +96,7 @@ def callback(_option, _opt, _value, parser) -> None:
return parser


class _ToolPrinter(ToolVisitor):
class ToolPrinter(ToolVisitor):
def on_list_files(
self, _paths: Sequence[PurePosixPath], _reason: str | None
) -> None:
Expand Down Expand Up @@ -126,10 +126,7 @@ def main() -> None:
return
logging.basicConfig(level=config.log_level, filename=str(log_path))

drafter = Drafter.create(
store=Store.persistent(),
path=opts.root,
)
drafter = Drafter.create(store=Store.persistent(), path=opts.root)
command = getattr(opts, "command", "generate")
if command == "generate":
bot_config = None
Expand All @@ -154,13 +151,12 @@ def main() -> None:
name = drafter.generate_draft(
prompt,
bot,
tool_visitors=[_ToolPrinter()],
checkout=opts.checkout,
tool_visitors=[ToolPrinter()],
reset=opts.reset,
)
print(f"Generated {name}.")
elif command == "finalize":
name = drafter.finalize_draft(delete=opts.delete)
name = drafter.finalize_draft(clean=opts.clean, delete=opts.delete)
print(f"Finalized {name}.")
elif command == "revert":
name = drafter.revert_draft(delete=opts.delete)
Expand Down
97 changes: 73 additions & 24 deletions src/git_draft/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from datetime import datetime
import json
import logging
import os
import os.path as osp
from pathlib import PurePosixPath
import re
import textwrap
Expand Down Expand Up @@ -69,7 +71,6 @@ def generate_draft(
prompt: str | TemplatedPrompt,
bot: Bot,
tool_visitors: Sequence[ToolVisitor] | None = None,
checkout: bool = False,
reset: bool = False,
sync: bool = False,
timeout: float | None = None,
Expand Down Expand Up @@ -107,11 +108,13 @@ def generate_draft(
},
)

_logger.debug("Running bot... [bot=%s]", bot)
start_time = time.perf_counter()
goal = Goal(prompt_contents, timeout)
action = bot.act(goal, toolbox)
end_time = time.perf_counter()
walltime = end_time - start_time
_logger.info("Completed bot action. [action=%s]", action)

toolbox.trim_index()
title = action.title
Expand Down Expand Up @@ -145,16 +148,18 @@ def generate_draft(
],
)

_logger.info("Generated draft.")
if checkout:
self._repo.git.checkout("--", ".")
_logger.info("Generated %s.", branch)
return str(branch)

def finalize_draft(self, delete=False) -> str:
return self._exit_draft(revert=False, delete=delete)
def finalize_draft(self, clean=False, delete=False) -> str:
name = self._exit_draft(revert=False, clean=clean, delete=delete)
_logger.info("Finalized %s.", name)
return name

def revert_draft(self, delete=False) -> str:
return self._exit_draft(revert=True, delete=delete)
name = self._exit_draft(revert=True, clean=False, delete=delete)
_logger.info("Reverted %s.", name)
return name

def _create_branch(self, sync: bool) -> _Branch:
if self._repo.head.is_detached:
Expand Down Expand Up @@ -190,7 +195,7 @@ def _stage_changes(self, sync: bool) -> str | None:
ref = self._repo.index.commit("draft! sync")
return ref.hexsha

def _exit_draft(self, *, revert: bool, delete: bool) -> str:
def _exit_draft(self, *, revert: bool, clean: bool, delete: bool) -> str:
branch = _Branch.active(self._repo)
if not branch:
raise RuntimeError("Not currently on a draft branch")
Expand All @@ -200,15 +205,24 @@ def _exit_draft(self, *, revert: bool, delete: bool) -> str:
sql("get-branch-by-suffix"), {"suffix": branch.suffix}
)
if not rows:
raise RuntimeError("Unrecognized branch")
raise RuntimeError("Unrecognized draft branch")
[(origin_branch, origin_sha, sync_sha)] = rows

if (
revert
and sync_sha
and self._repo.commit(origin_branch).hexsha != origin_sha
):
raise RuntimeError("Parent branch has moved, please rebase")
raise RuntimeError("Parent branch has moved, please rebase first")

if clean:
# We delete files which have been deleted in the draft manually,
# otherwise they would still show up as untracked.
origin_delta = self._delta(f"{origin_branch}..{branch}")
deleted = self._untracked() & origin_delta.deleted
for path in deleted:
os.remove(osp.join(self._repo.working_dir, path))
_logger.info("Cleaned up files. [deleted=%s]", deleted)

# We do a small dance to move back to the original branch, keeping the
# draft branch untouched. See https://stackoverflow.com/a/15993574 for
Expand All @@ -217,25 +231,60 @@ def _exit_draft(self, *, revert: bool, delete: bool) -> str:
self._repo.git.reset("-N", origin_branch)
self._repo.git.checkout(origin_branch)

# Next, we revert the relevant files if needed. If a sync commit had
# been created, we simply revert to it. Otherwise we compute which
# files have changed due to draft commits and revert only those.
if revert:
# We revert the relevant files if needed. If a sync commit had been
# created, we simply revert to it. Otherwise we compute which files
# have changed due to draft commits and revert only those.
if sync_sha:
self._repo.git.checkout(sync_sha, "--", ".")
delta = self._delta(sync_sha)
if delta.changed:
self._repo.git.checkout(sync_sha, "--", ".")
_logger.info("Reverted to sync commit. [sha=%s]", sync_sha)
else:
diffed = set(self._changed_files(f"{origin_branch}..{branch}"))
dirty = [p for p in self._changed_files("HEAD") if p in diffed]
if dirty:
self._repo.git.checkout("--", *dirty)
origin_delta = self._delta(f"{origin_branch}..{branch}")
head_delta = self._delta("HEAD")
changed = head_delta.touched & origin_delta.changed
if changed:
self._repo.git.checkout("--", *changed)
deleted = head_delta.touched & origin_delta.deleted
if deleted:
self._repo.git.rm("--", *deleted)
_logger.info(
"Reverted touched files. [changed=%s, deleted=%s]",
changed,
deleted,
)

if delete:
self._repo.git.branch("-D", branch.name)
_logger.debug("Deleted branch %s.", branch)

return branch.name

def _changed_files(self, spec) -> Sequence[str]:
return self._repo.git.diff(spec, name_only=True).splitlines()
def _untracked(self) -> frozenset[str]:
text = self._repo.git.ls_files(exclude_standard=True, others=True)
return frozenset(text.splitlines())

def _delta(self, spec) -> _Delta:
changed = list[str]()
deleted = list[str]()
for line in self._repo.git.diff(spec, name_status=True).splitlines():
state, name = line.split(None, 1)
if state == "D":
deleted.append(name)
else:
changed.append(name)
return _Delta(changed=frozenset(changed), deleted=frozenset(deleted))


@dataclasses.dataclass(frozen=True)
class _Delta:
changed: frozenset[str]
deleted: frozenset[str]

@property
def touched(self) -> frozenset[str]:
return self.changed | self.deleted


class _OperationRecorder(ToolVisitor):
Expand Down Expand Up @@ -266,11 +315,11 @@ def on_delete_file(self, path: PurePosixPath, reason: str | None) -> None:
self._record(reason, "delete_file", path=str(path))

def _record(self, reason: str | None, tool: str, **kwargs) -> None:
self.operations.append(
_Operation(
tool=tool, details=kwargs, reason=reason, start=datetime.now()
)
op = _Operation(
tool=tool, details=kwargs, reason=reason, start=datetime.now()
)
_logger.debug("Recorded operation. [op=%s]", op)
self.operations.append(op)


@dataclasses.dataclass(frozen=True)
Expand Down
25 changes: 18 additions & 7 deletions src/git_draft/toolbox.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

import logging
from pathlib import PurePosixPath
import tempfile
from typing import Callable, Protocol, Sequence, override

import git


_logger = logging.getLogger(__name__)


class Toolbox:
"""File-system intermediary

Expand Down Expand Up @@ -58,7 +62,7 @@ def delete_file(
self,
path: PurePosixPath,
reason: str | None = None,
) -> None:
) -> bool:
self._dispatch(lambda v: v.on_delete_file(path, reason))
return self._delete(path)

Expand All @@ -71,7 +75,7 @@ def _read(self, path: PurePosixPath) -> str:
def _write(self, path: PurePosixPath, contents: str) -> None:
raise NotImplementedError()

def _delete(self, path: PurePosixPath) -> None:
def _delete(self, path: PurePosixPath) -> bool:
raise NotImplementedError()


Expand All @@ -94,7 +98,7 @@ def on_delete_file(


class StagingToolbox(Toolbox):
"""Git-index backed toolbox
"""Git-index backed toolbox implementation

All files are directly read from and written to the index. This allows
concurrent editing without interference with the working directory.
Expand Down Expand Up @@ -132,12 +136,18 @@ def _write(self, path: PurePosixPath, contents: str) -> None:
)

@override
def _delete(self, path: PurePosixPath) -> None:
self._updated.add(str(path))
raise NotImplementedError() # TODO
def _delete(self, path: PurePosixPath) -> bool:
try:
self._repo.git.rm("--", str(path), cached=True)
except git.GitCommandError as err:
_logger.warning("Failed to delete file. [err=%r]", err)
return False
else:
self._updated.add(str(path))
return True

def trim_index(self) -> None:
"""Unstage any files which have not been written to."""
"""Unstage any files which have not been written to"""
diff = self._repo.git.diff(name_only=True, cached=True)
untouched = [
path
Expand All @@ -146,3 +156,4 @@ def trim_index(self) -> None:
]
if untouched:
self._repo.git.reset("--", *untouched)
_logger.debug("Trimmed index. [reset_paths=%s]", untouched)
51 changes: 50 additions & 1 deletion tests/git_draft/drafter_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from pathlib import Path, PurePosixPath
from typing import Sequence

Expand Down Expand Up @@ -36,6 +37,9 @@ def _write(self, name: str, contents="") -> None:
with open(self._path(name), "w") as f:
f.write(contents)

def _delete(self, name: str) -> None:
os.remove(self._path(name))

def _commits(self) -> Sequence[git.Commit]:
return list(self._repo.iter_commits())

Expand All @@ -45,6 +49,9 @@ def _commit_files(self, ref: str) -> frozenset[str]:
)
return frozenset(text.splitlines())

def _checkout(self) -> None:
self._repo.git.checkout("--", ".")

def test_generate_draft(self) -> None:
self._drafter.generate_draft("hello", FakeBot())
assert len(self._commits()) == 2
Expand Down Expand Up @@ -125,6 +132,47 @@ def act(self, _goal: Goal, _toolbox: Toolbox) -> Action:
assert len(self._commits()) == 2 # init, prompt
assert not self._commit_files("HEAD")

def test_delete_unknown_file(self) -> None:
class CustomBot(Bot):
def act(self, _goal: Goal, toolbox: Toolbox) -> Action:
toolbox.delete_file(PurePosixPath("p1"))
return Action()

self._drafter.generate_draft("hello", CustomBot())

def test_sync_delete_revert(self) -> None:
self._write("p1", "a")
self._repo.git.add(all=True)
self._repo.index.commit("advance")
self._delete("p1")

class CustomBot(Bot):
def act(self, _goal: Goal, toolbox: Toolbox) -> Action:
toolbox.write_file(PurePosixPath("p2"), "b")
return Action()

self._drafter.generate_draft("hello", CustomBot(), sync=True)
assert self._read("p1") is None

self._drafter.revert_draft()
assert self._read("p1") is None

def test_generate_delete_finalize_clean(self) -> None:
self._write("p1", "a")
self._repo.git.add(all=True)
self._repo.index.commit("advance")

class CustomBot(Bot):
def act(self, _goal: Goal, toolbox: Toolbox) -> Action:
toolbox.delete_file(PurePosixPath("p1"))
return Action()

self._drafter.generate_draft("hello", CustomBot())
assert self._read("p1") == "a"

self._drafter.finalize_draft(clean=True)
assert self._read("p1") is None

def test_revert_outside_draft(self) -> None:
with pytest.raises(RuntimeError):
self._drafter.revert_draft()
Expand Down Expand Up @@ -180,7 +228,8 @@ def act(self, _goal: Goal, toolbox: Toolbox) -> Action:

def test_finalize_keeps_changes(self) -> None:
self._write("p1.txt", "a1")
self._drafter.generate_draft("hello", FakeBot(), checkout=True)
self._drafter.generate_draft("hello", FakeBot())
self._checkout()
self._write("p1.txt", "a2")
self._drafter.finalize_draft()
assert self._read("p1.txt") == "a2"
Expand Down