diff --git a/README.md b/README.md index b62fa9a..d76fc91 100644 --- a/README.md +++ b/README.md @@ -5,11 +5,7 @@ ## Highlights -* Concurrent editing. Continue editing while the assistant runs, without any +* Concurrent edits. Continue editing while the assistant runs, without any risks of interference. - - -## Ideas - -* Add a compatibility OpenAI bot version which does not use threads, so that it - can be used with tools only. Gemini only supports the latter. +* Customizable prompt templates +* Extensible bot API diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index 49f2dba..d943b44 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -111,15 +111,13 @@ def main() -> None: ) command = getattr(opts, "command", "generate") if command == "generate": - if not config.bots: - raise ValueError("No bots configured") - + bot_config = None if opts.bot: bot_configs = [c for c in config.bots if c.name == opts.bot] if len(bot_configs) != 1: raise ValueError(f"Found {len(bot_configs)} matching bots") bot_config = bot_configs[0] - else: + elif config.bots: bot_config = config.bots[0] bot = load_bot(bot_config) diff --git a/src/git_draft/bots/__init__.py b/src/git_draft/bots/__init__.py index dea33a4..86a2cd6 100644 --- a/src/git_draft/bots/__init__.py +++ b/src/git_draft/bots/__init__.py @@ -4,9 +4,10 @@ """ import importlib +import os import sys -from ..common import BotConfig +from ..common import BotConfig, reindent from .common import Action, Bot, Goal, Operation, OperationHook, Toolbox __all__ = [ @@ -19,20 +20,11 @@ ] -def load_bot(config: BotConfig) -> Bot: - """Load and return a Bot instance using the provided configuration. +def load_bot(config: BotConfig | None) -> Bot: + """Load and return a Bot instance using the provided configuration""" + if not config: + return _default_bot() - If a pythonpath is specified in the config and not already present in - sys.path, it is added. The function expects the config.factory in the - format 'module:symbol' or 'symbol'. If only 'symbol' is provided, the - current module is used. - - Args: - config: BotConfig object containing bot configuration details. - - Raises: - NotImplementedError: If the specified factory cannot be found. - """ if config.pythonpath and config.pythonpath not in sys.path: sys.path.insert(0, config.pythonpath) @@ -44,7 +36,36 @@ def load_bot(config: BotConfig) -> Bot: factory = getattr(module, symbol, None) if not factory: - raise NotImplementedError(f"Unknown bot factory: {factory}") + raise NotImplementedError(f"Unknown bot factory: {config.factory}") kwargs = config.config or {} return factory(**kwargs) + + +def _default_bot() -> Bot: + if not os.environ.get("OPENAI_API_KEY"): + raise RuntimeError( + reindent( + """ + The default bot implementation requires an OpenAI API key. + Please specify one via the `$OPENAI_API_KEY` environment + variable or enable a different bot in your configuration. + """ + ) + ) + + try: + from .openai import threads_bot + + except ImportError: + raise RuntimeError( + reindent( + """ + The default bot implementation requires the `openai` Python + package. Please install it or specify a different bot in + your configuration. + """ + ) + ) + else: + return threads_bot() diff --git a/src/git_draft/bots/openai.py b/src/git_draft/bots/openai.py index 64c1696..31494b7 100644 --- a/src/git_draft/bots/openai.py +++ b/src/git_draft/bots/openai.py @@ -1,141 +1,239 @@ -import dataclasses +"""OpenAI API-backed bots + +They can be used with services other than OpenAPI as long as them implement a +sufficient subset of the API. For example the `completions_bot` only requires +tools support. + +See the following links for more resources: + +* https://platform.openai.com/docs/assistants/tools/function-calling +* https://platform.openai.com/docs/assistants/deep-dive#runs-and-run-steps +* https://platform.openai.com/docs/api-reference/assistants-streaming/events +* https://github.com/openai/openai-python/blob/main/src/openai/resources/beta/threads/runs/runs.py +""" + import json import logging import openai +import os from pathlib import PurePosixPath -import textwrap -from typing import Any, Mapping, Self, Sequence, override +from typing import Any, Mapping, Self, Sequence, TypedDict, override +from ..common import JSONObject, reindent from .common import Action, Bot, Goal, Toolbox _logger = logging.getLogger(__name__) +_DEFAULT_MODEL = "gpt-4o" + + +def completions_bot( + api_key: str | None = None, + base_url: str | None = None, + model: str = _DEFAULT_MODEL, +) -> Bot: + """Compatibility-mode bot, uses completions with function calling""" + if api_key and api_key.startswith("$"): + api_key = os.environ[api_key[1:]] + client = openai.OpenAI(api_key=api_key, base_url=base_url) + return _CompletionsBot(client, model) + + def threads_bot( - api_key: str | None = None, base_url: str | None = None + api_key: str | None = None, + base_url: str | None = None, + model: str = _DEFAULT_MODEL, ) -> Bot: + """Beta bot, uses assistant threads with function calling""" client = openai.OpenAI(api_key=api_key, base_url=base_url) - return _ThreadsBot.create(client) + return _ThreadsBot.create(client, model) + + +class _ToolsFactory: + def __init__(self, strict: bool) -> None: + self._strict = strict + + def _param( + self, + name: str, + description: str, + inputs: Mapping[str, Any] | None = None, + required_inputs: Sequence[str] | None = None, + ) -> openai.types.beta.FunctionToolParam: + param: openai.types.beta.FunctionToolParam = { + "type": "function", + "function": { + "name": name, + "description": reindent(description), + "parameters": { + "type": "object", + "additionalProperties": False, + "properties": inputs or {}, + "required": list(inputs.keys()) if inputs else [], + }, + }, + } + if self._strict: + param["function"]["strict"] = True + return param + + def params(self) -> Sequence[openai.types.chat.ChatCompletionToolParam]: + return [ + self._param( + name="list_files", + description="List all available files", + ), + self._param( + name="read_file", + description="Get a file's contents", + inputs={ + "path": { + "type": "string", + "description": "Path of the file to be read", + }, + }, + ), + self._param( + name="write_file", + description=""" + Set a file's contents + + The file will be created if it does not already exist. + """, + inputs={ + "path": { + "type": "string", + "description": "Path of the file to be updated", + }, + "contents": { + "type": "string", + "description": "New contents of the file", + }, + }, + ), + ] # https://aider.chat/docs/more-info.html # https://github.com/Aider-AI/aider/blob/main/aider/prompts.py -_INSTRUCTIONS = """\ +_INSTRUCTIONS = """ You are an expert software engineer, who writes correct and concise code. - Use the provided functions to find the filesyou need to answer the query, + Use the provided functions to find the files you need to answer the query, read the content of the relevant ones, and save the changes you suggest. - When writing a file, include a summary description of the changes you have - made. + + You should stop when and ONLY WHEN all the files you need to change have + been updated. """ -def _function_tool_param( - name: str, - description: str, - inputs: Mapping[str, Any] | None = None, - required_inputs: Sequence[str] | None = None, -) -> openai.types.beta.FunctionToolParam: - return { - "type": "function", - "function": { - "name": name, - "description": textwrap.dedent(description), - "parameters": { - "type": "object", - "additionalProperties": False, - "properties": inputs or {}, - "required": list(inputs.keys()) if inputs else [], - }, - "strict": True, - }, - } - - -_tools = [ - _function_tool_param( - name="list_files", - description="List all available files", - ), - _function_tool_param( - name="read_file", - description="Get a file's contents", - inputs={ - "path": { - "type": "string", - "description": "Path of the file to be read", - }, - }, - ), - _function_tool_param( - name="write_file", - description="""\ - Set a file's contents - - The file will be created if it does not already exist. - """, - inputs={ - "path": { - "type": "string", - "description": "Path of the file to be updated", - }, - "contents": { - "type": "string", - "description": "New contents of the file", - }, - "change_description": { - "type": "string", - "description": """\ - Brief description of the changes performed on this file - """, - }, - }, - ), -] +class _ToolHandler[V]: + def __init__(self, toolbox: Toolbox) -> None: + self._toolbox = toolbox + def _on_read_file(self, path: PurePosixPath, contents: str) -> V: + raise NotImplementedError() + + def _on_write_file(self, path: PurePosixPath) -> V: + raise NotImplementedError() + + def _on_list_files(self, paths: Sequence[PurePosixPath]) -> V: + raise NotImplementedError() + + def handle_function(self, function: Any) -> V: + name = function.name + inputs = json.loads(function.arguments) + _logger.info("Requested function: %s", function) + if name == "read_file": + path = PurePosixPath(inputs["path"]) + return self._on_read_file(path, self._toolbox.read_file(path)) + elif name == "write_file": + path = PurePosixPath(inputs["path"]) + contents = inputs["contents"] + self._toolbox.write_file(path, contents) + return self._on_write_file(path) + else: + assert name == "list_files" and not inputs + paths = self._toolbox.list_files() + return self._on_list_files(paths) + + +class _CompletionsBot(Bot): + def __init__(self, client: openai.OpenAI, model: str) -> None: + self._client = client + self._model = model -@dataclasses.dataclass(frozen=True) -class _AssistantConfig: - instructions: str - model: str - tools: Sequence[openai.types.beta.AssistantToolParam] + def act(self, goal: Goal, toolbox: Toolbox) -> Action: + tools = _ToolsFactory(strict=False).params() + tool_handler = _CompletionsToolHandler(toolbox) + + messages: list[openai.types.chat.ChatCompletionMessageParam] = [ + {"role": "system", "content": reindent(_INSTRUCTIONS)}, + {"role": "user", "content": goal.prompt}, + ] + + while True: + response = self._client.chat.completions.create( + model=self._model, + messages=messages, + tools=tools, + tool_choice="required", + ) + assert len(response.choices) == 1 + + done = True + calls = response.choices[0].message.tool_calls + for call in calls or []: + output = tool_handler.handle_function(call.function) + if output is not None: + done = False + messages.append({"role": "user", "content": output}) + if done: + break + + return Action() -_assistant_config = _AssistantConfig( - instructions=_INSTRUCTIONS, - model="gpt-4o", - tools=_tools, -) +class _CompletionsToolHandler(_ToolHandler[str | None]): + def _on_read_file(self, path: PurePosixPath, contents: str) -> str: + return ( + f"Here are the contents of `{path}`:\n\n```\n{contents}\n```\n" "" + ) + def _on_write_file(self, path: PurePosixPath) -> None: + return None -class _ThreadsBot(Bot): - """An OpenAI-backed bot + def _on_list_files(self, paths: Sequence[PurePosixPath]) -> str: + joined = "\n".join(f"* {p}" for p in paths) + return f"Here are the available files: {joined}" - See the following links for resources: - * https://platform.openai.com/docs/assistants/tools/function-calling - * https://platform.openai.com/docs/assistants/deep-dive#runs-and-run-steps - * https://platform.openai.com/docs/api-reference/assistants-streaming/events - * https://github.com/openai/openai-python/blob/main/src/openai/resources/beta/threads/runs/runs.py - """ +class _ThreadsBot(Bot): def __init__(self, client: openai.OpenAI, assistant_id: str) -> None: self._client = client self._assistant_id = assistant_id @classmethod - def create(cls, client: openai.OpenAI) -> Self: + def create(cls, client: openai.OpenAI, model: str) -> Self: + assistant_kwargs: JSONObject = dict( + model=model, + instructions=reindent(_INSTRUCTIONS), + tools=_ToolsFactory(strict=True).params(), + ) + path = cls.state_folder_path(ensure_exists=True) / "ASSISTANT_ID" - config = dataclasses.asdict(_assistant_config) try: with open(path) as f: assistant_id = f.read() - client.beta.assistants.update(assistant_id, **config) + client.beta.assistants.update(assistant_id, **assistant_kwargs) except (FileNotFoundError, openai.NotFoundError): - assistant = client.beta.assistants.create(**config) + assistant = client.beta.assistants.create(**assistant_kwargs) assistant_id = assistant.id with open(path, "w") as f: f.write(assistant_id) + return cls(client, assistant_id) def act(self, goal: Goal, toolbox: Toolbox) -> Action: @@ -178,21 +276,8 @@ def on_event(self, event: Any) -> None: def _handle_action(self, run_id: str, data: Any) -> None: tool_outputs = list[Any]() for tool in data.required_action.submit_tool_outputs.tool_calls: - name = tool.function.name - inputs = json.loads(tool.function.arguments) - _logger.info("Requested tool: %s", tool) - if name == "read_file": - path = PurePosixPath(inputs["path"]) - output = self._toolbox.read_file(path) - elif name == "write_file": - path = PurePosixPath(inputs["path"]) - contents = inputs["contents"] - self._toolbox.write_file(path, contents) - output = "OK" - elif name == "list_files": - assert not inputs - output = "\n".join(str(p) for p in self._toolbox.list_files()) - tool_outputs.append({"tool_call_id": tool.id, "output": output}) + handler = _ThreadToolHandler(self._toolbox, tool.id) + tool_outputs.append(handler.handle_function(tool.function)) run = self.current_run assert run, "No ongoing run" @@ -203,3 +288,26 @@ def _handle_action(self, run_id: str, data: Any) -> None: event_handler=self.clone(), ) as stream: stream.until_done() + + +class _ToolOutput(TypedDict): + tool_call_id: str + output: str + + +class _ThreadToolHandler(_ToolHandler[_ToolOutput]): + def __init__(self, toolbox: Toolbox, call_id: str) -> None: + super().__init__(toolbox) + self._call_id = call_id + + def _wrap(self, output: str) -> _ToolOutput: + return _ToolOutput(tool_call_id=self._call_id, output=output) + + def _on_read_file(self, path: PurePosixPath, contents: str) -> _ToolOutput: + return self._wrap(contents) + + def _on_write_file(self, path: PurePosixPath) -> _ToolOutput: + return self._wrap("OK") + + def _on_list_files(self, paths: Sequence[PurePosixPath]) -> _ToolOutput: + return self._wrap("\n".join((str(p) for p in paths))) diff --git a/src/git_draft/common.py b/src/git_draft/common.py index 52accf2..33df3ce 100644 --- a/src/git_draft/common.py +++ b/src/git_draft/common.py @@ -3,10 +3,12 @@ from __future__ import annotations import dataclasses +import itertools import logging from pathlib import Path import random import string +import textwrap import tomllib from typing import Any, Mapping, Self, Sequence import xdg_base_dirs @@ -69,8 +71,21 @@ class BotConfig: def random_id(n: int) -> str: + """Generates a random length n string of lowercase letters and digits""" return "".join(_random.choices(_alphabet, k=n)) class UnreachableError(RuntimeError): - pass + """Indicates unreachable code was unexpectedly executed""" + + +def reindent(s: str, width=0) -> str: + """Reindents text by dedenting and optionally wrapping paragraphs""" + paragraphs = ( + " ".join(textwrap.dedent("\n".join(g)).splitlines()) + for b, g in itertools.groupby(s.splitlines(), bool) + if b + ) + return "\n\n".join( + textwrap.fill(p, width=width) if width else p for p in paragraphs + ) diff --git a/tests/git_draft/bots/__init___test.py b/tests/git_draft/bots/__init___test.py index 6473beb..f9749eb 100644 --- a/tests/git_draft/bots/__init___test.py +++ b/tests/git_draft/bots/__init___test.py @@ -2,11 +2,11 @@ import sys import pytest -from git_draft.bots import Bot, load_bot +import git_draft.bots as sut from git_draft.common import BotConfig -class FakeBot(Bot): +class FakeBot(sut.Bot): pass @@ -19,10 +19,15 @@ def import_module(name): monkeypatch.setattr(importlib, "import_module", import_module) config = BotConfig(factory="fake_module:FakeBot") - bot = load_bot(config) + bot = sut.load_bot(config) assert isinstance(bot, FakeBot) def test_non_existing_factory(self) -> None: config = BotConfig("git_draft:unknown_factory") with pytest.raises(NotImplementedError): - load_bot(config) + sut.load_bot(config) + + def test_default_no_key(self, monkeypatch) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "") + with pytest.raises(RuntimeError): + sut.load_bot(None) diff --git a/tests/git_draft/common_test.py b/tests/git_draft/common_test.py index c5e73d3..042dc7c 100644 --- a/tests/git_draft/common_test.py +++ b/tests/git_draft/common_test.py @@ -59,3 +59,27 @@ def test_load_ok(self) -> None: def test_load_default(self) -> None: config = sut.Config.load() assert config.log_level == logging.INFO + + +@pytest.mark.parametrize( + "text,width,want", + [ + ("", 10, ""), + ("abc", 5, "abc"), + ("ab", 0, "ab"), + ("\nabc def", 4, "abc\ndef"), + (" abc\n def ", 10, "abc def"), + ( + """ + This is a fun paragraph + which continues. + + And another. + """, + 60, + "This is a fun paragraph which continues.\n\nAnd another.", + ), + ], +) +def test_reindent(text, width, want): + assert sut.reindent(text, width) == want