Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,5 @@

## Ideas

* Change prompt CLI inputs to `[PROMPT] [--] [ARGS]`. If `PROMPT` does not
contain any spaces or `ARGS` (or `--`) is present, it will be interpreted as a
template name. Otherwise an inline prompt.
* Add `--generate` timeout option.
* Add a compatibility OpenAI bot version which does not use threads, so that it
can be used with tools only. Gemini only supports the latter.
23 changes: 9 additions & 14 deletions src/git_draft/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,6 @@ def callback(_option, _opt, _value, parser) -> None:
help="delete draft after finalizing or discarding",
action="store_true",
)
parser.add_option(
"-p",
"--prompt",
dest="prompt",
help="inline prompt",
)
parser.add_option(
"-r",
"--reset",
Expand All @@ -88,9 +82,9 @@ def callback(_option, _opt, _value, parser) -> None:
)
parser.add_option(
"-t",
"--template",
dest="template",
help="prompt template",
"--timeout",
dest="timeout",
help="bot generation timeout",
)

return parser
Expand Down Expand Up @@ -119,6 +113,7 @@ def main() -> None:
if command == "generate":
if not config.bots:
raise ValueError("No bots configured")

if opts.bot:
bot_configs = [c for c in config.bots if c.name == opts.bot]
if len(bot_configs) != 1:
Expand All @@ -128,11 +123,11 @@ def main() -> None:
bot_config = config.bots[0]
bot = load_bot(bot_config)

prompt = opts.prompt
if not prompt:
if opts.template:
prompt = TemplatedPrompt.parse(opts.template, *args)
elif sys.stdin.isatty():
prompt: str | TemplatedPrompt
if args:
prompt = TemplatedPrompt.parse(args[0], *args[1:])
else:
if sys.stdin.isatty():
prompt = open_editor("Enter your prompt here...")
else:
prompt = sys.stdin.read()
Expand Down
3 changes: 2 additions & 1 deletion src/git_draft/bots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
import sys

from ..common import BotConfig
from .common import Action, Bot, Operation, OperationHook, Toolbox
from .common import Action, Bot, Goal, Operation, OperationHook, Toolbox

__all__ = [
"Action",
"Bot",
"Goal",
"Operation",
"OperationHook",
"Toolbox",
Expand Down
8 changes: 7 additions & 1 deletion src/git_draft/bots/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ class Operation:
type OperationHook = Callable[[Operation], None]


@dataclasses.dataclass(frozen=True)
class Goal:
prompt: str
timeout: float | None


@dataclasses.dataclass(frozen=True)
class Action:
title: str | None = None
Expand All @@ -107,5 +113,5 @@ def state_folder_path(cls, ensure_exists=False) -> Path:
path.mkdir(parents=True, exist_ok=True)
return path

def act(self, prompt: str, toolbox: Toolbox) -> Action:
def act(self, goal: Goal, toolbox: Toolbox) -> Action:
raise NotImplementedError()
12 changes: 6 additions & 6 deletions src/git_draft/bots/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import textwrap
from typing import Any, Mapping, Self, Sequence, override

from .common import Action, Bot, Toolbox
from .common import Action, Bot, Goal, Toolbox


_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -130,22 +130,22 @@ def create(cls, client: openai.OpenAI) -> Self:
try:
with open(path) as f:
assistant_id = f.read()
except FileNotFoundError:
client.beta.assistants.update(assistant_id, **config)
except (FileNotFoundError, openai.NotFoundError):
assistant = client.beta.assistants.create(**config)
assistant_id = assistant.id
with open(path, "w") as f:
f.write(assistant_id)
else:
client.beta.assistants.update(assistant_id, **config)
return cls(client, assistant_id)

def act(self, prompt: str, toolbox: Toolbox) -> Action:
def act(self, goal: Goal, toolbox: Toolbox) -> Action:
# TODO: Use timeout.
thread = self._client.beta.threads.create()

self._client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content=prompt,
content=goal.prompt,
)

with self._client.beta.threads.runs.stream(
Expand Down
16 changes: 9 additions & 7 deletions src/git_draft/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import time
from typing import Match, Sequence, override

from .bots import Bot, OperationHook, Toolbox
from .bots import Bot, Goal, OperationHook, Toolbox
from .common import random_id
from .prompt import PromptRenderer, TemplatedPrompt
from .store import Store, sql
Expand Down Expand Up @@ -85,7 +85,7 @@ def _write(self, path: PurePosixPath, contents: str) -> None:
f"{mode},{sha},{path}", add=True, cacheinfo=True
)

def update_index(self) -> None:
def trim_index(self) -> None:
diff = self._repo.git.diff(name_only=True, cached=True)
untouched = [
path
Expand Down Expand Up @@ -125,9 +125,10 @@ def generate_draft(
self,
prompt: str | TemplatedPrompt,
bot: Bot,
checkout=False,
reset=False,
sync=False,
checkout: bool = False,
reset: bool = False,
sync: bool = False,
timeout: float | None = None,
) -> None:
if isinstance(prompt, str) and not prompt.strip():
raise ValueError("Empty prompt")
Expand Down Expand Up @@ -159,11 +160,12 @@ def generate_draft(
)

start_time = time.perf_counter()
goal = Goal(prompt_contents, timeout)
toolbox = _Toolbox(self._repo, self._operation_hook)
action = bot.act(prompt_contents, toolbox)
action = bot.act(goal, toolbox)
end_time = time.perf_counter()

toolbox.update_index()
toolbox.trim_index()
title = action.title
if not title:
title = _default_title(prompt_contents)
Expand Down
4 changes: 3 additions & 1 deletion src/git_draft/prompts/add-docstrings.jinja
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
{% if path is defined %}
{% if symbol is defined and path is defined %}
Add a docstring to {{ symbol }} defined in {{ path }}.
{% elif path is defined %}
Add docstrings to all public functions and classes in {{ path }}.
{% else %}
Add docstrings to all public functions and classes in this repository.
Expand Down
8 changes: 4 additions & 4 deletions tests/git_draft/drafter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from typing import Sequence

from git_draft.bots import Action, Bot, Toolbox
from git_draft.bots import Action, Bot, Goal, Toolbox
import git_draft.drafter as sut
from git_draft.prompt import TemplatedPrompt
from git_draft.store import Store
Expand Down Expand Up @@ -46,8 +46,8 @@ def test_write_file(self, repo: git.Repo) -> None:


class FakeBot(Bot):
def act(self, prompt: str, toolbox: Toolbox) -> Action:
toolbox.write_file(PurePosixPath("PROMPT"), prompt)
def act(self, goal: Goal, toolbox: Toolbox) -> Action:
toolbox.write_file(PurePosixPath("PROMPT"), goal.prompt)
return Action()


Expand Down Expand Up @@ -86,7 +86,7 @@ def test_generate_stages_then_resets_worktree(self) -> None:
self._write("p2", "b")

class CustomBot(Bot):
def act(self, prompt: str, toolbox: Toolbox) -> Action:
def act(self, _goal: Goal, toolbox: Toolbox) -> Action:
assert toolbox.read_file(PurePosixPath("p1")) == "a"
toolbox.write_file(PurePosixPath("p2"), "B")
toolbox.write_file(PurePosixPath("p3"), "C")
Expand Down