diff --git a/README.md b/README.md index 6db7a00..b62fa9a 100644 --- a/README.md +++ b/README.md @@ -11,9 +11,5 @@ ## Ideas -* Change prompt CLI inputs to `[PROMPT] [--] [ARGS]`. If `PROMPT` does not - contain any spaces or `ARGS` (or `--`) is present, it will be interpreted as a - template name. Otherwise an inline prompt. -* Add `--generate` timeout option. * Add a compatibility OpenAI bot version which does not use threads, so that it can be used with tools only. Gemini only supports the latter. diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index 4efafb8..49f2dba 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -68,12 +68,6 @@ def callback(_option, _opt, _value, parser) -> None: help="delete draft after finalizing or discarding", action="store_true", ) - parser.add_option( - "-p", - "--prompt", - dest="prompt", - help="inline prompt", - ) parser.add_option( "-r", "--reset", @@ -88,9 +82,9 @@ def callback(_option, _opt, _value, parser) -> None: ) parser.add_option( "-t", - "--template", - dest="template", - help="prompt template", + "--timeout", + dest="timeout", + help="bot generation timeout", ) return parser @@ -119,6 +113,7 @@ def main() -> None: if command == "generate": if not config.bots: raise ValueError("No bots configured") + if opts.bot: bot_configs = [c for c in config.bots if c.name == opts.bot] if len(bot_configs) != 1: @@ -128,11 +123,11 @@ def main() -> None: bot_config = config.bots[0] bot = load_bot(bot_config) - prompt = opts.prompt - if not prompt: - if opts.template: - prompt = TemplatedPrompt.parse(opts.template, *args) - elif sys.stdin.isatty(): + prompt: str | TemplatedPrompt + if args: + prompt = TemplatedPrompt.parse(args[0], *args[1:]) + else: + if sys.stdin.isatty(): prompt = open_editor("Enter your prompt here...") else: prompt = sys.stdin.read() diff --git a/src/git_draft/bots/__init__.py b/src/git_draft/bots/__init__.py index 09bb66e..dea33a4 100644 --- a/src/git_draft/bots/__init__.py +++ b/src/git_draft/bots/__init__.py @@ -7,11 +7,12 @@ import sys from ..common import BotConfig -from .common import Action, Bot, Operation, OperationHook, Toolbox +from .common import Action, Bot, Goal, Operation, OperationHook, Toolbox __all__ = [ "Action", "Bot", + "Goal", "Operation", "OperationHook", "Toolbox", diff --git a/src/git_draft/bots/common.py b/src/git_draft/bots/common.py index eafd2f5..a6f30f4 100644 --- a/src/git_draft/bots/common.py +++ b/src/git_draft/bots/common.py @@ -91,6 +91,12 @@ class Operation: type OperationHook = Callable[[Operation], None] +@dataclasses.dataclass(frozen=True) +class Goal: + prompt: str + timeout: float | None + + @dataclasses.dataclass(frozen=True) class Action: title: str | None = None @@ -107,5 +113,5 @@ def state_folder_path(cls, ensure_exists=False) -> Path: path.mkdir(parents=True, exist_ok=True) return path - def act(self, prompt: str, toolbox: Toolbox) -> Action: + def act(self, goal: Goal, toolbox: Toolbox) -> Action: raise NotImplementedError() diff --git a/src/git_draft/bots/openai.py b/src/git_draft/bots/openai.py index c1eba33..64c1696 100644 --- a/src/git_draft/bots/openai.py +++ b/src/git_draft/bots/openai.py @@ -6,7 +6,7 @@ import textwrap from typing import Any, Mapping, Self, Sequence, override -from .common import Action, Bot, Toolbox +from .common import Action, Bot, Goal, Toolbox _logger = logging.getLogger(__name__) @@ -130,22 +130,22 @@ def create(cls, client: openai.OpenAI) -> Self: try: with open(path) as f: assistant_id = f.read() - except FileNotFoundError: + client.beta.assistants.update(assistant_id, **config) + except (FileNotFoundError, openai.NotFoundError): assistant = client.beta.assistants.create(**config) assistant_id = assistant.id with open(path, "w") as f: f.write(assistant_id) - else: - client.beta.assistants.update(assistant_id, **config) return cls(client, assistant_id) - def act(self, prompt: str, toolbox: Toolbox) -> Action: + def act(self, goal: Goal, toolbox: Toolbox) -> Action: + # TODO: Use timeout. thread = self._client.beta.threads.create() self._client.beta.threads.messages.create( thread_id=thread.id, role="user", - content=prompt, + content=goal.prompt, ) with self._client.beta.threads.runs.stream( diff --git a/src/git_draft/drafter.py b/src/git_draft/drafter.py index ad29375..233e19c 100644 --- a/src/git_draft/drafter.py +++ b/src/git_draft/drafter.py @@ -11,7 +11,7 @@ import time from typing import Match, Sequence, override -from .bots import Bot, OperationHook, Toolbox +from .bots import Bot, Goal, OperationHook, Toolbox from .common import random_id from .prompt import PromptRenderer, TemplatedPrompt from .store import Store, sql @@ -85,7 +85,7 @@ def _write(self, path: PurePosixPath, contents: str) -> None: f"{mode},{sha},{path}", add=True, cacheinfo=True ) - def update_index(self) -> None: + def trim_index(self) -> None: diff = self._repo.git.diff(name_only=True, cached=True) untouched = [ path @@ -125,9 +125,10 @@ def generate_draft( self, prompt: str | TemplatedPrompt, bot: Bot, - checkout=False, - reset=False, - sync=False, + checkout: bool = False, + reset: bool = False, + sync: bool = False, + timeout: float | None = None, ) -> None: if isinstance(prompt, str) and not prompt.strip(): raise ValueError("Empty prompt") @@ -159,11 +160,12 @@ def generate_draft( ) start_time = time.perf_counter() + goal = Goal(prompt_contents, timeout) toolbox = _Toolbox(self._repo, self._operation_hook) - action = bot.act(prompt_contents, toolbox) + action = bot.act(goal, toolbox) end_time = time.perf_counter() - toolbox.update_index() + toolbox.trim_index() title = action.title if not title: title = _default_title(prompt_contents) diff --git a/src/git_draft/prompts/add-docstrings.jinja b/src/git_draft/prompts/add-docstrings.jinja index ff2648d..bc6b8b2 100644 --- a/src/git_draft/prompts/add-docstrings.jinja +++ b/src/git_draft/prompts/add-docstrings.jinja @@ -1,4 +1,6 @@ -{% if path is defined %} +{% if symbol is defined and path is defined %} +Add a docstring to {{ symbol }} defined in {{ path }}. +{% elif path is defined %} Add docstrings to all public functions and classes in {{ path }}. {% else %} Add docstrings to all public functions and classes in this repository. diff --git a/tests/git_draft/drafter_test.py b/tests/git_draft/drafter_test.py index 059b044..3ab2e96 100644 --- a/tests/git_draft/drafter_test.py +++ b/tests/git_draft/drafter_test.py @@ -3,7 +3,7 @@ import pytest from typing import Sequence -from git_draft.bots import Action, Bot, Toolbox +from git_draft.bots import Action, Bot, Goal, Toolbox import git_draft.drafter as sut from git_draft.prompt import TemplatedPrompt from git_draft.store import Store @@ -46,8 +46,8 @@ def test_write_file(self, repo: git.Repo) -> None: class FakeBot(Bot): - def act(self, prompt: str, toolbox: Toolbox) -> Action: - toolbox.write_file(PurePosixPath("PROMPT"), prompt) + def act(self, goal: Goal, toolbox: Toolbox) -> Action: + toolbox.write_file(PurePosixPath("PROMPT"), goal.prompt) return Action() @@ -86,7 +86,7 @@ def test_generate_stages_then_resets_worktree(self) -> None: self._write("p2", "b") class CustomBot(Bot): - def act(self, prompt: str, toolbox: Toolbox) -> Action: + def act(self, _goal: Goal, toolbox: Toolbox) -> Action: assert toolbox.read_file(PurePosixPath("p1")) == "a" toolbox.write_file(PurePosixPath("p2"), "B") toolbox.write_file(PurePosixPath("p3"), "C")