diff --git a/README.md b/README.md index 8ed6b77..7d604ed 100644 --- a/README.md +++ b/README.md @@ -22,5 +22,3 @@ pipx install git-draft[openai] ## Next steps * MCP bot. -* Mechanism for reporting feedback from a bot, and possibly allowing user to - interactively respond. diff --git a/docs/git-draft.adoc b/docs/git-draft.adoc index f2126b3..91f598a 100644 --- a/docs/git-draft.adoc +++ b/docs/git-draft.adoc @@ -34,53 +34,65 @@ git draft [options] --templates [--json | [--edit] TEMPLATE] -a:: --accept:: --no-accept:: - Merge generated changes automatically. - Can be repeated. +Merge generated changes automatically, updating the working directory. +This option can be repeated up to 3 times, with the following behavior at each level: ++ +* `-a`: Merge changes conservatively, flagging any conflicts for manual user resolution. +* `-aa`: Merge changes aggressively, resolving all conflicts in favor of the generated change. +* `-aaa`: Merge changes aggressively (same as `-aa`) then run `--quit` automatically. ++ +By default, changes are not merged - keeping the working directory untouched. +A different default can be set in the configuration file. +When doing so, the `--no-accept` flag can be used to disable merging at CLI invocation time. -b BOT:: --bot=BOT:: - Bot name. +Bot name. +Defaults to the first bot defined in the configuration. -e:: --edit:: - Enable interactive editing of draft prompts and templates. - See `--generate` and `--show-templates` for details. +Enable interactive editing of draft prompts and templates. +See `--new` and `--templates` for details. -h:: --help:: - Show help message and exit. +Show help message and exit. -j:: --json:: Use JSON output. ---log:: +--log-path:: Show log path and exit. -N:: --new:: - Create an AI-generated draft. - If the `--edit` option is set, an interactive editor will be open with the rendered prompt to allow modification before it is forwarded to the bot. +Create an AI-generated draft. +If the `--edit` option is set, an interactive editor will be opened with the rendered prompt to allow modification before it is forwarded to the bot. -Q:: --quit:: - Go back to the draft's origin branch with the current working directory. +Go back to the draft's origin branch, keeping the working directory's current state. +This will delete the draft branch and its upstream. +Generated commits remain available via `ref/drafts`. --root:: - Repository search root. +Repository search root. +Defaults to the current working directory. -T:: --templates:: - With no argument, lists available templates. - With an template name argument, displays the corresponding template's contents or, if the `--edit` option is set, opens an interactive editor. +With no argument, lists available templates. +With an template name argument, displays the corresponding template's contents or, if the `--edit` option is set, opens an interactive editor. --version:: - Show version and exit. +Show version and exit. == Examples -The workhorse command is `git draft --generate` which leverages AI to edit our code. +The workhorse command is `git draft --new` which leverages AI to edit code. A prompt can be specified as standard input, for example `echo "Add a test for compute_offset in chart.py" | git draft --generate`. If no prompt is specified and stdin is a TTY, `$EDITOR` will be opened to enter the prompt. @@ -93,14 +105,6 @@ This will check out the branch used when creating the draft, adding the final st Note that you can come back to an existing draft anytime (by checking its branch out), but you will not be able to apply it if its origin branch has moved since the draft was created. -* Generate: create a new draft to the current folio, or create a new folio if none exists. -* Finalize - ** Apply: include changes into origin branch. - ** Discard: abandon folio. - ** Save: return to original branch. Q: how to load after? -* List templates - - o Foo (main) o Sync (drafts/123/pub) diff --git a/poetry.lock b/poetry.lock index 55dc47a..5f25dea 100644 --- a/poetry.lock +++ b/poetry.lock @@ -151,6 +151,18 @@ files = [ {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, ] +[[package]] +name = "docopt-ng" +version = "0.9.0" +description = "Jazzband-maintained fork of docopt, the humane command line arguments parser." +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "docopt_ng-0.9.0-py3-none-any.whl", hash = "sha256:bfe4c8b03f9fca424c24ee0b4ffa84bf7391cb18c29ce0f6a8227a3b01b81ff9"}, + {file = "docopt_ng-0.9.0.tar.gz", hash = "sha256:91c6da10b5bb6f2e9e25345829fb8278c78af019f6fc40887ad49b060483b1d7"}, +] + [[package]] name = "h11" version = "0.14.0" @@ -922,4 +934,4 @@ openai = ["openai"] [metadata] lock-version = "2.1" python-versions = ">=3.12,<4" -content-hash = "0786ebdc5d020335f23530a7897e3b293c67be2b5d7708b2dd0ec7f6e9bc4cc7" +content-hash = "3e30f690d81bad2a4f36bfc042d3b45bcd83eeed65167c2d6b232083b30370f1" diff --git a/pyproject.toml b/pyproject.toml index 8d80e1e..dd04d43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ readme = "README.md" dynamic = ["version"] requires-python = ">=3.12" dependencies = [ + "docopt-ng (>=0.9,<0.10)", "jinja2 (>=3.1.5,<4)", "prettytable (>=3.15.1,<4)", "xdg-base-dirs (>=6.0.2,<7)", diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index 1deb01b..01abbc9 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -20,7 +20,12 @@ from .drafter import Drafter, DraftMergeStrategy from .editor import open_editor from .git import Repo -from .prompt import Template, TemplatedPrompt, find_template, templates_table +from .prompt import ( + PromptMetadata, + TemplatedPrompt, + find_prompt_metadata, + templates_table, +) from .store import Store @@ -178,43 +183,45 @@ def main() -> None: # noqa: PLR0912 PLR0915 bot = load_bot(bot_config) prompt: str | TemplatedPrompt - editable = opts.edit if args: - prompt = TemplatedPrompt.parse(args[0], *args[1:]) - elif opts.edit: - editable = False + if args[0] == "-": + prompt = sys.stdin.read() + else: + prompt = TemplatedPrompt.public(args[0], args[1:]) + editable = opts.edit + else: prompt = edit( text=drafter.latest_draft_prompt() or _PROMPT_PLACEHOLDER ).strip() - if not prompt or prompt == _PROMPT_PLACEHOLDER: - raise ValueError("Aborting: empty or placeholder prompt") - else: - if sys.stdin.isatty(): - print("Reading prompt from stdin... (press C-D when done)") - prompt = sys.stdin.read() + if prompt.strip() == _PROMPT_PLACEHOLDER: + prompt = "" # Enable consistent error message + editable = False # We already edited the prompt accept = Accept(opts.accept or 0) - _ = drafter.generate_draft( + drafter.generate_draft( prompt, bot, prompt_transform=open_editor if editable else None, merge_strategy=accept.merge_strategy(), ) + if accept == Accept.MERGE_THEN_QUIT: + # TODO: Refuse to quit on pending question? + drafter.quit_folio() case "quit": drafter.quit_folio() case "templates": if args: name = args[0] - tpl = find_template(name) + meta = find_prompt_metadata(name) if opts.edit: - if tpl: - edit(path=tpl.local_path(), text=tpl.source) + if meta: + edit(path=meta.local_path(), text=meta.source()) else: - edit(path=Template.local_path_for(name)) + edit(path=PromptMetadata.local_path_for(name)) else: - if not tpl: + if not meta: raise ValueError(f"No template named {name!r}") - print(tpl.source) + print(meta.source()) else: table = templates_table() print(table.to_json() if opts.json else table) diff --git a/src/git_draft/bots/common.py b/src/git_draft/bots/common.py index d9a9df8..f9c0bb2 100644 --- a/src/git_draft/bots/common.py +++ b/src/git_draft/bots/common.py @@ -28,6 +28,7 @@ class Action: title: str | None = None request_count: int | None = None token_count: int | None = None + question: str | None = None def increment_request_count(self, n: int = 1, init: bool = False) -> None: self._increment("request_count", n, init) diff --git a/src/git_draft/bots/openai.py b/src/git_draft/bots/openai.py index 1b9921d..8701460 100644 --- a/src/git_draft/bots/openai.py +++ b/src/git_draft/bots/openai.py @@ -20,7 +20,7 @@ import openai -from ..common import JSONObject, config_string, reindent +from ..common import JSONObject, UnreachableError, config_string, reindent from .common import Action, Bot, Goal, Toolbox @@ -45,7 +45,7 @@ def threads_bot( model: str = _DEFAULT_MODEL, ) -> Bot: """Beta bot, uses assistant threads with function calling""" - return _ThreadsBot.create(_new_client(api_key, base_url), model) + return _ThreadsBot(_new_client(api_key, base_url), model) def _new_client(api_key: str | None, base_url: str | None) -> openai.OpenAI: @@ -85,6 +85,21 @@ def _param( def params(self) -> Sequence[openai.types.chat.ChatCompletionToolParam]: return [ + self._param( + name="ask_user", + description=""" + Request more information from the user + + Call this function if and only if you are unable to achieve + your task with the information you already have. + """, + inputs={ + "question": { + "type": "string", + "description": "Question to be answered by the user", + }, + }, + ), self._param( name="list_files", description="List all available files", @@ -152,17 +167,18 @@ def params(self) -> Sequence[openai.types.chat.ChatCompletionToolParam]: read the content of the relevant ones, and save the changes you suggest. You should stop when and ONLY WHEN all the files you need to change have - been updated. If you stop for any reason before completing your task, - explain why by updating a REASON file before stopping. For example if you - are missing some information or noticed something inconsistent with the - instructions, say so there. DO NOT STOP without updating at least this - file. + been updated. If you do not have enough information to complete your task, + use the provided tool to request it from the user, then stop. """ class _ToolHandler[V]: def __init__(self, toolbox: Toolbox) -> None: self._toolbox = toolbox + self.question: str | None = None + + def _on_ask_user(self) -> V: + raise NotImplementedError() def _on_read_file(self, path: PurePosixPath, contents: str | None) -> V: raise NotImplementedError() @@ -185,6 +201,10 @@ def handle_function(self, function: Any) -> V: inputs = json.loads(function.arguments) _logger.info("Requested function: %s", function) match function.name: + case "ask_user": + assert not self.question + self.question = inputs["question"] + return self._on_ask_user() case "read_file": path = PurePosixPath(inputs["path"]) return self._on_read_file(path, self._toolbox.read_file(path)) @@ -202,10 +222,11 @@ def handle_function(self, function: Any) -> V: dst_path = PurePosixPath(inputs["dst_path"]) self._toolbox.rename_file(src_path, dst_path) return self._on_rename_file(src_path, dst_path) - case _ as name: - assert name == "list_files" and not inputs + case "list_files": paths = self._toolbox.list_files() return self._on_list_files(paths) + case _ as name: + raise UnreachableError(f"Unexpected function: {name}") class _CompletionsBot(Bot): @@ -243,10 +264,16 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action: if done: break - return Action(request_count=request_count) + return Action( + request_count=request_count, + question=tool_handler.question, + ) class _CompletionsToolHandler(_ToolHandler[str | None]): + def _on_ask_user(self) -> None: + return None + def _on_read_file(self, path: PurePosixPath, contents: str | None) -> str: if contents is None: return f"`{path}` does not exist." @@ -269,32 +296,31 @@ def _on_list_files(self, paths: Sequence[PurePosixPath]) -> str: class _ThreadsBot(Bot): - def __init__(self, client: openai.OpenAI, assistant_id: str) -> None: + def __init__(self, client: openai.OpenAI, model: str) -> None: self._client = client - self._assistant_id = assistant_id + self._model = model - @classmethod - def create(cls, client: openai.OpenAI, model: str) -> Self: - assistant_kwargs: JSONObject = dict( - model=model, + def _load_assistant_id(self) -> str: + kwargs: JSONObject = dict( + model=self._model, instructions=reindent(_INSTRUCTIONS), tools=_ToolsFactory(strict=True).params(), ) - - path = cls.state_folder_path(ensure_exists=True) / "ASSISTANT_ID" + path = self.state_folder_path(ensure_exists=True) / "ASSISTANT_ID" try: with open(path) as f: assistant_id = f.read() - client.beta.assistants.update(assistant_id, **assistant_kwargs) + self._client.beta.assistants.update(assistant_id, **kwargs) except (FileNotFoundError, openai.NotFoundError): - assistant = client.beta.assistants.create(**assistant_kwargs) + assistant = self._client.beta.assistants.create(**kwargs) assistant_id = assistant.id with open(path, "w") as f: f.write(assistant_id) - - return cls(client, assistant_id) + return assistant_id def act(self, goal: Goal, toolbox: Toolbox) -> Action: + assistant_id = self._load_assistant_id() + thread = self._client.beta.threads.create() self._client.beta.threads.messages.create( thread_id=thread.id, @@ -307,7 +333,7 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action: action = Action(request_count=0, token_count=0) with self._client.beta.threads.runs.stream( thread_id=thread.id, - assistant_id=self._assistant_id, + assistant_id=assistant_id, event_handler=_EventHandler(self._client, toolbox, action), ) as stream: stream.until_done() @@ -353,6 +379,9 @@ def _handle_action(self, _run_id: str, data: Any) -> None: for tool in data.required_action.submit_tool_outputs.tool_calls: handler = _ThreadToolHandler(self._toolbox, tool.id) tool_outputs.append(handler.handle_function(tool.function)) + if handler.question: + assert not self._action.question + self._action.question = handler.question run = self.current_run assert run, "No ongoing run" @@ -378,6 +407,9 @@ def __init__(self, toolbox: Toolbox, call_id: str) -> None: def _wrap(self, output: str) -> _ToolOutput: return _ToolOutput(tool_call_id=self._call_id, output=output) + def _on_ask_user(self) -> _ToolOutput: + return self._wrap("OK") + def _on_read_file( self, _path: PurePosixPath, contents: str | None ) -> _ToolOutput: diff --git a/src/git_draft/common.py b/src/git_draft/common.py index 1436dd7..7b0564a 100644 --- a/src/git_draft/common.py +++ b/src/git_draft/common.py @@ -85,16 +85,24 @@ class UnreachableError(RuntimeError): """Indicates unreachable code was unexpectedly executed""" -def reindent(s: str, width: int = 0) -> str: +def reindent(s: str, prefix: str = "", width: int = 0) -> str: """Reindents text by dedenting and optionally wrapping paragraphs""" paragraphs = ( " ".join(textwrap.dedent("\n".join(g)).splitlines()) for b, g in itertools.groupby(s.splitlines(), bool) if b ) - return "\n\n".join( + if width and prefix: + width -= len(prefix) + 1 + assert width > 0 + wrapped = "\n\n".join( textwrap.fill(p, width=width) if width else p for p in paragraphs ) + if not prefix: + return wrapped + return "\n".join( + f"{prefix} {t}" if t else prefix for t in wrapped.splitlines() + ) def qualified_class_name(cls: type) -> str: @@ -163,6 +171,10 @@ def static() -> Feedback: class FeedbackSpinner: """Operation feedback tracker""" + @contextlib.contextmanager + def hidden(self) -> Iterator[None]: + yield None + def update(self, text: str, **tags) -> None: # pragma: no cover raise NotImplementedError() @@ -198,6 +210,11 @@ class _DynamicFeedbackSpinner(FeedbackSpinner): def __init__(self, yaspin: yaspin.core.Yaspin) -> None: self.yaspin = yaspin + @contextlib.contextmanager + def hidden(self) -> Iterator[None]: + with self.yaspin.hidden(): + yield + def update(self, text: str, **tags) -> None: self.yaspin.text = _tagged(text, **tags) diff --git a/src/git_draft/drafter.py b/src/git_draft/drafter.py index 13eeade..b713e5f 100644 --- a/src/git_draft/drafter.py +++ b/src/git_draft/drafter.py @@ -14,9 +14,9 @@ from typing import Literal from .bots import Action, Bot, Goal -from .common import Feedback, JSONObject, qualified_class_name +from .common import Feedback, JSONObject, qualified_class_name, reindent from .git import SHA, Repo -from .prompt import PromptRenderer, TemplatedPrompt +from .prompt import TemplatedPrompt from .store import Store, sql from .toolbox import RepoToolbox, ToolVisitor @@ -31,6 +31,7 @@ class Draft: folio: Folio seqno: int is_noop: bool + has_question: bool walltime: timedelta token_count: int | None @@ -110,13 +111,14 @@ def generate_draft( # Handle prompt templating and editing. We do this first in case # this fails, to avoid creating unnecessary branches. toolbox, dirty = RepoToolbox.for_working_dir(self._repo) - prompt_contents = self._prepare_prompt( - prompt, prompt_transform, toolbox - ) + with spinner.hidden(): + prompt_contents = self._prepare_prompt( + prompt, + prompt_transform, + toolbox, + ) template_name = ( - prompt.template - if isinstance(prompt, TemplatedPrompt) - else None + prompt.name if isinstance(prompt, TemplatedPrompt) else None ) spinner.update( "Prepared prompt.", @@ -148,6 +150,8 @@ def generate_draft( [operation_recorder], ), ) + if change.action.question: + self._feedback.report("Requested feedback.") spinner.update( "Completed bot run.", runtime=round(change.walltime.total_seconds(), 1), @@ -159,6 +163,7 @@ def generate_draft( folio=folio, seqno=seqno, is_noop=change.is_noop, + has_question=change.action.question is not None, walltime=change.walltime, token_count=change.action.token_count, ) @@ -183,19 +188,19 @@ def generate_draft( cursor.execute( sql("add-action"), { - "commit_sha": commit_sha, "prompt_id": prompt_id, "bot_class": qualified_class_name(bot.__class__), "walltime_seconds": change.walltime.total_seconds(), "request_count": change.action.request_count, "token_count": change.action.token_count, + "question": change.action.question, }, ) cursor.executemany( sql("add-operation"), [ { - "commit_sha": commit_sha, + "prompt_id": prompt_id, "tool": o.tool, "reason": o.reason, "details": json.dumps(o.details), @@ -328,14 +333,13 @@ def _prepare_prompt( toolbox: RepoToolbox, ) -> str: if isinstance(prompt, TemplatedPrompt): - renderer = PromptRenderer.for_toolbox(toolbox) - contents = renderer.render(prompt) + contents = prompt.render(toolbox) else: contents = prompt if prompt_transform: contents = prompt_transform(contents) if not contents.strip(): - raise ValueError("Empty prompt") + raise ValueError("Missing or empty prompt") return contents def _generate_change( @@ -406,7 +410,12 @@ def latest_draft_prompt(self) -> str | None: "folio_id": folio.id, }, ).fetchone() - return result[0] if result else None + if not result: + return None + prompt, question = result + if question: + prompt = "\n\n".join([prompt, reindent(question, prefix="> ")]) + return prompt @dataclasses.dataclass(frozen=True) diff --git a/src/git_draft/prompt.py b/src/git_draft/prompt.py index 3e4c1fc..036ba8e 100644 --- a/src/git_draft/prompt.py +++ b/src/git_draft/prompt.py @@ -2,129 +2,107 @@ from __future__ import annotations -from collections.abc import Mapping +from collections.abc import Mapping, Sequence import dataclasses import enum -import itertools +import functools import os from pathlib import Path -from typing import Self +import re +from typing import Self, TypedDict, cast +import docopt import jinja2 -import jinja2.meta -from .bots import Toolbox from .common import Config, Table, package_root +from .toolbox import NoopToolbox, Toolbox _extension = "jinja" +type PromptName = str + + @dataclasses.dataclass(frozen=True) class TemplatedPrompt: - """A parametrized prompt""" + """A parameterized prompt""" - template: str - context: Mapping[str, str] + name: PromptName + args: tuple[str, ...] = () @classmethod - def parse(cls, name: str, *args: str) -> Self: - """Parse arguments into a TemplatedPrompt - - Args: - name: The name of the template. - *args: Additional arguments for context, expected in 'key=value' - format. - - """ - return cls(name, dict(e.split("=", 1) for e in args)) + def public(cls, name: PromptName, args: Sequence[str]) -> Self: + _check_public_template_name(name) + return cls(name, tuple(args)) + def render(self, toolbox: Toolbox) -> str: + prompt = _load_prompt(_jinja_environment(), self.name, toolbox) + return prompt.render(self.args) -class _GlobalVariable(enum.StrEnum): - REPO = enum.auto() +_public_template_name_pattern = re.compile(r"\.?[a-z-]+") -class PromptRenderer: - """Renderer for prompt templates using Jinja2""" - def __init__(self, env: jinja2.Environment) -> None: - self._environment = env +def _check_public_template_name(name: str) -> None: + if not _public_template_name_pattern.fullmatch(name): + raise ValueError(f"Invalid template name: {name}") - @classmethod - def for_toolbox(cls, toolbox: Toolbox) -> Self: - env = _jinja_environment() - env.globals[_GlobalVariable.REPO] = { - "file_paths": [str(p) for p in toolbox.list_files()], - } - return cls(env) - - def render(self, prompt: TemplatedPrompt) -> str: - tpl = self._environment.get_template(f"{prompt.template}.{_extension}") - try: - return tpl.render(prompt.context) - except jinja2.UndefinedError as err: - raise ValueError(f"Unable to render template: {err}") - -def templates_table() -> Table: - env = _jinja_environment() - table = Table.empty() - table.data.field_names = ["name", "local", "preamble"] - for rel_path in env.list_templates(extensions=[_extension]): - if any(p.startswith(".") for p in rel_path.split(os.sep)): - continue - tpl = _load_template(rel_path, env) - local = "y" if tpl.is_local() else "n" - table.data.add_row([tpl.name, local, tpl.preamble or "-"]) - return table +def _jinja_environment(*, include_local: bool = True) -> jinja2.Environment: + folders = [_PromptFolder.BUILTIN] + if include_local: + folders.append(_PromptFolder.LOCAL) + return jinja2.Environment( + auto_reload=False, + autoescape=False, + keep_trailing_newline=True, + loader=jinja2.FileSystemLoader([f.path for f in folders]), + trim_blocks=True, + lstrip_blocks=True, + undefined=jinja2.StrictUndefined, + ) class _PromptFolder(enum.Enum): - BUILTIN = package_root / "prompts" - LOCAL = Config.folder_path() / "prompts" + BUILTIN = package_root + LOCAL = Config.folder_path() @property def path(self) -> Path: - return self.value - + return self.value / "prompts" -def _extract_preamble(source: str, env: jinja2.Environment) -> str | None: - """Returns the template's leading comment's contents, if preset""" - tokens = list(itertools.islice(env.lex(source), 3)) - if len(tokens) == 3 and tokens[1][1] == "comment": - return tokens[1][2].strip() - return None +@functools.cache +def _load_layouts() -> Mapping[str, str]: + root = _PromptFolder.BUILTIN.path + return { + p.stem: str(p.relative_to(root)) + for p in (root / ".LAYOUTS").glob(f"*.{_extension}") + } -def _load_template(rel_path: str, env: jinja2.Environment) -> Template: - assert env.loader, "No loader in environment" - source, abs_path, _uptodate = env.loader.get_source(env, rel_path) - assert abs_path, "Missing template path" - preamble = _extract_preamble(source, env) - return Template(Path(rel_path), Path(abs_path), source, preamble) - -def find_template(name: str) -> Template | None: - env = _jinja_environment() - try: - return _load_template(f"{name}.{_extension}", env) - except jinja2.TemplateNotFound: - return None +class _Context(TypedDict): + prompt: Mapping[str, str] + toolbox: Toolbox @dataclasses.dataclass(frozen=True) -class Template: - """An available template""" +class PromptMetadata: + """Metadata about an available template""" rel_path: Path abs_path: Path - source: str - preamble: str | None + description: str | None = None @property def name(self) -> str: return str(self.rel_path.parent / self.rel_path.stem) + def source(self) -> str: + with self.abs_path.open() as reader: + return reader.read() + def is_local(self) -> bool: return not self.abs_path.is_relative_to(_PromptFolder.BUILTIN.path) @@ -133,26 +111,112 @@ def local_path(self) -> Path: return self.abs_path return _PromptFolder.LOCAL.path / self.rel_path - def extract_variables(self, env: jinja2.Environment) -> frozenset[str]: - """Returns the names of variables directly used in the template - - The returned set does not include transitive variables (used in - included templates) or variables populated automatically (e.g. `repo`). - """ - # https://stackoverflow.com/a/48685520 - ast = env.parse(self.source) - return frozenset(jinja2.meta.find_undeclared_variables(ast)) - @staticmethod def local_path_for(name: str) -> Path: + _check_public_template_name(name) return _PromptFolder.LOCAL.path / Path(f"{name}.{_extension}") -def _jinja_environment() -> jinja2.Environment: - return jinja2.Environment( - auto_reload=False, - autoescape=False, - keep_trailing_newline=True, - loader=jinja2.FileSystemLoader([p.path for p in _PromptFolder]), - undefined=jinja2.StrictUndefined, - ) +class _Prompt: + """Instantiated dynamic prompt""" + + def __init__(self, metadata: PromptMetadata) -> None: + self.metadata = metadata + + def render(self, args: Sequence[str]) -> str: + raise NotImplementedError() + + +class _StandalonePrompt(_Prompt): + """Prompt without a layout""" + + def __init__(self, metadata: PromptMetadata, rendered: str) -> None: + super().__init__(metadata) + self._rendered = rendered + + def render(self, args: Sequence[str]) -> str: + if args: + raise RuntimeError("Arguments not supported for this template") + return self._rendered + + +class _DocoptPrompt(_Prompt): + """Prompt which supports options via docopt""" + + def __init__( + self, + template: jinja2.Template, + doc: str, + rel_path: Path, + context: _Context, + ) -> None: + # We could validate the doc string here, but don't since docopt doesn't + # make it easy with inline print and sys.exit calls. + super().__init__( + PromptMetadata( + rel_path, + _template_path(template), + doc.partition("\n")[0], + ) + ) + self._template = template + self._doc = doc + self._context = context + + def render(self, args: Sequence[str]) -> str: + try: + opts = docopt.docopt(self._doc, list(args)) + except docopt.DocoptExit as exc: + raise ValueError(f"Invalid template arguments: {args}") from exc + return self._template.render({**self._context, "opts": opts}) + + +def _template_path(template: jinja2.Template) -> Path: + """Returns the template's absolute path""" + assert template.filename + path = Path(template.filename) + assert path.is_absolute() + return path + + +def _load_prompt( + env: jinja2.Environment, name: PromptName, toolbox: Toolbox +) -> _Prompt: + rel_path = Path(f"{name}.{_extension}") + assert env.loader, "No loader in environment" + template = env.loader.load(env, str(rel_path)) + context: _Context = dict(prompt=_load_layouts(), toolbox=toolbox) + try: + module = template.make_module(vars=cast(dict, context)) + except jinja2.TemplateError as exc: + raise ValueError(f"Template {name} is invalid: {exc}") from exc + match getattr(module, "layout", None): + case "docopt": + return _DocoptPrompt(template, str(module), rel_path, context) + case _: + metadata = PromptMetadata(rel_path, _template_path(template)) + return _StandalonePrompt(metadata, str(module)) + + +def find_prompt_metadata(name: PromptName) -> PromptMetadata | None: + try: + prompt = _load_prompt(_jinja_environment(), name, NoopToolbox()) + except jinja2.TemplateNotFound: + return None + return prompt.metadata + + +def templates_table(*, include_local: bool = True) -> Table: + env = _jinja_environment(include_local=include_local) + toolbox = NoopToolbox() + table = Table.empty() + table.data.field_names = ["name", "local", "description"] + for rel_path in env.list_templates(extensions=[_extension]): + if any(p.startswith(".") for p in rel_path.split(os.sep)): + continue + name, _ext = os.path.splitext(rel_path) + prompt = _load_prompt(env, name, toolbox) + metadata = prompt.metadata + local = "y" if metadata.is_local() else "n" + table.data.add_row([name, local, metadata.description or ""]) + return table diff --git a/src/git_draft/prompts/.LAYOUTS/docopt.jinja b/src/git_draft/prompts/.LAYOUTS/docopt.jinja new file mode 100644 index 0000000..21ee7d1 --- /dev/null +++ b/src/git_draft/prompts/.LAYOUTS/docopt.jinja @@ -0,0 +1,7 @@ +{% set layout = "docopt" %} +{% import ".MACROS.jinja" as macros with context %} +{% if opts is defined %} +{% block body required %}{% endblock %} +{% else %} +{% block doc required %}{% endblock %} +{% endif %} diff --git a/src/git_draft/prompts/.MACROS.jinja b/src/git_draft/prompts/.MACROS.jinja new file mode 100644 index 0000000..b67517c --- /dev/null +++ b/src/git_draft/prompts/.MACROS.jinja @@ -0,0 +1,13 @@ +{% macro file_list() %} +{% set paths = toolbox.list_files() %} +{% if paths %} +For reference, here is the list of all currently available files in the +repository: + +{% for path in paths %} + * {{ path }} +{% endfor %} +{% else %} +This repository does not contain any files yet. +{% endif %} +{% endmacro %} diff --git a/src/git_draft/prompts/add-docstrings.jinja b/src/git_draft/prompts/add-docstrings.jinja index 47b0497..b102755 100644 --- a/src/git_draft/prompts/add-docstrings.jinja +++ b/src/git_draft/prompts/add-docstrings.jinja @@ -1,9 +1,23 @@ -{# symbol path? #} - -{% if symbol is defined and path is defined %} -Add a docstring to {{ symbol }} defined in {{ path }}. -{% elif path is defined %} -Add docstrings to all public functions and classes in {{ path }}. +{% extends prompt.docopt %} + +{% block doc %} +Add docstrings to one or more public symbols + +Usage: + add-docstrings [--path PATH] [--symbol SYMBOL] + +Options: + --path PATH Path to (Python) file where docstrings should be added. If + omitted, all source files in this repository will be used. + --symbol SYMBOL Name of symbol to be documented. If omitted, all symbols + will be considered. +{% endblock %} + +{% block body %} +{% if opts.symbol and opts.path %} +Add a docstring to {{ opts.symbol }} defined in {{ opts.path }}. +{% elif opts.path %} +Add docstrings to all public functions and classes in {{ opts.path }}. {% else %} Add docstrings to all public functions and classes in this repository. {% endif %} @@ -38,4 +52,5 @@ Additionally, the first paragraph of each docstring should fit in a single line and not include a period at the end. It should be a brief summary of the symbol's functionality. -{% include "includes/.file-list.jinja" %} +{{ macros.file_list() }} +{% endblock %} diff --git a/src/git_draft/prompts/add-test.jinja b/src/git_draft/prompts/add-test.jinja index bcacbe6..cea0371 100644 --- a/src/git_draft/prompts/add-test.jinja +++ b/src/git_draft/prompts/add-test.jinja @@ -1,13 +1,31 @@ -{# symbol src_path? test_path? #} +{% extends prompt.docopt %} -Add tests for {{ symbol }}. Follow existing conventions when implementing the +{% block doc %} +Add tests for a symbol + +Usage: + add-test --symbol NAME [--src-path PATH] [--test-path PATH] + +Options: + --symbol NAME Name of the symbol to be tested. + --src-path PATH Path to source file where the symbol is defined. + --test-path PATH Path to file where the test should be added. +{% endblock %} + +{% block body %} +Add tests for {{ opts.symbol }}. Follow existing conventions when implementing the tests. For example, if the surrounding code uses fixtures, do so as well. +{% if opts.src_path %} -{% if src_path is defined %}The symbol to be tested is defined in {{ src_path }}.{% endif %} +The symbol to be tested is defined in {{ opts.src_path }}. +{% endif %} +{% if opts.test_path %} -{% if test_path is defined %}The tests should be added to {{ test_path }}.{% endif %} +The tests should be added to {{ opts.test_path }}. +{% endif %} Do not stop until you have added at least one test. You should add separate tests to cover the normal execution path, and to cover any exceptional cases. -{% include "includes/.file-list.jinja" %} +{{ macros.file_list() }} +{% endblock %} diff --git a/src/git_draft/prompts/includes/.file-list.jinja b/src/git_draft/prompts/includes/.file-list.jinja deleted file mode 100644 index 5caf30f..0000000 --- a/src/git_draft/prompts/includes/.file-list.jinja +++ /dev/null @@ -1,6 +0,0 @@ -For reference, here is the list of all currently available files in the -repository: - -{% for path in repo.file_paths -%} - * {{ path }} -{% endfor %} diff --git a/src/git_draft/queries/add-action.sql b/src/git_draft/queries/add-action.sql index c438cd2..e054afd 100644 --- a/src/git_draft/queries/add-action.sql +++ b/src/git_draft/queries/add-action.sql @@ -1,14 +1,14 @@ insert into actions ( - commit_sha, prompt_id, bot_class, walltime_seconds, request_count, - token_count) + token_count, + question) values ( - :commit_sha, :prompt_id, :bot_class, :walltime_seconds, :request_count, - :token_count); + :token_count, + :question); diff --git a/src/git_draft/queries/add-operation.sql b/src/git_draft/queries/add-operation.sql index 627ffbc..7d68cc3 100644 --- a/src/git_draft/queries/add-operation.sql +++ b/src/git_draft/queries/add-operation.sql @@ -1,2 +1,2 @@ -insert into operations (action_commit_sha, tool, reason, details, started_at) - values (:commit_sha, :tool, :reason, :details, :started_at) +insert into operations (prompt_id, tool, reason, details, started_at) + values (:prompt_id, :tool, :reason, :details, :started_at) diff --git a/src/git_draft/queries/create-tables.sql b/src/git_draft/queries/create-tables.sql index 2e1e60c..622746e 100644 --- a/src/git_draft/queries/create-tables.sql +++ b/src/git_draft/queries/create-tables.sql @@ -20,22 +20,22 @@ create table if not exists prompts ( create unique index if not exists prompts_by_folio_seqno on prompts (folio_id, seqno); create table if not exists actions ( - commit_sha text primary key, + prompt_id integer primary key, created_at timestamp default current_timestamp, - prompt_id integer not null, bot_class text not null, walltime_seconds real not null, request_count int, token_count int, + question text, foreign key (prompt_id) references prompts (id) on delete cascade ) without rowid; create table if not exists operations ( id integer primary key, - action_commit_sha text not null, + prompt_id integer not null, tool text not null, reason text, details text not null, started_at timestamp not null, - foreign key (action_commit_sha) references actions (commit_sha) on delete cascade + foreign key (prompt_id) references actions (prompt_id) on delete cascade ); diff --git a/src/git_draft/queries/get-latest-folio-prompt.sql b/src/git_draft/queries/get-latest-folio-prompt.sql index a1e5a49..82149b1 100644 --- a/src/git_draft/queries/get-latest-folio-prompt.sql +++ b/src/git_draft/queries/get-latest-folio-prompt.sql @@ -1,6 +1,7 @@ -select p.contents +select p.contents, a.question from prompts as p join folios as f on p.folio_id = f.id + left join actions as a on p.id = a.prompt_id where f.id = :folio_id order by p.id desc limit 1; diff --git a/src/git_draft/queries/list-folio-prompts.sql b/src/git_draft/queries/list-folio-prompts.sql index 52d8910..1a65510 100644 --- a/src/git_draft/queries/list-folio-prompts.sql +++ b/src/git_draft/queries/list-folio-prompts.sql @@ -7,7 +7,7 @@ select from prompts as p join folios as f on p.folio_id = f.id left join actions as a on p.id = a.prompt_id - left join operations as o on a.commit_sha = o.action_commit_sha + left join operations as o on a.prompt_id = o.prompt_id where f.id = :folio_id group by p.id order by created desc; diff --git a/src/git_draft/store.py b/src/git_draft/store.py index e2de64e..26f8785 100644 --- a/src/git_draft/store.py +++ b/src/git_draft/store.py @@ -19,7 +19,7 @@ class Store: """Lightweight sqlite wrapper""" - _name = "v2.sqlite3" + _name = "v3.sqlite3" def __init__(self, conn: sqlite3.Connection) -> None: self._connection = conn diff --git a/src/git_draft/toolbox.py b/src/git_draft/toolbox.py index 632ecb0..04cc7e8 100644 --- a/src/git_draft/toolbox.py +++ b/src/git_draft/toolbox.py @@ -131,6 +131,26 @@ def on_rename_file( ) -> None: ... # pragma: no cover +class NoopToolbox(Toolbox): + """No-op read-only toolbox""" + + @override + def _list(self) -> Sequence[PurePosixPath]: + return [] + + @override + def _read(self, _path: PurePosixPath) -> str: + raise RuntimeError() + + @override + def _write(self, _path: PurePosixPath, _contents: str) -> None: + raise RuntimeError() + + @override + def _delete(self, _path: PurePosixPath) -> None: + raise RuntimeError() + + class RepoToolbox(Toolbox): """Git-repo backed toolbox implementation diff --git a/tests/git_draft/common_test.py b/tests/git_draft/common_test.py index 0dec8ef..f29c6f1 100644 --- a/tests/git_draft/common_test.py +++ b/tests/git_draft/common_test.py @@ -62,13 +62,13 @@ def test_evar(self, monkeypatch) -> None: @pytest.mark.parametrize( - "text,width,want", + "text,width,prefix,want", [ - ("", 10, ""), - ("abc", 5, "abc"), - ("ab", 0, "ab"), - ("\nabc def", 4, "abc\ndef"), - (" abc\n def ", 10, "abc def"), + ("", 10, "", ""), + ("abc", 5, "", "abc"), + ("ab", 0, "", "ab"), + ("\nabc def", 4, "", "abc\ndef"), + (" abc\n def ", 10, "", "abc def"), ( """ This is a fun paragraph @@ -77,9 +77,21 @@ def test_evar(self, monkeypatch) -> None: And another. """, 60, + "", "This is a fun paragraph which continues.\n\nAnd another.", ), + ( + """ + A quoted + something. + + And very long follow up. + """, + 24, + ">", + "> A quoted something.\n>\n> And very long follow\n> up.", + ), ], ) -def test_reindent(text, width, want): - assert sut.reindent(text, width) == want +def test_reindent(text, width, prefix, want): + assert sut.reindent(text, prefix=prefix, width=width) == want diff --git a/tests/git_draft/prompt_test.py b/tests/git_draft/prompt_test.py index 90685ac..048a7bd 100644 --- a/tests/git_draft/prompt_test.py +++ b/tests/git_draft/prompt_test.py @@ -4,56 +4,42 @@ from git_draft.toolbox import RepoToolbox -class TestPromptRenderer: +class TestCheckPublicTemplateName: + @pytest.mark.parametrize("name", ["ok", ".hidden", "composite-name"]) + def test_ok(self, name: str) -> None: + sut._check_public_template_name(name) + + @pytest.mark.parametrize("name", ["", "ABC", ".PROMPT", ".with.ext"]) + def test_raises(self, name: str) -> None: + with pytest.raises(ValueError): + sut._check_public_template_name(name) + + +class TestTemplatedPrompt: @pytest.fixture(autouse=True) def setup(self, repo) -> None: - toolbox = RepoToolbox(repo, "HEAD") - self._renderer = sut.PromptRenderer.for_toolbox(toolbox) + self._toolbox = RepoToolbox(repo, "HEAD") def test_ok(self) -> None: - prompt = sut.TemplatedPrompt.parse("add-test", "symbol=foo") - rendered = self._renderer.render(prompt) + prompt = sut.TemplatedPrompt("add-test", ("--symbol=foo",)) + rendered = prompt.render(self._toolbox) assert "foo" in rendered def test_missing_variable(self) -> None: - prompt = sut.TemplatedPrompt.parse("add-test") + prompt = sut.TemplatedPrompt("add-test") with pytest.raises(ValueError): - self._renderer.render(prompt) - + prompt.render(self._toolbox) -class TestTemplate: - @pytest.fixture(autouse=True) - def setup(self) -> None: - self._env = sut._jinja_environment() - def test_fields(self): - tpl = sut._load_template("includes/.file-list.jinja", self._env) - assert not tpl.is_local() - assert tpl.name == "includes/.file-list" - assert tpl.local_path() != tpl.abs_path - - def test_preamble_ok(self): - tpl = sut._load_template("add-test.jinja", self._env) - assert "symbol" in tpl.preamble - - def test_preamble_missing(self): - tpl = sut._load_template("includes/.file-list.jinja", self._env) - assert tpl.preamble is None - - def test_extract_variables(self): - tpl = sut._load_template("add-test.jinja", self._env) - variables = tpl.extract_variables(self._env) - assert "symbol" in variables - assert "repo" not in variables - - def test_find_ok(self) -> None: - tpl = sut.find_template("add-test") - assert tpl - assert "symbol" in tpl.source +class TestFindPromptMetadata: + def test_ok(self) -> None: + metadata = sut.find_prompt_metadata("add-test") + assert metadata + assert "symbol" in (metadata.description or "") - def test_find_missing(self) -> None: - assert sut.find_template("foo") is None + def test_missing(self) -> None: + assert sut.find_prompt_metadata("foo") is None def test_templates_table() -> None: - assert sut.templates_table() + assert sut.templates_table(include_local=False)