diff --git a/docs/git-draft.adoc b/docs/git-draft.adoc index e41ec9f..4587e0a 100644 --- a/docs/git-draft.adoc +++ b/docs/git-draft.adoc @@ -20,9 +20,7 @@ IMPORTANT: `git-draft` is WIP. [verse] git draft [options] [--generate] [--accept... | --no-accept] [--bot BOT] [--edit] [--reset | --no-reset] [TEMPLATE [VARIABLE...]] -git draft [options] --finalize [--delete] -git draft [options] --show-drafts [--json] -git draft [options] --show-prompts [--json] [PROMPT] +git draft [options] --finalize git draft [options] --show-templates [--json | [--edit] TEMPLATE] @@ -38,16 +36,11 @@ git draft [options] --show-templates [--json | [--edit] TEMPLATE] --no-accept:: Check out generated changes automatically. Can be repeated. - This may fail if you manually edit files that the bot updates during generation. -b BOT:: --bot=BOT:: Bot name. --d:: ---delete:: - Delete finalized branch. - -e:: --edit:: Enable interactive editing of prompts and templates. @@ -73,23 +66,9 @@ git draft [options] --show-templates [--json | [--edit] TEMPLATE] --log:: Show log path and exit. ---reset:: ---no-reset:: - Controls behavior when staged changes are present at the start of a generate command. - If enabled, these changes are automatically reset and combined with other working directory changes. - Otherwise an error is raised. - --root:: Repository search root. --D:: ---show-drafts:: - List recently created drafts. - --P:: ---show-prompts:: - Lists recently used prompts. - -T:: --show-templates:: Lists available templates. @@ -109,13 +88,8 @@ The workhorse command is `git draft --generate` which leverages AI to edit our c A prompt can be specified as standard input, for example `echo "Add a test for compute_offset in chart.py" | git draft --generate`. If no prompt is specified and stdin is a TTY, `$EDITOR` will be opened to enter the prompt. -If not on a draft branch, a new draft branch called `drafts/$parent/$hash` will be created (`$hash` is a random suffix used to guarantee uniqueness of branch names) and checked out. -By default any unstaged changes are then automatically added and committed (`draft! sync`). -This behavior can be disabled by passing in `--stash`, which will instead add them to the stash. -Staged changes are always committed. - -The prompt automatically gets augmented with information about the files in the repository, and give the AI access to tools for reading and writing files. -Once the response has been received and changes applied, a commit is created (`draft! prompt: a short summary of the change`). +By default, the prompt gets augmented with information about the files in the repository, and give the AI access to tools for reading and writing files. +Once the response has been received and changes applied, a commit is created in a separate branch. The `--generate` step can be repeated as many times as needed. Once you are satisfied with the changes, run `git draft --finalize` to apply them. @@ -123,6 +97,66 @@ This will check out the branch used when creating the draft, adding the final st Note that you can come back to an existing draft anytime (by checking its branch out), but you will not be able to apply it if its origin branch has moved since the draft was created. +* Generate: create a new draft to the current folio, or create a new folio if none exists. +* Finalize + ** Apply: include changes into origin branch. + ** Discard: abandon folio. + ** Save: return to original branch. Q: how to load after? +* List templates + + +o Foo (main) +o Sync (drafts/123/pub) + +# After generate without accept +o Foo (main) +o Sync (drafts/123) +o draft! (drafts/123+, refs/drafts/123/1) + +# After generate with accept +o Foo (main) +o Sync +|\ +| o draft! prompt: (refs/drafts/123/1) +o | Sync +|/ +o Merge (drafts/123/pub) + + +o Foo (main) +o draft! sync +|\ +| o draft! prompt: (drafts/123+, refs/drafts/123/1) +o Something +o Also something (drafts/123) + + +o Foo (main) +o draft! sync +|\ +| o draft! prompt: (refs/drafts/123/1) +o Something +o Also something (drafts/123/pub) +o draft! prompt: (drafts/123+, refs/drafts/123/2) + + +o Foo (main) +o draft! sync (drafts/123/pub) +|\ +| o draft! prompt: (refs/drafts/123/1) + \ + o draft! prompt: (drafts/123+, refs/drafts/123/2) + +o Foo (main) +o draft! sync (drafts/123/pub) +|\ +| o draft! prompt: (refs/drafts/123/1) +|/ +o draft! sync + \ + o draft! prompt: (drafts/123+, refs/drafts/123/2) + + == See also `git(1)` diff --git a/src/git_draft/__init__.py b/src/git_draft/__init__.py index c7073f3..171476d 100644 --- a/src/git_draft/__init__.py +++ b/src/git_draft/__init__.py @@ -2,7 +2,8 @@ import logging -from .bots import Action, Bot, Goal, Toolbox +from .bots import Action, Bot, Goal +from .toolbox import Toolbox __all__ = [ diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index 4c3fc99..6fc194a 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -60,8 +60,6 @@ def callback( add_command("finalize", help="apply current draft to original branch") add_command("generate", help="create or update draft from a prompt") - add_command("show-drafts", short="D", help="show draft history") - add_command("show-prompts", short="P", help="show prompt history") add_command("show-templates", short="T", help="show template information") parser.add_option( @@ -76,12 +74,6 @@ def callback( dest="bot", help="bot name", ) - parser.add_option( - "-d", - "--delete", - help="delete draft after finalizing", - action="store_true", - ) parser.add_option( "-e", "--edit", @@ -102,18 +94,6 @@ def callback( action="store_const", const=0, ) - parser.add_option( - "--no-reset", - help="abort if there are any staged changes", - dest="reset", - action="store_false", - ) - parser.add_option( - "--reset", - help="reset index before generating a new draft", - dest="reset", - action="store_true", - ) parser.add_option( "--timeout", dest="timeout", @@ -213,33 +193,26 @@ def main() -> None: # noqa: PLR0912 PLR0915 prompt = sys.stdin.read() accept = Accept(opts.accept or 0) - draft = drafter.generate_draft( + drafter.generate_draft( prompt, bot, accept=accept, bot_name=opts.bot, prompt_transform=open_editor if editable else None, tool_visitors=[ToolPrinter()], - reset=config.reset if opts.reset is None else opts.reset, ) match accept: case Accept.MANUAL: - print(f"Generated change in {draft.branch_name}.") - case Accept.CHECKOUT: - print(f"Applied change in {draft.branch_name}.") - case Accept.FINALIZE | Accept.NO_REGRETS: - print(f"Finalized change via {draft.branch_name}.") + print("Generated draft.") + case Accept.MERGE: + print("Merged draft.") + case Accept.FINALIZE: + print("Finalized draft.") case _: raise UnreachableError() case "finalize": - draft = drafter.finalize_draft(delete=opts.delete) - print(f"Finalized {draft.branch_name}.") - case "show-drafts": - table = drafter.history_table(args[0] if args else None) - if table: - print(table.to_json() if opts.json else table) - case "show-prompts": - raise NotImplementedError() # TODO: Implement + drafter.finalize_folio() + print("Finalized draft folio.") case "show-templates": if args: name = args[0] diff --git a/src/git_draft/common.py b/src/git_draft/common.py index 7af8f9c..2e573f3 100644 --- a/src/git_draft/common.py +++ b/src/git_draft/common.py @@ -8,9 +8,7 @@ import logging import os from pathlib import Path -import random import sqlite3 -import string import textwrap import tomllib from typing import Any, ClassVar, Self @@ -41,7 +39,6 @@ class Config: bots: Sequence[BotConfig] = dataclasses.field(default_factory=lambda: []) log_level: int = logging.INFO - reset: bool = True @staticmethod def folder_path() -> Path: @@ -73,31 +70,11 @@ class BotConfig: pythonpath: str | None = None -type RepoID = str - - -@dataclasses.dataclass(frozen=True) -class RepoConfig: # TODO: Use - """Repository-specific config""" - - repo_id: str - bot_name: str | None = None - - def config_string(arg: str) -> str: """Dereferences environment value if the input starts with `$`""" return os.environ[arg[1:]] if arg and arg.startswith("$") else arg -_random = random.Random() -_alphabet = string.ascii_lowercase + string.digits - - -def random_id(n: int) -> str: - """Generates a random length n string of lowercase letters and digits""" - return "".join(_random.choices(_alphabet, k=n)) - - class UnreachableError(RuntimeError): """Indicates unreachable code was unexpectedly executed""" diff --git a/src/git_draft/drafter.py b/src/git_draft/drafter.py index f5828e8..6bb7b84 100644 --- a/src/git_draft/drafter.py +++ b/src/git_draft/drafter.py @@ -10,16 +10,15 @@ import logging from pathlib import Path, PurePosixPath import re -from re import Match import textwrap import time from .bots import Action, Bot, Goal -from .common import JSONObject, Table, qualified_class_name, random_id -from .git import Commit, Repo +from .common import JSONObject, qualified_class_name +from .git import SHA, Repo from .prompt import PromptRenderer, TemplatedPrompt from .store import Store, sql -from .toolbox import StagingToolbox, ToolVisitor +from .toolbox import RepoToolbox, ToolVisitor _logger = logging.getLogger(__name__) @@ -29,48 +28,57 @@ class Accept(enum.Enum): """Valid change accept mode""" MANUAL = 0 - CHECKOUT = enum.auto() + MERGE = enum.auto() FINALIZE = enum.auto() - NO_REGRETS = enum.auto() @dataclasses.dataclass(frozen=True) class Draft: - """Collection of generated changes""" + """Generated changes""" - branch_name: str + folio: Folio + seqno: int + is_noop: bool + walltime: timedelta + token_count: int | None + @property + def ref(self) -> str: + return _draft_ref(self.folio.id, self.seqno) -@dataclasses.dataclass(frozen=True) -class _Branch: - """Draft branch""" - _pattern = re.compile(r"drafts/(.+)") +def _draft_ref(folio_id: int, suffix: int | str) -> str: + return f"refs/drafts/{folio_id}/{suffix}" - folio_id: str - @property - def name(self) -> str: - return f"drafts/{self.folio_id}" +_FOLIO_BRANCH_NAMESPACE = "drafts" - def __str__(self) -> str: - return self.name +_FOLIO_UPSTREAM_BRANCH_SUFFIX = "+" + +_folio_branch_pattern = re.compile(_FOLIO_BRANCH_NAMESPACE + r"/(\d+)") + + +@dataclasses.dataclass(frozen=True) +class Folio: + """Collection of drafts""" + + id: int + + def branch_name(self) -> str: + return f"{_FOLIO_BRANCH_NAMESPACE}/{self.id}" + + def upstream_branch_name(self) -> str: + return self.branch_name() + _FOLIO_UPSTREAM_BRANCH_SUFFIX - @classmethod - def active(cls, repo: Repo, name: str | None = None) -> _Branch | None: - match: Match | None = None - active_branch = name or repo.active_branch() - if active_branch: - match = cls._pattern.fullmatch(active_branch) - if not match: - if name: - raise ValueError(f"Not a valid draft branch name: {name!r}") - return None - return _Branch(match[1]) - @staticmethod - def new_suffix() -> str: - return random_id(9) +def _active_folio(repo: Repo) -> Folio | None: + active_branch = repo.active_branch() + if not active_branch: + return None + match = _folio_branch_pattern.fullmatch(active_branch) + if not match: + return None + return Folio(int(match[1])) class Drafter: @@ -84,7 +92,8 @@ def __init__(self, store: Store, repo: Repo) -> None: @classmethod def create(cls, store: Store, path: str | None = None) -> Drafter: - return cls(store, Repo.enclosing(Path(path) if path else Path.cwd())) + repo = Repo.enclosing(Path(path) if path else Path.cwd()) + return cls(store, repo) def generate_draft( # noqa: PLR0913 self, @@ -93,34 +102,27 @@ def generate_draft( # noqa: PLR0913 accept: Accept = Accept.MANUAL, bot_name: str | None = None, prompt_transform: Callable[[str], str] | None = None, - reset: bool = False, timeout: float | None = None, tool_visitors: Sequence[ToolVisitor] | None = None, ) -> Draft: if timeout is not None: raise NotImplementedError() # TODO: Implement - if self._repo.has_staged_changes(): - if not reset: - raise ValueError("Please commit or reset any staged changes") - self._repo.git("reset") - - # Ensure that we are on a draft branch. - branch = _Branch.active(self._repo) - if branch: - self._stage_repo() - _logger.debug("Reusing active branch %s.", branch) - else: - branch = self._create_branch() - _logger.debug("Created branch %s.", branch) + # Ensure that we are in a folio. + folio = _active_folio(self._repo) + if not folio: + folio = self._create_folio() + toolbox, parent_commit_sha = self._synced_toolbox() # Handle prompt templating and editing. - prompt_contents = self._prepare_prompt(prompt, prompt_transform) + prompt_contents = self._prepare_prompt( + prompt, prompt_transform, toolbox + ) with self._store.cursor() as cursor: - [(prompt_id,)] = cursor.execute( + [(prompt_id, seqno)] = cursor.execute( sql("add-prompt"), { - "branch_suffix": branch.folio_id, + "folio_id": folio.id, "template": prompt.template if isinstance(prompt, TemplatedPrompt) else None, @@ -128,17 +130,28 @@ def generate_draft( # noqa: PLR0913 }, ) + # Run the bot to generate the change. operation_recorder = _OperationRecorder() change = self._generate_change( bot, Goal(prompt_contents, timeout), - [operation_recorder, *list(tool_visitors or [])], + toolbox.with_visitors( + [operation_recorder, *list(tool_visitors or [])], + ), ) + + # Update metadata. + commit_sha = self._record_change( + change, parent_commit_sha, folio, seqno + ) + # TODO: Trim commits (sync and prompt of files which have not been + # operated on). This will improve the UX by allowing fast-forward when + # other files are edited. with self._store.cursor() as cursor: cursor.execute( sql("add-action"), { - "commit_sha": change.commit, + "commit_sha": commit_sha, "prompt_id": prompt_id, "bot_name": bot_name, "bot_class": qualified_class_name(bot.__class__), @@ -151,7 +164,7 @@ def generate_draft( # noqa: PLR0913 sql("add-operation"), [ { - "commit_sha": change.commit, + "commit_sha": commit_sha, "tool": o.tool, "reason": o.reason, "details": json.dumps(o.details), @@ -160,22 +173,129 @@ def generate_draft( # noqa: PLR0913 for o in operation_recorder.operations ], ) - _logger.info("Created new change on %s.", branch) - - delta = change.delta() - if delta and accept.value >= Accept.CHECKOUT.value: - delta.apply() + _logger.info("Created new change in folio %s.", folio.id) + + if accept.value >= Accept.MERGE.value: + self._sync_head("merge") + self._repo.git( + "merge", + "--no-ff", + "-Xtheirs", + "-m", + "draft! merge", + commit_sha, + ) if accept.value >= Accept.FINALIZE.value: - self.finalize_draft(delete=accept == Accept.NO_REGRETS) - return Draft(str(branch)) + self.finalize_folio() + + return Draft( + folio=folio, + seqno=seqno, + is_noop=change.is_noop, + walltime=change.walltime, + token_count=change.action.token_count, + ) + + def finalize_folio(self) -> Folio: + folio = _active_folio(self._repo) + if not folio: + raise RuntimeError("Not currently on a draft branch") + + with self._store.cursor() as cursor: + rows = cursor.execute(sql("get-folio-by-id"), {"id": folio.id}) + if not rows: + raise RuntimeError("Unrecognized draft branch") + [(origin_branch,)] = rows + + # Check that the origin branch has not moved to avoid unexpected diffs. + check_call = self._repo.git( + "merge-base", + "--is-ancestor", + origin_branch, + "HEAD", + expect_codes=(), + ) + if check_call.code: + raise RuntimeError("Origin branch diverged, please rebase first") + + # Create a reference to the current state for later analysis. + self._sync_head("finalize") + self._repo.git("update-ref", _draft_ref(folio.id, "@"), "HEAD") + + # Move back to the original branch, doing a little dance to keep the + # state. See https://stackoverflow.com/a/15993574 for the inspiration. + self._repo.git("checkout", "--detach") + self._repo.git("reset", "--soft", origin_branch) + self._repo.git("checkout", origin_branch) + + # Clean up folio branches. + self._repo.git( + "branch", + "-D", + folio.branch_name(), + folio.upstream_branch_name(), + ) + + _logger.info("Exited %s.", folio) + return folio + + def _create_folio(self) -> Folio: + origin_branch = self._repo.active_branch() + if origin_branch is None: + raise RuntimeError("No currently active branch") + + with self._store.cursor() as cursor: + [(folio_id,)] = cursor.execute( + sql("add-folio"), + { + "repo_uuid": str(self._repo.uuid), + "origin_branch": origin_branch, + }, + ) + folio = Folio(folio_id) + + self._repo.git("checkout", "--detach") + upstream_branch = folio.upstream_branch_name() + self._repo.git("branch", upstream_branch) + live_branch = folio.branch_name() + self._repo.git("branch", "--track", live_branch, upstream_branch) + self._repo.git("checkout", live_branch) + + return folio + + def _synced_toolbox(self) -> tuple[RepoToolbox, SHA]: + toolbox, dirty = RepoToolbox.for_working_dir(self._repo) + if dirty: + upstream_rev = self._commit_tree( + toolbox.tree_sha(), "HEAD", "sync(prompt)" + ) + _logger.debug("Created sync commit. [sha=%r]", upstream_rev) + else: + upstream_rev = "HEAD" + return toolbox, upstream_rev + + def _sync_head(self, scope: str) -> None: + self._repo.git("add", "-A") + index_call = self._repo.git( + "diff-index", + "--quiet", + "--cached", + "HEAD", + expect_codes=(), + ) + if index_call.code: + self._repo.git( + "commit", "--no-verify", "-m", f"draft! sync({scope})" + ) def _prepare_prompt( self, prompt: str | TemplatedPrompt, prompt_transform: Callable[[str], str] | None, + toolbox: RepoToolbox, ) -> str: if isinstance(prompt, TemplatedPrompt): - renderer = PromptRenderer.for_toolbox(StagingToolbox(self._repo)) + renderer = PromptRenderer.for_toolbox(toolbox) contents = renderer.render(prompt) else: contents = prompt @@ -189,167 +309,91 @@ def _generate_change( self, bot: Bot, goal: Goal, - tool_visitors: Sequence[ToolVisitor], + toolbox: RepoToolbox, ) -> _Change: - # Trigger code generation. - _logger.debug("Running bot... [bot=%s]", bot) - toolbox = StagingToolbox(self._repo, tool_visitors) + old_tree_sha = toolbox.tree_sha() + start_time = time.perf_counter() + _logger.debug("Running bot... [bot=%s]", bot) action = bot.act(goal, toolbox) - end_time = time.perf_counter() - walltime = end_time - start_time _logger.info("Completed bot action. [action=%s]", action) + end_time = time.perf_counter() - # Generate an appropriate commit. - toolbox.trim_index() - title = action.title - if not title: - title = _default_title(goal.prompt) - commit = self._repo.create_commit( - f"draft! {title}\n\n{goal.prompt}", - skip_hooks=True, - ) - + walltime = end_time - start_time + title = action.title or _default_title(goal.prompt) + new_tree_sha = toolbox.tree_sha() return _Change( - commit.sha, timedelta(seconds=walltime), action, self._repo + walltime=timedelta(seconds=walltime), + action=action, + commit_message=f"prompt: {title}\n\n{goal.prompt}", + tree_sha=new_tree_sha, + is_noop=new_tree_sha == old_tree_sha, ) - def finalize_draft(self, *, delete: bool = False) -> Draft: - branch = _Branch.active(self._repo) - if not branch: - raise RuntimeError("Not currently on a draft branch") - self._stage_repo() - - with self._store.cursor() as cursor: - rows = cursor.execute( - sql("get-branch-by-suffix"), {"suffix": branch.folio_id} - ) - if not rows: - raise RuntimeError("Unrecognized draft branch") - [(origin_branch, origin_sha)] = rows - - # We do a small dance to move back to the original branch, keeping the - # draft branch untouched. See https://stackoverflow.com/a/15993574 for - # the inspiration. - self._repo.git("checkout", "--detach") - self._repo.git("reset", "-N", origin_branch) - self._repo.git("checkout", origin_branch) - - if delete: - self._repo.git("branch", "-D", branch.name) - _logger.debug("Deleted branch %s.", branch) - - _logger.info("Exited %s.", branch) - return Draft(branch.name) - - def _create_branch(self) -> _Branch: - if self._repo.active_branch() is None: - raise RuntimeError("No currently active branch") - origin_branch = self._repo.active_branch() - origin_sha = self._repo.head_commit().sha - - self._repo.git("checkout", "--detach") - self._stage_repo() - suffix = _Branch.new_suffix() - - with self._store.cursor() as cursor: - cursor.execute( - sql("add-branch"), - { - "suffix": suffix, - "repo_path": str(self._repo.working_dir), - "origin_branch": origin_branch, - "origin_sha": origin_sha, - }, - ) - - branch = _Branch(suffix) - self._repo.checkout_new_branch(branch.name) - return branch - - def _stage_repo(self) -> Commit | None: - self._repo.git("add", "--all") - if not self._repo.has_staged_changes(): - return None - return self._repo.create_commit("draft! sync") - - def history_table(self, branch_name: str | None = None) -> Table: - path = self._repo.working_dir - branch = _Branch.active(self._repo, branch_name) - with self._store.cursor() as cursor: - if branch: - results = cursor.execute( - sql("list-prompts"), - { - "repo_path": str(path), - "branch_suffix": branch.folio_id, - }, - ) - else: - results = cursor.execute( - sql("list-drafts"), {"repo_path": str(path)} - ) - return Table.from_cursor(results) + def _record_change( + self, change: _Change, parent_commit_sha: SHA, folio: Folio, seqno: int + ) -> SHA: + commit_sha = self._commit_tree( + change.tree_sha, parent_commit_sha, change.commit_message + ) + _logger.debug("Created prompt commit. [sha=%r]", commit_sha) + self._repo.git( + "update-ref", + f"refs/heads/{folio.upstream_branch_name()}", + commit_sha, + ) + # We also add a reference to the commit so that it doesn't get GC'ed + # when the upstream branch moves. This also makes it easy to visualize + # the change using `git diff refs/drafts/xx/yy`. + self._repo.git("update-ref", _draft_ref(folio.id, seqno), commit_sha) + return commit_sha + + def _commit_tree( + self, tree_sha: SHA, parent_rev: str, message: str + ) -> SHA: + return self._repo.git( + "commit-tree", + "-p", + parent_rev, + "-m", + f"draft! {message}", + tree_sha, + ).stdout def latest_draft_prompt(self) -> str | None: """Returns the latest prompt for the current draft""" - branch = _Branch.active(self._repo) - if not branch: + folio = _active_folio(self._repo) + if not folio: return None with self._store.cursor() as cursor: result = cursor.execute( - sql("get-latest-prompt"), + sql("get-latest-folio-prompt"), { - "repo_path": str(self._repo.working_dir), - "branch_suffix": branch.folio_id, + "repo_uuid": str(self._repo.uuid), + "folio_id": folio.id, }, ).fetchone() return result[0] if result else None -type _CommitSHA = str - - @dataclasses.dataclass(frozen=True) class _Change: """A bot-generated draft, may be a no-op""" - commit: _CommitSHA - walltime: timedelta action: Action - repo: Repo = dataclasses.field(repr=False) - - def delta(self) -> _Delta | None: - diff = self.repo.git("diff-tree", "--patch", self.commit).stdout - return _Delta(diff, self.repo) if diff else None - - -@dataclasses.dataclass(frozen=True) -class _Delta: - """A change's effects, guaranteed non-empty""" - - diff: str - repo: Repo = dataclasses.field(repr=False) - - def apply(self) -> None: - # For patch applcation to work as expected (adding conflict markers as - # needed), files in the patch must exist in the index. - self.repo.git("add", "--all") - git = self.repo.git( - "apply", "--3way", "-", stdin=self.diff, expect_codes=() - ) - if "with conflicts" in git.stderr: - raise ConflictError() - if git.code != 0: - raise NotImplementedError() # TODO: Raise better error - self.repo.git("reset") + walltime: timedelta + commit_message: str + tree_sha: SHA + is_noop: bool -class ConflictError(Exception): - """A change could not be applied cleanly""" +class _OperationRecorder(ToolVisitor): + """Visitor which keeps track of which operations have been performed + This is useful to store a summary of each change in our database for later + analysis. + """ -class _OperationRecorder(ToolVisitor): def __init__(self) -> None: self.operations = list[_Operation]() @@ -399,6 +443,8 @@ def _record(self, reason: str | None, tool: str, **kwargs) -> None: @dataclasses.dataclass(frozen=True) class _Operation: + """Tool usage record""" + tool: str details: JSONObject reason: str | None @@ -406,4 +452,4 @@ class _Operation: def _default_title(prompt: str) -> str: - return textwrap.shorten(prompt, break_on_hyphens=False, width=72) + return textwrap.shorten(prompt, break_on_hyphens=False, width=55) diff --git a/src/git_draft/git.py b/src/git_draft/git.py index 3fe80d0..dc3476b 100644 --- a/src/git_draft/git.py +++ b/src/git_draft/git.py @@ -2,78 +2,24 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import Iterator, Sequence import dataclasses +import enum import logging from pathlib import Path import subprocess from typing import Self +import uuid _logger = logging.getLogger(__name__) -@dataclasses.dataclass(frozen=True) -class Commit: - """Commit newtype""" - - sha: str - - def __str__(self) -> str: - return self.sha - - -class Repo: - """Git repository""" - - def __init__(self, working_dir: Path) -> None: - self.working_dir = working_dir - - @classmethod - def enclosing(cls, path: Path) -> Self: - git = Git.run("-C", str(path), "rev-parse", "--show-toplevel") - return cls(Path(git.stdout)) - - def git( - self, - cmd: str, - *args: str, - stdin: str | None = None, - expect_codes: Sequence[int] = (0,), - ) -> Git: - return Git.run( - "-C", - str(self.working_dir), - cmd, - *args, - stdin=stdin, - expect_codes=expect_codes, - ) - - def active_branch(self) -> str | None: - return self.git("branch", "--show-current").stdout or None - - def checkout_new_branch(self, name: str) -> None: - self.git("checkout", "-b", name) - - def has_staged_changes(self) -> bool: - git = self.git("diff", "--quiet", "--staged", expect_codes=()) - return git.code != 0 - - def head_commit(self) -> Commit: - sha = self.git("rev-parse", "HEAD").stdout - return Commit(sha) - - def create_commit(self, message: str, skip_hooks: bool = False) -> Commit: - args = ["commit", "--allow-empty", "-m", message] - if skip_hooks: - args.append("--no-verify") - self.git(*args) - return self.head_commit() +type SHA = str @dataclasses.dataclass(frozen=True) -class Git: +class GitCall: """Git command execution result""" code: int @@ -81,27 +27,104 @@ class Git: stderr: str @classmethod - def run( + def sync( cls, *args: str, stdin: str | None = None, executable: str = "git", expect_codes: Sequence[int] = (0,), + working_dir: Path | None = None, ) -> Self: - _logger.debug("Running git command. [args=%r]", args) + """Run a git command synchronously""" + _logger.debug( + "Running git command. [args=%r, cwd=%r]", args, working_dir + ) popen = subprocess.Popen( [executable, *args], encoding="utf8", stdin=None if stdin is None else subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, + cwd=working_dir, ) stdout, stderr = popen.communicate(input=stdin) code = popen.returncode if expect_codes and code not in expect_codes: - raise GitError(f"Git command failed with code {code}: {stderr}") + raise GitError( + f"Git command failed with code {code}\n{stderr}\n{stdout}" + ) return cls(code, stdout.rstrip(), stderr.rstrip()) class GitError(Exception): """Git command execution error""" + + +class _ConfigKey(enum.StrEnum): + REPO_UUID = "repouuid" + DEFAULT_BOT = "bot" # TODO: Use + + @property + def fullname(self) -> str: + return f"draft.{self.value}" + + +class Repo: + """Git repository""" + + def __init__(self, working_dir: Path, uuid: uuid.UUID) -> None: + self.working_dir = working_dir + self.uuid = uuid + + @classmethod + def enclosing(cls, path: Path) -> Self: + """Returns the repo enclosing the given path""" + call = GitCall.sync("rev-parse", "--show-toplevel", working_dir=path) + working_dir = Path(call.stdout) + uuid = _ensure_repo_uuid(working_dir) + return cls(working_dir, uuid) + + def git( + self, + cmd: str, + *args: str, + stdin: str | None = None, + expect_codes: Sequence[int] = (0,), + ) -> GitCall: + """Runs a git command inside this repo""" + return GitCall.sync( + cmd, + *args, + stdin=stdin, + expect_codes=expect_codes, + working_dir=self.working_dir, + ) + + def active_branch(self) -> str | None: + return self.git("branch", "--show-current").stdout or None + + +def _ensure_repo_uuid(working_dir: Path) -> uuid.UUID: + call = GitCall.sync( + "config", + "get", + _ConfigKey.REPO_UUID.fullname, + working_dir=working_dir, + expect_codes=(), + ) + if call.code == 0: + return uuid.UUID(call.stdout) + repo_uuid = uuid.uuid4() + GitCall.sync( + "config", + "set", + _ConfigKey.REPO_UUID.fullname, + str(repo_uuid), + working_dir=working_dir, + ) + _logger.debug("Set repo UUID. [uuid=%s]", repo_uuid) + return repo_uuid + + +def null_delimited(arg: str) -> Iterator[str]: + return (item for item in arg.split("\x00") if item) diff --git a/src/git_draft/queries/add-branch.sql b/src/git_draft/queries/add-branch.sql deleted file mode 100644 index 4865cd6..0000000 --- a/src/git_draft/queries/add-branch.sql +++ /dev/null @@ -1,2 +0,0 @@ -insert into branches (suffix, repo_path, origin_branch, origin_sha) - values (:suffix, :repo_path, :origin_branch, :origin_sha); diff --git a/src/git_draft/queries/add-folio.sql b/src/git_draft/queries/add-folio.sql new file mode 100644 index 0000000..52aecc8 --- /dev/null +++ b/src/git_draft/queries/add-folio.sql @@ -0,0 +1,3 @@ +insert into folios (repo_uuid, origin_branch) + values (:repo_uuid, :origin_branch) + returning id; diff --git a/src/git_draft/queries/add-prompt.sql b/src/git_draft/queries/add-prompt.sql index b2893f8..9f74a2d 100644 --- a/src/git_draft/queries/add-prompt.sql +++ b/src/git_draft/queries/add-prompt.sql @@ -1,3 +1,9 @@ -insert into prompts (branch_suffix, template, contents) - values (:branch_suffix, :template, :contents) - returning id; +with + seqnos as ( + select coalesce(count(*), 0) + 1 as seqno + from prompts + where folio_id = :folio_id) +insert into prompts (seqno, folio_id, template, contents) + select seqno, :folio_id, :template, :contents + from seqnos + returning id, seqno diff --git a/src/git_draft/queries/create-tables.sql b/src/git_draft/queries/create-tables.sql index 5c6c58a..0b4ac3c 100644 --- a/src/git_draft/queries/create-tables.sql +++ b/src/git_draft/queries/create-tables.sql @@ -1,20 +1,24 @@ -create table if not exists branches ( - suffix text primary key, - repo_path text not null, +create table if not exists folios ( + id integer primary key, + repo_uuid text not null, created_at timestamp default current_timestamp, - origin_branch text not null, - origin_sha text not null -) without rowid; + origin_branch text not null +); + +create index if not exists folios_by_repo on folios (repo_uuid); create table if not exists prompts ( id integer primary key, + folio_id integer not null, + seqno integer not null, created_at timestamp default current_timestamp, - branch_suffix text not null, template text, contents text not null, - foreign key (branch_suffix) references branches(suffix) + foreign key (folio_id) references folios(id) ); +create unique index if not exists prompts_by_folio_seqno on prompts (folio_id, seqno); + create table if not exists actions ( commit_sha text primary key, created_at timestamp default current_timestamp, @@ -24,7 +28,7 @@ create table if not exists actions ( walltime_seconds real not null, request_count int, token_count int, - foreign key (prompt_id) references prompts(id) on delete cascade + foreign key (prompt_id) references prompts (id) on delete cascade ) without rowid; create table if not exists operations ( @@ -34,5 +38,5 @@ create table if not exists operations ( reason text, details text not null, started_at timestamp not null, - foreign key (action_commit_sha) references actions(commit_sha) on delete cascade + foreign key (action_commit_sha) references actions (commit_sha) on delete cascade ); diff --git a/src/git_draft/queries/get-branch-by-suffix.sql b/src/git_draft/queries/get-branch-by-suffix.sql deleted file mode 100644 index c4c34ef..0000000 --- a/src/git_draft/queries/get-branch-by-suffix.sql +++ /dev/null @@ -1,3 +0,0 @@ -select origin_branch, origin_sha - from branches - where suffix = :suffix; diff --git a/src/git_draft/queries/get-folio-by-id.sql b/src/git_draft/queries/get-folio-by-id.sql new file mode 100644 index 0000000..8f747a3 --- /dev/null +++ b/src/git_draft/queries/get-folio-by-id.sql @@ -0,0 +1,3 @@ +select origin_branch + from folios + where id = :id; diff --git a/src/git_draft/queries/get-latest-folio-prompt.sql b/src/git_draft/queries/get-latest-folio-prompt.sql new file mode 100644 index 0000000..a1e5a49 --- /dev/null +++ b/src/git_draft/queries/get-latest-folio-prompt.sql @@ -0,0 +1,6 @@ +select p.contents + from prompts as p + join folios as f on p.folio_id = f.id + where f.id = :folio_id + order by p.id desc + limit 1; diff --git a/src/git_draft/queries/get-latest-prompt.sql b/src/git_draft/queries/get-latest-prompt.sql deleted file mode 100644 index 6a66082..0000000 --- a/src/git_draft/queries/get-latest-prompt.sql +++ /dev/null @@ -1,6 +0,0 @@ -select p.contents - from prompts as p - join branches as b on p.branch_suffix = b.suffix - where b.repo_path = :repo_path and b.suffix = :branch_suffix - order by p.id desc - limit 1; diff --git a/src/git_draft/queries/list-drafts.sql b/src/git_draft/queries/list-drafts.sql deleted file mode 100644 index 1b0747a..0000000 --- a/src/git_draft/queries/list-drafts.sql +++ /dev/null @@ -1,12 +0,0 @@ -select - datetime(min(p.created_at), 'localtime') as created, - 'draft/' || b.suffix as branch, - min(b.origin_branch) as origin, - count(p.id) as prompts, - sum(a.token_count) as tokens - from branches as b - join prompts as p on b.suffix = p.branch_suffix - join actions as a on p.id = a.prompt_id - where b.repo_path = :repo_path - group by b.suffix - order by created desc; diff --git a/src/git_draft/queries/list-prompts.sql b/src/git_draft/queries/list-folio-prompts.sql similarity index 78% rename from src/git_draft/queries/list-prompts.sql rename to src/git_draft/queries/list-folio-prompts.sql index 43021b1..52d8910 100644 --- a/src/git_draft/queries/list-prompts.sql +++ b/src/git_draft/queries/list-folio-prompts.sql @@ -5,9 +5,9 @@ select coalesce(round(sum(a.walltime_seconds), 1), 0) as walltime, count(o.id) as ops from prompts as p - join branches as b on p.branch_suffix = b.suffix + join folios as f on p.folio_id = f.id left join actions as a on p.id = a.prompt_id left join operations as o on a.commit_sha = o.action_commit_sha - where b.repo_path = :repo_path and b.suffix = :branch_suffix + where f.id = :folio_id group by p.id order by created desc; diff --git a/src/git_draft/queries/list-folios.sql b/src/git_draft/queries/list-folios.sql new file mode 100644 index 0000000..5c5643b --- /dev/null +++ b/src/git_draft/queries/list-folios.sql @@ -0,0 +1,12 @@ +select + datetime(min(p.created_at), 'localtime') as created, + f.id as id, + min(f.origin_branch) as origin, + count(p.id) as prompts, + sum(a.token_count) as tokens + from folios as f + join prompts as p on f.id = p.folio_id + join actions as a on p.id = a.prompt_id + where f.repo_uuid = :repo_uuid + group by f.id + order by created desc; diff --git a/src/git_draft/store.py b/src/git_draft/store.py index 49c7b66..e2de64e 100644 --- a/src/git_draft/store.py +++ b/src/git_draft/store.py @@ -19,7 +19,7 @@ class Store: """Lightweight sqlite wrapper""" - _name = "v1.sqlite3" + _name = "v2.sqlite3" def __init__(self, conn: sqlite3.Connection) -> None: self._connection = conn diff --git a/src/git_draft/toolbox.py b/src/git_draft/toolbox.py index f38c957..32a1854 100644 --- a/src/git_draft/toolbox.py +++ b/src/git_draft/toolbox.py @@ -2,13 +2,16 @@ from __future__ import annotations +import collections from collections.abc import Callable, Sequence +import dataclasses import logging from pathlib import PurePosixPath import tempfile -from typing import Protocol, override +from typing import Protocol, Self, override -from .git import GitError, Repo +from .common import UnreachableError +from .git import SHA, GitError, Repo, null_delimited _logger = logging.getLogger(__name__) @@ -65,9 +68,9 @@ def delete_file( self, path: PurePosixPath, reason: str | None = None, - ) -> bool: + ) -> None: self._dispatch(lambda v: v.on_delete_file(path, reason)) - return self._delete(path) + self._delete(path) def rename_file( self, @@ -89,7 +92,7 @@ def _write( ) -> None: # pragma: no cover raise NotImplementedError() - def _delete(self, path: PurePosixPath) -> bool: # pragma: no cover + def _delete(self, path: PurePosixPath) -> None: # pragma: no cover raise NotImplementedError() def _rename( @@ -128,68 +131,154 @@ def on_rename_file( ) -> None: ... # pragma: no cover -class StagingToolbox(Toolbox): - """Git-index backed toolbox implementation +class RepoToolbox(Toolbox): + """Git-repo backed toolbox implementation + + All files are directly read from and written to an standalone tree. This + allows concurrent editing without interference with the working directory + or index. - All files are directly read from and written to the index. This allows - concurrent editing without interference with the working directory. + This toolbox is not thread-safe. """ def __init__( - self, repo: Repo, visitors: Sequence[ToolVisitor] | None = None + self, + repo: Repo, + start_rev: SHA, + visitors: Sequence[ToolVisitor] | None = None, ) -> None: super().__init__(visitors) + call = repo.git("rev-parse", "--verify", f"{start_rev}^{{tree}}") + self._tree_sha = call.stdout + self._tree_updates = list[_TreeUpdate]() self._repo = repo - self._updated = set[str]() + + @classmethod + def for_working_dir(cls, repo: Repo) -> tuple[Self, bool]: + toolbox = cls(repo, "HEAD") + head_tree_sha = toolbox.tree_sha() + + # Apply any changes from the working directory. + deleted = set[SHA]() + for path in null_delimited(repo.git("ls-files", "-dz").stdout): + deleted.add(path) + toolbox._delete(PurePosixPath(path)) + for path in null_delimited( + repo.git("ls-files", "-moz", "--exclude-standard").stdout + ): + if path in deleted: + continue # Deleted files also show up as modified + toolbox._write_from_disk(PurePosixPath(path), path) + + return toolbox, toolbox.tree_sha() != head_tree_sha + + def with_visitors(self, visitors: Sequence[ToolVisitor]) -> Self: + return self.__class__(self._repo, self.tree_sha(), visitors) + + def tree_sha(self) -> SHA: + if updates := self._tree_updates: + self._tree_sha = _update_tree(self._tree_sha, updates, self._repo) + updates.clear() + return self._tree_sha @override def _list(self) -> Sequence[PurePosixPath]: - # Show staged files. - return [ - PurePosixPath(p) - for p in self._repo.git("ls-files").stdout.splitlines() - ] + call = self._repo.git("ls-tree", "-rz", "--name-only", self.tree_sha()) + return [PurePosixPath(p) for p in null_delimited(call.stdout)] @override def _read(self, path: PurePosixPath) -> str: - # Read the file from the index. - return self._repo.git("show", f":{path}").stdout + try: + return self._repo.git("show", f"{self.tree_sha()}:{path}").stdout + except GitError as exc: + msg = str(exc) + if "does not exist in" in msg or "exists on disk, but not" in msg: + raise FileNotFoundError(f"{path} does not exist") + raise @override def _write(self, path: PurePosixPath, contents: str) -> None: - self._updated.add(str(path)) # Update the index without touching the worktree. # https://stackoverflow.com/a/25352119 with tempfile.NamedTemporaryFile(delete_on_close=False) as temp: temp.write(contents.encode("utf8")) temp.close() - sha = self._repo.git( - "hash-object", "-w", temp.name, "--path", str(path) - ).stdout - mode = 644 # TODO: Read from original file if it exists. - self._repo.git( - "update-index", "--add", "--cacheinfo", f"{mode},{sha},{path}" - ) + self._write_from_disk(path, temp.name) + + def _write_from_disk( + self, path: PurePosixPath, contents_path: str + ) -> None: + blob_sha = self._repo.git( + "hash-object", + "-w", + "--path", + str(path), + contents_path, + ).stdout + self._tree_updates.append(_WriteBlob(path, blob_sha)) @override - def _delete(self, path: PurePosixPath) -> bool: - try: - self._repo.git("rm", "--cached", "--", str(path)) - except GitError as err: - _logger.warning("Failed to delete file. [err=%r]", err) - return False - else: - self._updated.add(str(path)) - return True - - def trim_index(self) -> None: - """Unstage any files which have not been written to""" - git = self._repo.git("diff", "--name-only", "--cached") - untouched = [ - path - for path in git.stdout.splitlines() - if path and path not in self._updated - ] - if untouched: - self._repo.git("reset", "--", *untouched) - _logger.debug("Trimmed index. [reset_paths=%s]", untouched) + def _delete(self, path: PurePosixPath) -> None: + self._tree_updates.append(_DeleteBlob(path)) + + +class _TreeUpdate: + """Generic tree update""" + + +@dataclasses.dataclass(frozen=True) +class _WriteBlob(_TreeUpdate): + path: PurePosixPath + blob_sha: SHA + + +@dataclasses.dataclass(frozen=True) +class _DeleteBlob(_TreeUpdate): + path: PurePosixPath + + +def _update_tree(sha: SHA, updates: Sequence[_TreeUpdate], repo: Repo) -> SHA: + if not updates: + return sha + + blob_shas = collections.defaultdict[PurePosixPath, dict[str, str]](dict) + for update in updates: + match update: + case _WriteBlob(path, blob_sha): + blob_shas[path.parent][path.name] = blob_sha + case _DeleteBlob(path): + blob_shas[path.parent][path.name] = "" + case _: + raise UnreachableError() + + def visit_tree(sha: SHA, path: PurePosixPath) -> SHA: + old_lines = null_delimited(repo.git("ls-tree", "-z", sha).stdout) + new_blob_shas = blob_shas[path] + + new_lines = list[str]() + for line in old_lines: + old_prefix, name = line.split("\t", 1) + mode, otype, old_sha = old_prefix.split(" ") + match otype: + case "blob": + new_sha = new_blob_shas.pop(name, old_sha) + if new_sha: + new_lines.append(f"{mode} blob {new_sha}\t{name}") + case "tree": + new_sha = visit_tree(old_sha, path / name) + new_lines.append(f"040000 tree {new_sha}\t{name}") + case _: + raise UnreachableError() + + for name, blob_sha in new_blob_shas.items(): + if blob_sha: + new_lines.append(f"100644 blob {blob_sha}\t{name}") + else: + _logger.warning("Unmatched deletion. [path=%s]", path / name) + + if new_lines == old_lines: + return sha + + return repo.git("mktree", "-z", stdin="\x00".join(new_lines)).stdout + + return visit_tree(sha, PurePosixPath(".")) diff --git a/tests/git_draft/common_test.py b/tests/git_draft/common_test.py index cc158a9..0dec8ef 100644 --- a/tests/git_draft/common_test.py +++ b/tests/git_draft/common_test.py @@ -12,17 +12,6 @@ def test_ensure_state_home() -> None: assert path.exists() -class TestRandomId: - def test_length(self) -> None: - length = 10 - result = sut.random_id(length) - assert len(result) == length - - def test_content(self) -> None: - result = sut.random_id(1000) - assert set(result).issubset(sut._alphabet) - - class TestConfig: @pytest.fixture(autouse=True) def config_home(self, monkeypatch, tmp_path) -> Path: diff --git a/tests/git_draft/conftest.py b/tests/git_draft/conftest.py index 1841bb4..1229f48 100644 --- a/tests/git_draft/conftest.py +++ b/tests/git_draft/conftest.py @@ -1,16 +1,17 @@ from collections.abc import Iterator +import os from pathlib import Path import pytest -from git_draft.git import Git, Repo +from git_draft.git import GitCall, Repo @pytest.fixture def repo(tmp_path: Path) -> Iterator[Repo]: path = tmp_path / "repo" path.mkdir() - Git.run("-C", str(path), "init", "-b", "main") + GitCall.sync("init", "-b", "main", working_dir=path) repo = Repo.enclosing(path) repo.git("commit", "-m", "init", "--allow-empty") yield repo @@ -20,3 +21,35 @@ def repo(tmp_path: Path) -> Iterator[Repo]: def state_home(monkeypatch, tmp_path) -> None: path = tmp_path / "state" monkeypatch.setenv("XDG_STATE_HOME", str(path)) + + +class RepoFS: + def __init__(self, repo: Repo) -> None: + self._repo = repo + + def path(self, name: str) -> Path: + return Path(self._repo.working_dir, name) + + def read(self, name: str) -> str | None: + try: + with open(self.path(name)) as f: + return f.read() + except FileNotFoundError: + return None + + def write(self, name: str, contents="") -> None: + with open(self.path(name), "w") as f: + f.write(contents) + + def delete(self, name: str) -> None: + os.remove(self.path(name)) + + def flush(self, message: str = "flush") -> str: + self._repo.git("add", "-A") + self._repo.git("commit", "-m", message) + return self._repo.git("rev-parse", "HEAD").stdout + + +@pytest.fixture +def repo_fs(repo: Repo) -> RepoFS: + return RepoFS(repo) diff --git a/tests/git_draft/drafter_test.py b/tests/git_draft/drafter_test.py index 5ac7adf..8e62489 100644 --- a/tests/git_draft/drafter_test.py +++ b/tests/git_draft/drafter_test.py @@ -1,16 +1,16 @@ from collections.abc import Callable, Mapping, Sequence -import os -from pathlib import Path, PurePosixPath +from pathlib import PurePosixPath from typing import Self import pytest from git_draft.bots import Action, Bot, Goal, Toolbox import git_draft.drafter as sut -from git_draft.git import Commit, Repo -from git_draft.prompt import TemplatedPrompt +from git_draft.git import SHA, Repo from git_draft.store import Store +from .conftest import RepoFS + class _SimpleBot(Bot): """A simple bot which updates files to match a mapping""" @@ -41,30 +41,14 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action: class TestDrafter: @pytest.fixture(autouse=True) - def setup(self, repo: Repo) -> None: + def setup(self, repo: Repo, repo_fs: RepoFS) -> None: self._repo = repo + self._fs = repo_fs self._drafter = sut.Drafter(Store.in_memory(), repo) - def _path(self, name: str) -> Path: - return Path(self._repo.working_dir, name) - - def _read(self, name: str) -> str | None: - try: - with open(self._path(name)) as f: - return f.read() - except FileNotFoundError: - return None - - def _write(self, name: str, contents="") -> None: - with open(self._path(name), "w") as f: - f.write(contents) - - def _delete(self, name: str) -> None: - os.remove(self._path(name)) - - def _commits(self, ref: str | None = None) -> Sequence[Commit]: + def _commits(self, ref: str | None = None) -> Sequence[SHA]: git = self._repo.git("log", "--pretty=format:%H", ref or "HEAD") - return [Commit(i) for i in git.stdout.splitlines()] + return git.stdout.splitlines() def _commit_files(self, ref: str) -> frozenset[str]: git = self._repo.git( @@ -76,26 +60,38 @@ def _checkout(self) -> None: self._repo.git("checkout", "--", ".") def test_generate_draft(self) -> None: + self._fs.write("p1", "a") self._drafter.generate_draft("hello", _SimpleBot({"p1": "A"})) - assert len(self._commits()) == 2 + assert len(self._commits()) == 1 + assert len(self._commits("@{u}")) == 3 + assert self._fs.read("p1") == "a" def test_generate_empty_draft(self) -> None: self._drafter.generate_draft("hello", _SimpleBot.noop()) - assert len(self._commits()) == 2 - - def test_generate_stages_then_resets_worktree(self) -> None: - self._write("p1", "a") - self._write("p2", "b") + assert len(self._commits()) == 1 + assert len(self._commits("@{u}")) == 2 - class CustomBot(Bot): - def act(self, _goal: Goal, toolbox: Toolbox) -> Action: - assert toolbox.read_file(PurePosixPath("p1")) == "a" - toolbox.write_file(PurePosixPath("p2"), "B") - toolbox.write_file(PurePosixPath("p3"), "C") - return Action() + def test_generate_draft_accept_merge(self) -> None: + self._fs.write("p1", "a") + self._drafter.generate_draft( + "hello", + _SimpleBot({"p2": "b"}), + accept=sut.Accept.MERGE, + ) + assert len(self._commits()) == 5 # init, sync, prompt, sync, merge + assert self._fs.read("p1") == "a" + assert self._fs.read("p2") == "b" - self._drafter.generate_draft("hello", CustomBot()) - assert self._commit_files("HEAD") == set(["p2", "p3"]) + def test_generate_draft_accept_finalize(self) -> None: + self._fs.write("p1", "a") + self._drafter.generate_draft( + "hello", + _SimpleBot({"p1": "A", "p2": "b"}), + accept=sut.Accept.FINALIZE, + ) + assert len(self._commits()) == 1 # init + assert self._fs.read("p1") == "A" + assert self._fs.read("p2") == "b" def test_generate_outside_branch(self) -> None: self._repo.git("checkout", "--detach") @@ -106,124 +102,24 @@ def test_generate_empty_prompt(self) -> None: with pytest.raises(ValueError): self._drafter.generate_draft("", _SimpleBot.noop()) - def test_generate_dirty_index_no_reset(self) -> None: - self._write("log") - self._repo.git("add", "--all") - with pytest.raises(ValueError): - self._drafter.generate_draft("hi", _SimpleBot.noop()) - - def test_generate_dirty_index_reset_sync(self) -> None: - self._write("log", "11") - self._repo.git("add", "--all") - self._drafter.generate_draft("hi", _SimpleBot.prompt(), reset=True) - assert self._read("log") == "11" - assert not self._path("PROMPT").exists() - self._repo.git("checkout", ".") - assert self._read("PROMPT") == "hi" - assert len(self._commits()) == 3 # init, sync, prompt - - def test_generate_clean_index_sync(self) -> None: - prompt = TemplatedPrompt("add-test", {"symbol": "abc"}) - self._drafter.generate_draft(prompt, _SimpleBot({"p1": "abc"})) - self._repo.git("checkout", ".") - assert "abc" in (self._read("p1") or "") - assert len(self._commits()) == 2 # sync, prompt - def test_generate_reuse_branch(self) -> None: bot = _SimpleBot({"prompt": lambda goal: goal.prompt}) - self._drafter.generate_draft("prompt1", bot) - self._drafter.generate_draft("prompt2", bot) - self._repo.git("checkout", ".") - assert self._read("prompt") == "prompt2" - assert len(self._commits()) == 4 # init, sync, prompt, prompt - - def test_generate_reuse_branch_sync(self) -> None: - bot = _SimpleBot({"p1": "A"}) - self._drafter.generate_draft("prompt1", bot) - self._drafter.generate_draft("prompt2", bot) - assert len(self._commits()) == 4 # init, prompt, sync, prompt - - def test_generate_noop(self) -> None: - self._write("unrelated", "a") - self._drafter.generate_draft("prompt", _SimpleBot.noop()) - assert len(self._commits()) == 3 # init, sync, prompt - assert not self._commit_files("HEAD") - - def test_generate_accept_checkout(self) -> None: - self._write("p1", "A") - self._write("p2", "B") - self._write("p4", "E") - self._drafter.generate_draft( - "hello", - _SimpleBot({"p1": "C", "p3": "D", "p4": None}), - accept=sut.Accept.CHECKOUT, - ) - assert self._read("p1") == "C" - assert self._read("p2") == "B" - assert self._read("p3") == "D" - assert self._read("p4") is None - - @pytest.mark.skip(reason="conflict resolution in flux") - def test_generate_accept_checkout_conflict(self) -> None: - self._write("p1", "A") - with pytest.raises(sut.ConflictError): - self._drafter.generate_draft( - "hello", - _SimpleBot({"p1": "B", "p2": "C"}), - accept=sut.Accept.CHECKOUT, - ) - assert """<<<<<<< ours\nA""" in (self._read("p1") or "") - assert self._read("p2") == "C" - - def test_generate_accept_finalize(self) -> None: - self._write("p1", "A") - self._drafter.generate_draft( - "hello", - _SimpleBot({"p2": "B"}), - accept=sut.Accept.FINALIZE, - ) - assert self._read("p1") == "A" - assert self._read("p2") == "B" - assert self._repo.active_branch() == "main" + self._drafter.generate_draft("prompt1", bot, sut.Accept.MERGE) + self._drafter.generate_draft("prompt2", bot, sut.Accept.MERGE) + assert self._fs.read("prompt") == "prompt2" def test_delete_unknown_file(self) -> None: self._drafter.generate_draft("hello", _SimpleBot({"p1": None})) def test_finalize_keeps_changes(self) -> None: - self._write("p1.txt", "a1") - self._drafter.generate_draft("hello", _SimpleBot.prompt()) - self._checkout() - self._write("p1.txt", "a2") - self._drafter.finalize_draft() - assert self._read("p1.txt") == "a2" - assert self._read("PROMPT") == "hello" - - def test_finalize_and_sync(self) -> None: - draft = self._drafter.generate_draft( - "hello", - _SimpleBot.prompt(), - accept=sut.Accept.CHECKOUT, - ) - self._write("PROMPT", "a2") - self._drafter.finalize_draft() - assert self._read("PROMPT") == "a2" - commits = self._commits(draft.branch_name) - assert len(commits) == 3 # init, prompt, sync - assert ( - "sync" - in self._repo.git( - "log", "--format=%B", "-n1", commits[0].sha - ).stdout + self._fs.write("p1.txt", "a1") + self._drafter.generate_draft( + "hello", _SimpleBot.prompt(), sut.Accept.MERGE ) - - def test_history_table_empty(self) -> None: - table = self._drafter.history_table() - assert not table - - def test_history_table_active_draft(self) -> None: - self._drafter.generate_draft("hello", _SimpleBot.noop()) - table = self._drafter.history_table() - assert table + self._fs.write("p1.txt", "a2") + self._drafter.finalize_folio() + assert self._fs.read("p1.txt") == "a2" + assert self._fs.read("PROMPT") == "hello" def test_latest_draft_prompt(self) -> None: bot = _SimpleBot.noop() diff --git a/tests/git_draft/prompt_test.py b/tests/git_draft/prompt_test.py index 1d5357c..90685ac 100644 --- a/tests/git_draft/prompt_test.py +++ b/tests/git_draft/prompt_test.py @@ -1,13 +1,13 @@ import pytest import git_draft.prompt as sut -from git_draft.toolbox import StagingToolbox +from git_draft.toolbox import RepoToolbox class TestPromptRenderer: @pytest.fixture(autouse=True) def setup(self, repo) -> None: - toolbox = StagingToolbox(repo) + toolbox = RepoToolbox(repo, "HEAD") self._renderer = sut.PromptRenderer.for_toolbox(toolbox) def test_ok(self) -> None: diff --git a/tests/git_draft/toolbox_test.py b/tests/git_draft/toolbox_test.py index ada65fc..085cadc 100644 --- a/tests/git_draft/toolbox_test.py +++ b/tests/git_draft/toolbox_test.py @@ -1,51 +1,69 @@ -from pathlib import Path, PurePosixPath +from pathlib import PurePosixPath import pytest -from git_draft.git import GitError, Repo +from git_draft.git import Repo import git_draft.toolbox as sut +from .conftest import RepoFS -class TestStagingToolbox: + +PPP = PurePosixPath + + +class TestRepoToolbox: @pytest.fixture(autouse=True) - def setup(self, repo: Repo) -> None: - self._toolbox = sut.StagingToolbox(repo) - - def test_list_files(self, repo: Repo) -> None: - assert self._toolbox.list_files() == [] - names = set(["one.txt", "two.txt"]) - for name in names: - with Path(repo.working_dir, name).open("w") as f: - f.write("ok") - repo.git("add", "--all") - assert set(str(p) for p in self._toolbox.list_files()) == names - - def test_read_file(self, repo: Repo) -> None: - with Path(repo.working_dir, "one").open("w") as f: - f.write("ok") - - path = PurePosixPath("one") - with pytest.raises(GitError): - self._toolbox.read_file(path) - - repo.git("add", "--all") - assert self._toolbox.read_file(path) == "ok" - - def test_write_file(self, repo: Repo) -> None: - self._toolbox.write_file(PurePosixPath("one"), "hi") - - path = Path(repo.working_dir, "one") - assert not path.exists() - - repo.git("checkout-index", "--all") - with path.open() as f: - assert f.read() == "hi" - - def test_rename_file(self, repo: Repo) -> None: - self._toolbox.write_file(PurePosixPath("one"), "hi") - self._toolbox.rename_file(PurePosixPath("one"), PurePosixPath("two")) - - repo.git("checkout-index", "--all") - assert not Path(repo.working_dir, "one").exists() - with Path(repo.working_dir, "two").open() as f: - assert f.read() == "hi" + def setup(self, repo: Repo, repo_fs: RepoFS) -> None: + self._repo = repo + self._fs = repo_fs + + def test_list_files(self) -> None: + self._fs.write("f1", "a") + self._fs.write("f2", "b") + self._fs.flush() + + toolbox = sut.RepoToolbox(self._repo, "HEAD") + self._fs.delete("f2") + self._fs.write("f3", "c") + assert set(str(p) for p in toolbox.list_files()) == {"f1", "f2"} + + def test_read_file(self) -> None: + self._fs.write("f1", "a") + sha = self._fs.flush() + self._fs.write("f1", "aa") + self._fs.flush() + self._fs.write("f2", "b") + + toolbox = sut.RepoToolbox(self._repo, sha) + assert toolbox.read_file(PPP("f1")) == "a" + assert toolbox.read_file(PPP("f2")) is None + assert toolbox.read_file(PPP("f3")) is None + + def test_write_file(self) -> None: + self._fs.write("f1", "a") + self._fs.write("f2", "b") + sha = self._fs.flush() + self._fs.write("f1", "aa") + self._fs.flush() + + toolbox = sut.RepoToolbox(self._repo, sha) + toolbox.write_file(PPP("f1"), "aaa") + toolbox.write_file(PPP("f3"), "c") + assert toolbox.read_file(PPP("f1")) == "aaa" + assert toolbox.read_file(PPP("f3")) == "c" + assert self._fs.read("f1") == "aa" + assert self._fs.read("f3") is None + + def test_for_working_dir_dirty(self) -> None: + self._fs.write("f1", "a") + self._fs.write("f2", "b") + self._fs.write("f3", "c") + self._fs.flush() + self._fs.write("f1", "aa") + self._fs.delete("f2") + + toolbox, dirty = sut.RepoToolbox.for_working_dir(self._repo) + assert dirty + assert toolbox.read_file(PPP("f1")) == "aa" + assert toolbox.read_file(PPP("f2")) is None + assert toolbox.read_file(PPP("f3")) == "c"