diff --git a/README.md b/README.md index dee1a39..623886f 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ * Concurrent edits. By default `git-draft` does not touch the working directory. * Customizable prompt templates. * Extensible bot API. +* Local data collection for privacy-friendly analytics. ## Installation @@ -23,8 +24,14 @@ pipx install git-draft[openai] * Mechanism for reporting feedback from a bot, and possibly allowing user to interactively respond. * Add configuration option to auto sync and `--no-sync` flag. Similar to reset. -* Add "amend" commit when finalizing. This could be useful training data, - showing what the bot did not get right. + Also rename both options to `sync` and `reset`, this will make it more natural + to support a similar config option for `accept`. +* Add `--sync` `finalize` option which creates a additional commit when + finalizing if any changes were added to the bot's output. This could be useful + training data, showing what the bot did not get right. * Convenience `--accept` functionality for simple cases: checkout option which applies the changes, and finalizes the draft if specified multiple times. For example `git draft -aa add-test symbol=foo` +* Support file rename tool. + * https://stackoverflow.com/q/49853177/1062617 + * https://stackoverflow.com/q/6658313/1062617 diff --git a/docs/git-draft.adoc b/docs/git-draft.adoc index 5154750..d32f545 100644 --- a/docs/git-draft.adoc +++ b/docs/git-draft.adoc @@ -18,8 +18,8 @@ IMPORTANT: `git-draft` is WIP. == Synopsis [verse] -git draft [options] [--generate] [--bot BOT] [--edit] [--reset | --no-reset] [--sync] [TEMPLATE [VARIABLE...]] -git draft [options] --finalize [--delete] +git draft [options] [--generate] [--accept... | no-accept] [--bot BOT] [--edit] [--reset | --no-reset] [--sync | --no-sync] [TEMPLATE [VARIABLE...]] +git draft [options] --finalize [--delete] [--sync | --no-sync] git draft [options] --show-drafts [--json] git draft [options] --show-prompts [--json] [PROMPT] git draft [options] --show-templates [--json | [--edit] TEMPLATE] @@ -27,11 +27,18 @@ git draft [options] --show-templates [--json | [--edit] TEMPLATE] == Description -`git-draft` is a git-centric way to edit code using AI. +`git-draft` is a git-centric way to develop using AI. == Options +-a:: +--accept:: +--no-accept:: + Check out generated changes automatically. + Can be repeated. + This may fail if you manually edit files that the bot updates during generation. + -b BOT:: --bot=BOT:: Bot name. @@ -42,15 +49,17 @@ git draft [options] --show-templates [--json | [--edit] TEMPLATE] -e:: --edit:: - Edit. + Enable interactive editing of prompts and templates. + See `--generate` and `--show-templates` for details. -F:: --finalize:: - TODO + Go back to the draft's origin branch with the current working directory. -G:: --generate:: - TODO + Add an AI-generated commit. + If the `--edit` option is set, an interactive editor will be open with the rendered prompt to allow modification before it is forwarded to the bot. -h:: --help:: @@ -65,22 +74,25 @@ git draft [options] --show-templates [--json | [--edit] TEMPLATE] --reset:: --no-reset:: - TODO + Controls behavior when staged changes are present at the start of a generate command. + If enabled, these changes are automatically reset and combined with other working directory changes. + Otherwise an error is raised. --root:: Repository search root. -D:: --show-drafts:: - TODO + List recently created drafts. -P:: --show-prompts:: - TODO + Lists recently used prompts. -T:: --show-templates:: - TODO + Lists available templates. + With an template name argument, displays the corresponding template's contents or, if the `--edit` option is set, opens an interactive editor. -s:: --sync:: diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index 216d9f9..018d35f 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -11,7 +11,7 @@ from .bots import load_bot from .common import PROGRAM, Config, UnreachableError, ensure_state_home -from .drafter import Drafter +from .drafter import Accept, Drafter from .editor import open_editor from .prompt import Template, TemplatedPrompt, find_template, templates_table from .store import Store @@ -64,6 +64,12 @@ def callback( add_command("show-prompts", short="P", help="show prompt history") add_command("show-templates", short="T", help="show template information") + parser.add_option( + "-a", + "--accept", + help="apply generated changes", + action="count", + ) parser.add_option( "-b", "--bot", @@ -171,67 +177,69 @@ def main() -> None: # noqa: PLR0912 PLR0915 logging.basicConfig(level=config.log_level, filename=str(log_path)) drafter = Drafter.create(store=Store.persistent(), path=opts.root) - command = getattr(opts, "command", "generate") - if command == "generate": - bot_config = None - if opts.bot: - bot_configs = [c for c in config.bots if c.name == opts.bot] - if len(bot_configs) != 1: - raise ValueError(f"Found {len(bot_configs)} matching bots") - bot_config = bot_configs[0] - elif config.bots: - bot_config = config.bots[0] - bot = load_bot(bot_config) - - prompt: str | TemplatedPrompt - editable = opts.edit - if args: - prompt = TemplatedPrompt.parse(args[0], *args[1:]) - elif opts.edit: - editable = False - prompt = edit( - text=drafter.latest_draft_prompt() or _PROMPT_PLACEHOLDER + match getattr(opts, "command", "generate"): + case "generate": + bot_config = None + if opts.bot: + bot_configs = [c for c in config.bots if c.name == opts.bot] + if len(bot_configs) != 1: + raise ValueError(f"Found {len(bot_configs)} matching bots") + bot_config = bot_configs[0] + elif config.bots: + bot_config = config.bots[0] + bot = load_bot(bot_config) + + prompt: str | TemplatedPrompt + editable = opts.edit + if args: + prompt = TemplatedPrompt.parse(args[0], *args[1:]) + elif opts.edit: + editable = False + prompt = edit( + text=drafter.latest_draft_prompt() or _PROMPT_PLACEHOLDER + ) + else: + prompt = sys.stdin.read() + + accept = Accept(opts.accept or 0) + name = drafter.generate_draft( + prompt, + bot, + accept=accept, + bot_name=opts.bot, + prompt_transform=open_editor if editable else None, + tool_visitors=[ToolPrinter()], + reset=config.auto_reset if opts.reset is None else opts.reset, + sync=opts.sync, ) - else: - prompt = sys.stdin.read() - - name = drafter.generate_draft( - prompt, - bot, - bot_name=opts.bot, - prompt_transform=open_editor if editable else None, - tool_visitors=[ToolPrinter()], - reset=config.auto_reset if opts.reset is None else opts.reset, - sync=opts.sync, - ) - print(f"Refined {name}.") - elif command == "finalize": - name = drafter.finalize_draft(delete=opts.delete) - print(f"Finalized {name}.") - elif command == "show-drafts": - table = drafter.history_table(args[0] if args else None) - if table: - print(table.to_json() if opts.json else table) - elif command == "show-prompts": - raise NotImplementedError() # TODO: Implement - elif command == "show-templates": - if args: - name = args[0] - tpl = find_template(name) - if opts.edit: - if tpl: - edit(path=tpl.local_path(), text=tpl.source) + print(f"Generated change in {name}.") + case "finalize": + name = drafter.finalize_draft(delete=opts.delete) + print(f"Finalized {name}.") + case "show-drafts": + table = drafter.history_table(args[0] if args else None) + if table: + print(table.to_json() if opts.json else table) + case "show-prompts": + raise NotImplementedError() # TODO: Implement + case "show-templates": + if args: + name = args[0] + tpl = find_template(name) + if opts.edit: + if tpl: + edit(path=tpl.local_path(), text=tpl.source) + else: + edit(path=Template.local_path_for(name)) else: - edit(path=Template.local_path_for(name)) + if not tpl: + raise ValueError(f"No template named {name!r}") + print(tpl.source) else: - if not tpl: - raise ValueError(f"No template named {name!r}") - print(tpl.source) - else: - table = templates_table() - print(table.to_json() if opts.json else table) - else: - raise UnreachableError() + table = templates_table() + print(table.to_json() if opts.json else table) + case _: + raise UnreachableError() if __name__ == "__main__": diff --git a/src/git_draft/drafter.py b/src/git_draft/drafter.py index c0468b8..6715ded 100644 --- a/src/git_draft/drafter.py +++ b/src/git_draft/drafter.py @@ -4,18 +4,20 @@ from collections.abc import Callable, Sequence import dataclasses -from datetime import datetime +from datetime import datetime, timedelta +import enum import json import logging from pathlib import PurePosixPath import re from re import Match +import tempfile import textwrap import time import git -from .bots import Bot, Goal +from .bots import Action, Bot, Goal from .common import JSONObject, Table, qualified_class_name, random_id from .prompt import PromptRenderer, TemplatedPrompt from .store import Store, sql @@ -25,6 +27,15 @@ _logger = logging.getLogger(__name__) +class Accept(enum.Enum): + """Valid change accept mode""" + + MANUAL = enum.auto() + CHECKOUT = enum.auto() + FINALIZE = enum.auto() + NO_REGRETS = enum.auto() + + @dataclasses.dataclass(frozen=True) class _Branch: """Draft branch""" @@ -76,12 +87,13 @@ def generate_draft( # noqa: PLR0913 self, prompt: str | TemplatedPrompt, bot: Bot, + accept: Accept = Accept.MANUAL, bot_name: str | None = None, - tool_visitors: Sequence[ToolVisitor] | None = None, prompt_transform: Callable[[str], str] | None = None, reset: bool = False, sync: bool = False, timeout: float | None = None, + tool_visitors: Sequence[ToolVisitor] | None = None, ) -> str: if timeout is not None: raise NotImplementedError() # TODO: Implement @@ -94,73 +106,50 @@ def generate_draft( # noqa: PLR0913 # Ensure that we are on a draft branch. branch = _Branch.active(self._repo) if branch: - self._stage_changes(sync) + self._stage_repo(sync) _logger.debug("Reusing active branch %s.", branch) else: branch = self._create_branch(sync) _logger.debug("Created branch %s.", branch) # Handle prompt templating and editing. - if isinstance(prompt, TemplatedPrompt): - template: str | None = prompt.template - renderer = PromptRenderer.for_toolbox(StagingToolbox(self._repo)) - prompt_contents = renderer.render(prompt) - else: - template = None - prompt_contents = prompt - if prompt_transform: - prompt_contents = prompt_transform(prompt_contents) - if not prompt_contents.strip(): - raise ValueError("Aborting: empty prompt") + prompt_contents = self._prepare_prompt(prompt, prompt_transform) with self._store.cursor() as cursor: [(prompt_id,)] = cursor.execute( sql("add-prompt"), { "branch_suffix": branch.suffix, - "template": template, + "template": prompt.template + if isinstance(prompt, TemplatedPrompt) + else None, "contents": prompt_contents, }, ) - # Trigger code generation. - _logger.debug("Running bot... [bot=%s]", bot) operation_recorder = _OperationRecorder() - tool_visitors = [operation_recorder, *list(tool_visitors or [])] - toolbox = StagingToolbox(self._repo, tool_visitors) - start_time = time.perf_counter() - goal = Goal(prompt_contents, timeout) - action = bot.act(goal, toolbox) - end_time = time.perf_counter() - walltime = end_time - start_time - _logger.info("Completed bot action. [action=%s]", action) - - # Generate an appropriate commit and update our database. - toolbox.trim_index() - title = action.title - if not title: - title = _default_title(prompt_contents) - commit = self._repo.index.commit( - f"draft! {title}\n\n{prompt_contents}", - skip_hooks=True, + change = self._generate_change( + bot, + Goal(prompt_contents, timeout), + [operation_recorder, *list(tool_visitors or [])], ) with self._store.cursor() as cursor: cursor.execute( sql("add-action"), { - "commit_sha": commit.hexsha, + "commit_sha": change.commit, "prompt_id": prompt_id, "bot_name": bot_name, "bot_class": qualified_class_name(bot.__class__), - "walltime": walltime, - "request_count": action.request_count, - "token_count": action.token_count, + "walltime_seconds": change.walltime.total_seconds(), + "request_count": change.action.request_count, + "token_count": change.action.token_count, }, ) cursor.executemany( sql("add-operation"), [ { - "commit_sha": commit.hexsha, + "commit_sha": change.commit, "tool": o.tool, "reason": o.reason, "details": json.dumps(o.details), @@ -169,10 +158,60 @@ def generate_draft( # noqa: PLR0913 for o in operation_recorder.operations ], ) + _logger.info("Created new change on %s.", branch) - _logger.info("Completed generation for %s.", branch) + delta = change.delta() + if delta and accept.value >= Accept.CHECKOUT.value: + delta.apply() + if accept.value >= Accept.FINALIZE.value: + self.finalize_draft(delete=accept == Accept.NO_REGRETS) return str(branch) + def _prepare_prompt( + self, + prompt: str | TemplatedPrompt, + prompt_transform: Callable[[str], str] | None, + ) -> str: + if isinstance(prompt, TemplatedPrompt): + renderer = PromptRenderer.for_toolbox(StagingToolbox(self._repo)) + contents = renderer.render(prompt) + else: + contents = prompt + if prompt_transform: + contents = prompt_transform(contents) + if not contents.strip(): + raise ValueError("Empty prompt") + return contents + + def _generate_change( + self, + bot: Bot, + goal: Goal, + tool_visitors: Sequence[ToolVisitor], + ) -> _Change: + # Trigger code generation. + _logger.debug("Running bot... [bot=%s]", bot) + toolbox = StagingToolbox(self._repo, tool_visitors) + start_time = time.perf_counter() + action = bot.act(goal, toolbox) + end_time = time.perf_counter() + walltime = end_time - start_time + _logger.info("Completed bot action. [action=%s]", action) + + # Generate an appropriate commit. + toolbox.trim_index() + title = action.title + if not title: + title = _default_title(goal.prompt) + commit = self._repo.index.commit( + f"draft! {title}\n\n{goal.prompt}", + skip_hooks=True, + ) + + return _Change( + commit.hexsha, timedelta(seconds=walltime), action, self._repo + ) + def finalize_draft(self, *, delete: bool = False) -> str: branch = _Branch.active(self._repo) if not branch: @@ -200,6 +239,40 @@ def finalize_draft(self, *, delete: bool = False) -> str: _logger.info("Exited %s.", branch) return branch.name + def _create_branch(self, sync: bool) -> _Branch: + if self._repo.head.is_detached: + raise RuntimeError("No currently active branch") + origin_branch = self._repo.active_branch.name + origin_sha = self._repo.commit().hexsha + + self._repo.git.checkout(detach=True) + sync_sha = self._stage_repo(sync) + suffix = _Branch.new_suffix() + + with self._store.cursor() as cursor: + cursor.execute( + sql("add-branch"), + { + "suffix": suffix, + "repo_path": self._repo.working_dir, + "origin_branch": origin_branch, + "origin_sha": origin_sha, + "sync_sha": sync_sha, + }, + ) + + branch = _Branch(suffix) + branch_ref = self._repo.create_head(branch.name) + self._repo.git.checkout(branch_ref) + return branch + + def _stage_repo(self, sync: bool) -> str | None: + self._repo.git.add(all=True) + if not sync or not self._repo.is_dirty(untracked_files=True): + return None + ref = self._repo.index.commit("draft! sync") + return ref.hexsha + def history_table(self, branch_name: str | None = None) -> Table: path = self._repo.working_dir branch = _Branch.active(self._repo, branch_name) @@ -233,64 +306,50 @@ def latest_draft_prompt(self) -> str | None: ).fetchone() return result[0] if result else None - def _create_branch(self, sync: bool) -> _Branch: - if self._repo.head.is_detached: - raise RuntimeError("No currently active branch") - origin_branch = self._repo.active_branch.name - origin_sha = self._repo.commit().hexsha - self._repo.git.checkout(detach=True) - sync_sha = self._stage_changes(sync) - suffix = _Branch.new_suffix() +type _CommitSHA = str - with self._store.cursor() as cursor: - cursor.execute( - sql("add-branch"), - { - "suffix": suffix, - "repo_path": self._repo.working_dir, - "origin_branch": origin_branch, - "origin_sha": origin_sha, - "sync_sha": sync_sha, - }, - ) - branch = _Branch(suffix) - branch_ref = self._repo.create_head(branch.name) - self._repo.git.checkout(branch_ref) - return branch +@dataclasses.dataclass(frozen=True) +class _Change: + """A bot-generated draft, may be a no-op""" - def _stage_changes(self, sync: bool) -> str | None: - self._repo.git.add(all=True) - if not sync or not self._repo.is_dirty(untracked_files=True): - return None - ref = self._repo.index.commit("draft! sync") - return ref.hexsha + commit: _CommitSHA + walltime: timedelta + action: Action + repo: git.Repo = dataclasses.field(repr=False) - def _untracked(self) -> frozenset[str]: - text = self._repo.git.ls_files(exclude_standard=True, others=True) - return frozenset(text.splitlines()) - - def _delta(self, spec: str) -> _Delta: - changed = list[str]() - deleted = list[str]() - for line in self._repo.git.diff(spec, name_status=True).splitlines(): - state, name = line.split(None, 1) - if state == "D": - deleted.append(name) - else: - changed.append(name) - return _Delta(changed=frozenset(changed), deleted=frozenset(deleted)) + def delta(self) -> _Delta | None: + diff = self.repo.git.diff_tree(self.commit, patch=True) + return _Delta(diff, self.repo) if diff else None @dataclasses.dataclass(frozen=True) class _Delta: - changed: frozenset[str] - deleted: frozenset[str] - - @property - def touched(self) -> frozenset[str]: - return self.changed | self.deleted + """A change's effects, guaranteed non-empty""" + + diff: str + repo: git.Repo = dataclasses.field(repr=False) + + def apply(self) -> None: + # For patch applcation to work as expected (adding conflict markers as + # needed), files in the patch must exist in the index. + self.repo.git.add(all=True) + with tempfile.TemporaryFile() as temp: + temp.write(self.diff.encode("utf8")) + temp.seek(0) + try: + self.repo.git.apply("--3way", "-", istream=temp) + except git.CommandError as exc: + if "with conflicts" in exc.stderr: + raise ConflictError() + raise exc + finally: + self.repo.git.reset() + + +class ConflictError(Exception): + """A change could not be applied cleanly""" class _OperationRecorder(ToolVisitor): diff --git a/src/git_draft/queries/add-action.sql b/src/git_draft/queries/add-action.sql index 2614c9e..990994b 100644 --- a/src/git_draft/queries/add-action.sql +++ b/src/git_draft/queries/add-action.sql @@ -3,7 +3,7 @@ insert into actions ( prompt_id, bot_name, bot_class, - walltime, + walltime_seconds, request_count, token_count) values ( @@ -11,6 +11,6 @@ insert into actions ( :prompt_id, :bot_name, :bot_class, - :walltime, + :walltime_seconds, :request_count, :token_count); diff --git a/src/git_draft/queries/create-tables.sql b/src/git_draft/queries/create-tables.sql index 8c85ae9..09506d3 100644 --- a/src/git_draft/queries/create-tables.sql +++ b/src/git_draft/queries/create-tables.sql @@ -22,7 +22,7 @@ create table if not exists actions ( prompt_id integer not null, bot_name text, bot_class text not null, - walltime real not null, + walltime_seconds real not null, request_count int, token_count int, foreign key (prompt_id) references prompts(id) on delete cascade diff --git a/src/git_draft/queries/list-prompts.sql b/src/git_draft/queries/list-prompts.sql index 3cf7afa..43021b1 100644 --- a/src/git_draft/queries/list-prompts.sql +++ b/src/git_draft/queries/list-prompts.sql @@ -2,7 +2,7 @@ select datetime(min(p.created_at), 'localtime') as created, coalesce(min(template), '-') as template, coalesce(min(a.bot_name), '-') as bot, - coalesce(round(sum(a.walltime), 1), 0) as walltime, + coalesce(round(sum(a.walltime_seconds), 1), 0) as walltime, count(o.id) as ops from prompts as p join branches as b on p.branch_suffix = b.suffix diff --git a/tests/git_draft/drafter_test.py b/tests/git_draft/drafter_test.py index ca7a4e4..ffea16d 100644 --- a/tests/git_draft/drafter_test.py +++ b/tests/git_draft/drafter_test.py @@ -127,6 +127,56 @@ def act(self, _goal: Goal, _toolbox: Toolbox) -> Action: assert len(self._commits()) == 2 # init, prompt assert not self._commit_files("HEAD") + def test_generate_accept_checkout(self) -> None: + self._write("p1", "A") + self._write("p2", "B") + + class CustomBot(Bot): + def act(self, _goal: Goal, toolbox: Toolbox) -> Action: + toolbox.write_file(PurePosixPath("p1"), "C") + toolbox.write_file(PurePosixPath("p3"), "D") + return Action() + + self._drafter.generate_draft( + "hello", CustomBot(), accept=sut.Accept.CHECKOUT, sync=True + ) + assert self._read("p1") == "C" + assert self._read("p2") == "B" + assert self._read("p3") == "D" + + def test_generate_accept_checkout_conflict(self) -> None: + self._write("p1", "A") + + class CustomBot(Bot): + def act(self, _goal: Goal, toolbox: Toolbox) -> Action: + toolbox.write_file(PurePosixPath("p1"), "B") + toolbox.write_file(PurePosixPath("p2"), "C") + return Action() + + with pytest.raises(sut.ConflictError): + self._drafter.generate_draft( + "hello", CustomBot(), accept=sut.Accept.CHECKOUT + ) + assert """<<<<<<< ours\nA""" in (self._read("p1") or "") + assert self._read("p2") == "C" + + def test_generate_accept_finalize(self) -> None: + self._write("p1", "A") + + class CustomBot(Bot): + def act(self, _goal: Goal, toolbox: Toolbox) -> Action: + toolbox.write_file(PurePosixPath("p2"), "B") + return Action() + + self._drafter.generate_draft( + "hello", + CustomBot(), + accept=sut.Accept.FINALIZE, + ) + assert self._read("p1") == "A" + assert self._read("p2") == "B" + assert self._repo.active_branch.name == "main" + def test_delete_unknown_file(self) -> None: class CustomBot(Bot): def act(self, _goal: Goal, toolbox: Toolbox) -> Action: