From 4588f1f0ff55c49d8f37691e3f0539026956f227 Mon Sep 17 00:00:00 2001 From: Matthieu Monsch Date: Wed, 5 Mar 2025 07:50:02 -0800 Subject: [PATCH 1/8] feat: add completions-based OpenAI bot --- README.md | 10 +- src/git_draft/bots/openai.py | 175 +++++++++++++++++++---------------- 2 files changed, 98 insertions(+), 87 deletions(-) diff --git a/README.md b/README.md index b62fa9a..d76fc91 100644 --- a/README.md +++ b/README.md @@ -5,11 +5,7 @@ ## Highlights -* Concurrent editing. Continue editing while the assistant runs, without any +* Concurrent edits. Continue editing while the assistant runs, without any risks of interference. - - -## Ideas - -* Add a compatibility OpenAI bot version which does not use threads, so that it - can be used with tools only. Gemini only supports the latter. +* Customizable prompt templates +* Extensible bot API diff --git a/src/git_draft/bots/openai.py b/src/git_draft/bots/openai.py index 64c1696..48de4bc 100644 --- a/src/git_draft/bots/openai.py +++ b/src/git_draft/bots/openai.py @@ -12,13 +12,92 @@ _logger = logging.getLogger(__name__) +_DEFAULT_MODEL = "gpt-4o" + + +def completions_bot( + api_key: str | None = None, + base_url: str | None = None, + model: str | None = _DEFAULT_MODEL, +) -> Bot: + client = openai.OpenAI(api_key=api_key, base_url=base_url) + return _CompletionsBot(client) + + def threads_bot( - api_key: str | None = None, base_url: str | None = None + api_key: str | None = None, + base_url: str | None = None, + model: str | None = _DEFAULT_MODEL, ) -> Bot: client = openai.OpenAI(api_key=api_key, base_url=base_url) return _ThreadsBot.create(client) +class _ToolsFactory: + def __init__(self, strict: bool) -> None: + self._strict = strict + + def _param( + self, + name: str, + description: str, + inputs: Mapping[str, Any] | None = None, + required_inputs: Sequence[str] | None = None, + ) -> openai.types.beta.FunctionToolParam: + param: openai.types.beta.FunctionToolParam = { + "type": "function", + "function": { + "name": name, + "description": textwrap.dedent(description), + "parameters": { + "type": "object", + "additionalProperties": False, + "properties": inputs or {}, + "required": list(inputs.keys()) if inputs else [], + }, + }, + } + if self._strict: + param["function"]["strict"] = True + return param + + def params(self) -> Sequence[openai.types.beta.AssistantToolParam]: + return [ + self._param( + name="list_files", + description="List all available files", + ), + self._param( + name="read_file", + description="Get a file's contents", + inputs={ + "path": { + "type": "string", + "description": "Path of the file to be read", + }, + }, + ), + self._param( + name="write_file", + description="""\ + Set a file's contents + + The file will be created if it does not already exist. + """, + inputs={ + "path": { + "type": "string", + "description": "Path of the file to be updated", + }, + "contents": { + "type": "string", + "description": "New contents of the file", + }, + }, + ), + ] + + # https://aider.chat/docs/more-info.html # https://github.com/Aider-AI/aider/blob/main/aider/prompts.py _INSTRUCTIONS = """\ @@ -30,82 +109,12 @@ def threads_bot( """ -def _function_tool_param( - name: str, - description: str, - inputs: Mapping[str, Any] | None = None, - required_inputs: Sequence[str] | None = None, -) -> openai.types.beta.FunctionToolParam: - return { - "type": "function", - "function": { - "name": name, - "description": textwrap.dedent(description), - "parameters": { - "type": "object", - "additionalProperties": False, - "properties": inputs or {}, - "required": list(inputs.keys()) if inputs else [], - }, - "strict": True, - }, - } - - -_tools = [ - _function_tool_param( - name="list_files", - description="List all available files", - ), - _function_tool_param( - name="read_file", - description="Get a file's contents", - inputs={ - "path": { - "type": "string", - "description": "Path of the file to be read", - }, - }, - ), - _function_tool_param( - name="write_file", - description="""\ - Set a file's contents - - The file will be created if it does not already exist. - """, - inputs={ - "path": { - "type": "string", - "description": "Path of the file to be updated", - }, - "contents": { - "type": "string", - "description": "New contents of the file", - }, - "change_description": { - "type": "string", - "description": """\ - Brief description of the changes performed on this file - """, - }, - }, - ), -] - - -@dataclasses.dataclass(frozen=True) -class _AssistantConfig: - instructions: str - model: str - tools: Sequence[openai.types.beta.AssistantToolParam] +class _CompletionsBot(Bot): + def __init__(self, model: str) -> None: + self._model = model - -_assistant_config = _AssistantConfig( - instructions=_INSTRUCTIONS, - model="gpt-4o", - tools=_tools, -) + def act(self, goal: Goal, toolbox: Toolbox) -> Action: + raise NotImplementedError() # TODO class _ThreadsBot(Bot): @@ -124,18 +133,24 @@ def __init__(self, client: openai.OpenAI, assistant_id: str) -> None: self._assistant_id = assistant_id @classmethod - def create(cls, client: openai.OpenAI) -> Self: + def create(cls, client: openai.OpenAI, model: str) -> Self: + assistant_kwargs = dict( + model=model, + instructions=_INSTRUCTIONS, + tools=_ToolsFactory(strict=True).params(), + ) + path = cls.state_folder_path(ensure_exists=True) / "ASSISTANT_ID" - config = dataclasses.asdict(_assistant_config) try: with open(path) as f: assistant_id = f.read() - client.beta.assistants.update(assistant_id, **config) + client.beta.assistants.update(assistant_id, **kwargs) except (FileNotFoundError, openai.NotFoundError): - assistant = client.beta.assistants.create(**config) + assistant = client.beta.assistants.create(**kwargs) assistant_id = assistant.id with open(path, "w") as f: f.write(assistant_id) + return cls(client, assistant_id) def act(self, goal: Goal, toolbox: Toolbox) -> Action: From 067b1c439a65d752c15c5ca49d50e23bc76fd3db Mon Sep 17 00:00:00 2001 From: Matthieu Monsch Date: Wed, 5 Mar 2025 07:59:01 -0800 Subject: [PATCH 2/8] fixup! feat: add completions-based OpenAI bot --- src/git_draft/bots/openai.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/git_draft/bots/openai.py b/src/git_draft/bots/openai.py index 48de4bc..c8a6cfe 100644 --- a/src/git_draft/bots/openai.py +++ b/src/git_draft/bots/openai.py @@ -1,4 +1,3 @@ -import dataclasses import json import logging import openai @@ -6,6 +5,7 @@ import textwrap from typing import Any, Mapping, Self, Sequence, override +from ..common import JSONObject from .common import Action, Bot, Goal, Toolbox @@ -18,19 +18,19 @@ def completions_bot( api_key: str | None = None, base_url: str | None = None, - model: str | None = _DEFAULT_MODEL, + model: str = _DEFAULT_MODEL, ) -> Bot: client = openai.OpenAI(api_key=api_key, base_url=base_url) - return _CompletionsBot(client) + return _CompletionsBot(client, model) def threads_bot( api_key: str | None = None, base_url: str | None = None, - model: str | None = _DEFAULT_MODEL, + model: str = _DEFAULT_MODEL, ) -> Bot: client = openai.OpenAI(api_key=api_key, base_url=base_url) - return _ThreadsBot.create(client) + return _ThreadsBot.create(client, model) class _ToolsFactory: @@ -110,7 +110,8 @@ def params(self) -> Sequence[openai.types.beta.AssistantToolParam]: class _CompletionsBot(Bot): - def __init__(self, model: str) -> None: + def __init__(self, client: openai.OpenAI, model: str) -> None: + self._client = client self._model = model def act(self, goal: Goal, toolbox: Toolbox) -> Action: @@ -134,7 +135,7 @@ def __init__(self, client: openai.OpenAI, assistant_id: str) -> None: @classmethod def create(cls, client: openai.OpenAI, model: str) -> Self: - assistant_kwargs = dict( + assistant_kwargs: JSONObject = dict( model=model, instructions=_INSTRUCTIONS, tools=_ToolsFactory(strict=True).params(), @@ -144,9 +145,9 @@ def create(cls, client: openai.OpenAI, model: str) -> Self: try: with open(path) as f: assistant_id = f.read() - client.beta.assistants.update(assistant_id, **kwargs) + client.beta.assistants.update(assistant_id, **assistant_kwargs) except (FileNotFoundError, openai.NotFoundError): - assistant = client.beta.assistants.create(**kwargs) + assistant = client.beta.assistants.create(**assistant_kwargs) assistant_id = assistant.id with open(path, "w") as f: f.write(assistant_id) From 9a8d71b12df914d6e0e6d2002f5b536b5ebd6c1d Mon Sep 17 00:00:00 2001 From: Matthieu Monsch Date: Wed, 5 Mar 2025 08:20:19 -0800 Subject: [PATCH 3/8] fixup! feat: add completions-based OpenAI bot --- src/git_draft/__main__.py | 6 ++-- src/git_draft/bots/__init__.py | 50 ++++++++++++++++++++++++---------- src/git_draft/bots/openai.py | 2 ++ src/git_draft/common.py | 4 +++ 4 files changed, 44 insertions(+), 18 deletions(-) diff --git a/src/git_draft/__main__.py b/src/git_draft/__main__.py index 49f2dba..d943b44 100644 --- a/src/git_draft/__main__.py +++ b/src/git_draft/__main__.py @@ -111,15 +111,13 @@ def main() -> None: ) command = getattr(opts, "command", "generate") if command == "generate": - if not config.bots: - raise ValueError("No bots configured") - + bot_config = None if opts.bot: bot_configs = [c for c in config.bots if c.name == opts.bot] if len(bot_configs) != 1: raise ValueError(f"Found {len(bot_configs)} matching bots") bot_config = bot_configs[0] - else: + elif config.bots: bot_config = config.bots[0] bot = load_bot(bot_config) diff --git a/src/git_draft/bots/__init__.py b/src/git_draft/bots/__init__.py index dea33a4..76644d7 100644 --- a/src/git_draft/bots/__init__.py +++ b/src/git_draft/bots/__init__.py @@ -4,9 +4,11 @@ """ import importlib +import os import sys +import textwrap -from ..common import BotConfig +from ..common import BotConfig, reindent from .common import Action, Bot, Goal, Operation, OperationHook, Toolbox __all__ = [ @@ -19,20 +21,11 @@ ] -def load_bot(config: BotConfig) -> Bot: - """Load and return a Bot instance using the provided configuration. +def load_bot(config: BotConfig | None) -> Bot: + """Load and return a Bot instance using the provided configuration""" + if not config: + return _default_bot() - If a pythonpath is specified in the config and not already present in - sys.path, it is added. The function expects the config.factory in the - format 'module:symbol' or 'symbol'. If only 'symbol' is provided, the - current module is used. - - Args: - config: BotConfig object containing bot configuration details. - - Raises: - NotImplementedError: If the specified factory cannot be found. - """ if config.pythonpath and config.pythonpath not in sys.path: sys.path.insert(0, config.pythonpath) @@ -48,3 +41,32 @@ def load_bot(config: BotConfig) -> Bot: kwargs = config.config or {} return factory(**kwargs) + + +def _default_bot() -> Bot: + if not "OPENAI_API_KEY" in os.environ: + raise RuntimeError( + reindent( + """ + The default bot implementation requires an OpenAI API key. + Please specify one via the `$OPENAI_API_KEY` environment + variable or enable a different bot in your configuration. + """ + ) + ) + + try: + from .openai import threads_bot + + except ImportError: + raise RuntimeError( + reindent( + """ + The default bot implementation requires the `openai` Python + package. Please install it or specify a different bot in + your configuration. + """ + ) + ) + else: + return threads_bot() diff --git a/src/git_draft/bots/openai.py b/src/git_draft/bots/openai.py index c8a6cfe..dd16c5c 100644 --- a/src/git_draft/bots/openai.py +++ b/src/git_draft/bots/openai.py @@ -20,6 +20,7 @@ def completions_bot( base_url: str | None = None, model: str = _DEFAULT_MODEL, ) -> Bot: + """Compatibility-mode bot, uses completions with function calling""" client = openai.OpenAI(api_key=api_key, base_url=base_url) return _CompletionsBot(client, model) @@ -29,6 +30,7 @@ def threads_bot( base_url: str | None = None, model: str = _DEFAULT_MODEL, ) -> Bot: + """Beta bot, uses assistant threads with function calling""" client = openai.OpenAI(api_key=api_key, base_url=base_url) return _ThreadsBot.create(client, model) diff --git a/src/git_draft/common.py b/src/git_draft/common.py index 52accf2..0346e4d 100644 --- a/src/git_draft/common.py +++ b/src/git_draft/common.py @@ -74,3 +74,7 @@ def random_id(n: int) -> str: class UnreachableError(RuntimeError): pass + + +def reindent(s: str) -> str: + raise NotImplementedError() # TODO From 9500d89fe88eab42ad6bd3715645cc176493fdd6 Mon Sep 17 00:00:00 2001 From: Matthieu Monsch Date: Wed, 5 Mar 2025 18:45:12 -0800 Subject: [PATCH 4/8] fixup! feat: add completions-based OpenAI bot --- src/git_draft/bots/__init__.py | 3 +-- src/git_draft/common.py | 14 ++++++++++++-- tests/git_draft/common_test.py | 24 ++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/git_draft/bots/__init__.py b/src/git_draft/bots/__init__.py index 76644d7..5637031 100644 --- a/src/git_draft/bots/__init__.py +++ b/src/git_draft/bots/__init__.py @@ -6,7 +6,6 @@ import importlib import os import sys -import textwrap from ..common import BotConfig, reindent from .common import Action, Bot, Goal, Operation, OperationHook, Toolbox @@ -44,7 +43,7 @@ def load_bot(config: BotConfig | None) -> Bot: def _default_bot() -> Bot: - if not "OPENAI_API_KEY" in os.environ: + if "OPENAI_API_KEY" not in os.environ: raise RuntimeError( reindent( """ diff --git a/src/git_draft/common.py b/src/git_draft/common.py index 0346e4d..6d4a365 100644 --- a/src/git_draft/common.py +++ b/src/git_draft/common.py @@ -3,10 +3,12 @@ from __future__ import annotations import dataclasses +import itertools import logging from pathlib import Path import random import string +import textwrap import tomllib from typing import Any, Mapping, Self, Sequence import xdg_base_dirs @@ -76,5 +78,13 @@ class UnreachableError(RuntimeError): pass -def reindent(s: str) -> str: - raise NotImplementedError() # TODO +def reindent(s: str, width=0) -> str: + paragraphs = ( + " ".join(textwrap.dedent("\n".join(g)).splitlines()) + for b, g in itertools.groupby(s.splitlines(), bool) + if b + ) + return "\n\n".join( + textwrap.fill(p, width=width) if width else p + for p in paragraphs + ) diff --git a/tests/git_draft/common_test.py b/tests/git_draft/common_test.py index c5e73d3..042dc7c 100644 --- a/tests/git_draft/common_test.py +++ b/tests/git_draft/common_test.py @@ -59,3 +59,27 @@ def test_load_ok(self) -> None: def test_load_default(self) -> None: config = sut.Config.load() assert config.log_level == logging.INFO + + +@pytest.mark.parametrize( + "text,width,want", + [ + ("", 10, ""), + ("abc", 5, "abc"), + ("ab", 0, "ab"), + ("\nabc def", 4, "abc\ndef"), + (" abc\n def ", 10, "abc def"), + ( + """ + This is a fun paragraph + which continues. + + And another. + """, + 60, + "This is a fun paragraph which continues.\n\nAnd another.", + ), + ], +) +def test_reindent(text, width, want): + assert sut.reindent(text, width) == want From aa012947c5d06c9c8061692b10f47c6859fd7f1e Mon Sep 17 00:00:00 2001 From: Matthieu Monsch Date: Wed, 5 Mar 2025 18:50:12 -0800 Subject: [PATCH 5/8] fixup! feat: add completions-based OpenAI bot --- src/git_draft/bots/openai.py | 38 +++++++++++++++++++----------------- src/git_draft/common.py | 3 +-- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/src/git_draft/bots/openai.py b/src/git_draft/bots/openai.py index dd16c5c..ed77234 100644 --- a/src/git_draft/bots/openai.py +++ b/src/git_draft/bots/openai.py @@ -1,11 +1,24 @@ +"""OpenAI API-backed bots + +They can be used with services other than OpenAPI as long as them implement a +sufficient subset of the API. For example the `completions_bot` only requires +tools support. + +See the following links for more resources: + +* https://platform.openai.com/docs/assistants/tools/function-calling +* https://platform.openai.com/docs/assistants/deep-dive#runs-and-run-steps +* https://platform.openai.com/docs/api-reference/assistants-streaming/events +* https://github.com/openai/openai-python/blob/main/src/openai/resources/beta/threads/runs/runs.py +""" + import json import logging import openai from pathlib import PurePosixPath -import textwrap from typing import Any, Mapping, Self, Sequence, override -from ..common import JSONObject +from ..common import JSONObject, reindent from .common import Action, Bot, Goal, Toolbox @@ -50,7 +63,7 @@ def _param( "type": "function", "function": { "name": name, - "description": textwrap.dedent(description), + "description": reindent(description), "parameters": { "type": "object", "additionalProperties": False, @@ -81,7 +94,7 @@ def params(self) -> Sequence[openai.types.beta.AssistantToolParam]: ), self._param( name="write_file", - description="""\ + description=""" Set a file's contents The file will be created if it does not already exist. @@ -102,12 +115,10 @@ def params(self) -> Sequence[openai.types.beta.AssistantToolParam]: # https://aider.chat/docs/more-info.html # https://github.com/Aider-AI/aider/blob/main/aider/prompts.py -_INSTRUCTIONS = """\ +_INSTRUCTIONS = """ You are an expert software engineer, who writes correct and concise code. - Use the provided functions to find the filesyou need to answer the query, + Use the provided functions to find the files you need to answer the query, read the content of the relevant ones, and save the changes you suggest. - When writing a file, include a summary description of the changes you have - made. """ @@ -121,15 +132,6 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action: class _ThreadsBot(Bot): - """An OpenAI-backed bot - - See the following links for resources: - - * https://platform.openai.com/docs/assistants/tools/function-calling - * https://platform.openai.com/docs/assistants/deep-dive#runs-and-run-steps - * https://platform.openai.com/docs/api-reference/assistants-streaming/events - * https://github.com/openai/openai-python/blob/main/src/openai/resources/beta/threads/runs/runs.py - """ def __init__(self, client: openai.OpenAI, assistant_id: str) -> None: self._client = client @@ -139,7 +141,7 @@ def __init__(self, client: openai.OpenAI, assistant_id: str) -> None: def create(cls, client: openai.OpenAI, model: str) -> Self: assistant_kwargs: JSONObject = dict( model=model, - instructions=_INSTRUCTIONS, + instructions=reindent(_INSTRUCTIONS), tools=_ToolsFactory(strict=True).params(), ) diff --git a/src/git_draft/common.py b/src/git_draft/common.py index 6d4a365..4c67444 100644 --- a/src/git_draft/common.py +++ b/src/git_draft/common.py @@ -85,6 +85,5 @@ def reindent(s: str, width=0) -> str: if b ) return "\n\n".join( - textwrap.fill(p, width=width) if width else p - for p in paragraphs + textwrap.fill(p, width=width) if width else p for p in paragraphs ) From 006407481e173f6b10700d00128ca5e5aa073feb Mon Sep 17 00:00:00 2001 From: Matthieu Monsch Date: Wed, 5 Mar 2025 19:42:12 -0800 Subject: [PATCH 6/8] fixup! feat: add completions-based OpenAI bot --- src/git_draft/bots/__init__.py | 2 +- src/git_draft/bots/openai.py | 119 ++++++++++++++++++++++++++++----- src/git_draft/common.py | 4 +- 3 files changed, 105 insertions(+), 20 deletions(-) diff --git a/src/git_draft/bots/__init__.py b/src/git_draft/bots/__init__.py index 5637031..a614a57 100644 --- a/src/git_draft/bots/__init__.py +++ b/src/git_draft/bots/__init__.py @@ -36,7 +36,7 @@ def load_bot(config: BotConfig | None) -> Bot: factory = getattr(module, symbol, None) if not factory: - raise NotImplementedError(f"Unknown bot factory: {factory}") + raise NotImplementedError(f"Unknown bot factory: {config.factory}") kwargs = config.config or {} return factory(**kwargs) diff --git a/src/git_draft/bots/openai.py b/src/git_draft/bots/openai.py index ed77234..3c11cf3 100644 --- a/src/git_draft/bots/openai.py +++ b/src/git_draft/bots/openai.py @@ -15,8 +15,9 @@ import json import logging import openai +import os from pathlib import PurePosixPath -from typing import Any, Mapping, Self, Sequence, override +from typing import Any, Mapping, Self, Sequence, TypedDict, override from ..common import JSONObject, reindent from .common import Action, Bot, Goal, Toolbox @@ -34,6 +35,8 @@ def completions_bot( model: str = _DEFAULT_MODEL, ) -> Bot: """Compatibility-mode bot, uses completions with function calling""" + if api_key and api_key.startswith("$"): + api_key = os.environ[api_key[1:]] client = openai.OpenAI(api_key=api_key, base_url=base_url) return _CompletionsBot(client, model) @@ -76,7 +79,7 @@ def _param( param["function"]["strict"] = True return param - def params(self) -> Sequence[openai.types.beta.AssistantToolParam]: + def params(self) -> Sequence[openai.types.chat.ChatCompletionToolParam]: return [ self._param( name="list_files", @@ -122,13 +125,83 @@ def params(self) -> Sequence[openai.types.beta.AssistantToolParam]: """ +class _ToolHandler[V]: + def __init__(self, toolbox: Toolbox) -> None: + self._toolbox = toolbox + + def _on_read_file(self, path: PurePosixPath, contents: str) -> V: + raise NotImplementedError() + + def _on_write_file(self, path: PurePosixPath) -> V: + raise NotImplementedError() + + def _on_list_files(self, paths: Sequence[PurePosixPath]) -> V: + raise NotImplementedError() + + def handle_function(self, function: Any) -> V: + name = function.name + inputs = json.loads(function.arguments) + _logger.info("Requested function: %s", function) + if name == "read_file": + path = PurePosixPath(inputs["path"]) + return self._on_read_file(path, self._toolbox.read_file(path)) + elif name == "write_file": + path = PurePosixPath(inputs["path"]) + contents = inputs["contents"] + self._toolbox.write_file(path, contents) + return self._on_write_file(path) + else: + assert name == "list_files" and not inputs + paths = self._toolbox.list_files() + return self._on_list_files(paths) + + class _CompletionsBot(Bot): def __init__(self, client: openai.OpenAI, model: str) -> None: self._client = client self._model = model def act(self, goal: Goal, toolbox: Toolbox) -> Action: - raise NotImplementedError() # TODO + tools = _ToolsFactory(strict=False).params() + tool_handler = _CompletionsToolHandler(toolbox) + + messages: list[openai.types.chat.ChatCompletionMessageParam] = [ + {"role": "system", "content": reindent(_INSTRUCTIONS)}, + {"role": "user", "content": goal.prompt}, + ] + + while True: + response = self._client.chat.completions.create( + model=self._model, + messages=messages, + tools=tools, + ) + calls = response.choices[0].message.tool_calls + if not calls: + break + for call in calls: + messages.append( + { + "role": "user", + "content": tool_handler.handle_function(call.function), + } + ) + + return Action() + + +class _CompletionsToolHandler(_ToolHandler[str]): + def _on_read_file(self, path: PurePosixPath, contents: str) -> str: + return ( + f"Here are the contents of `{path}`:\n\n```\n{contents}\n```\n" "" + ) + + def _on_write_file(self, path: PurePosixPath) -> str: + return f"The file at {path} has been updated to match." + + def _on_list_files(self, paths: Sequence[PurePosixPath]) -> str: + joined = "\n".join(f"* {p}" for p in paths) + return f"Here are the available files: {joined}" class _ThreadsBot(Bot): @@ -198,21 +271,8 @@ def on_event(self, event: Any) -> None: def _handle_action(self, run_id: str, data: Any) -> None: tool_outputs = list[Any]() for tool in data.required_action.submit_tool_outputs.tool_calls: - name = tool.function.name - inputs = json.loads(tool.function.arguments) - _logger.info("Requested tool: %s", tool) - if name == "read_file": - path = PurePosixPath(inputs["path"]) - output = self._toolbox.read_file(path) - elif name == "write_file": - path = PurePosixPath(inputs["path"]) - contents = inputs["contents"] - self._toolbox.write_file(path, contents) - output = "OK" - elif name == "list_files": - assert not inputs - output = "\n".join(str(p) for p in self._toolbox.list_files()) - tool_outputs.append({"tool_call_id": tool.id, "output": output}) + handler = _ThreadToolHandler(self._toolbox, tool.id) + tool_outputs.append(handler.handle_function(tool.function)) run = self.current_run assert run, "No ongoing run" @@ -223,3 +283,26 @@ def _handle_action(self, run_id: str, data: Any) -> None: event_handler=self.clone(), ) as stream: stream.until_done() + + +class _ToolOutput(TypedDict): + tool_call_id: str + output: str + + +class _ThreadToolHandler(_ToolHandler[_ToolOutput]): + def __init__(self, toolbox: Toolbox, call_id: str) -> None: + super().__init__(toolbox) + self._call_id = call_id + + def _wrap(self, output: str) -> _ToolOutput: + return _ToolOutput(tool_call_id=self._call_id, output=output) + + def _on_read_file(self, path: PurePosixPath, contents: str) -> _ToolOutput: + return self._wrap(contents) + + def _on_write_file(self, path: PurePosixPath) -> _ToolOutput: + return self._wrap("OK") + + def _on_list_files(self, paths: Sequence[PurePosixPath]) -> _ToolOutput: + return self._wrap("\n".join((str(p) for p in paths))) diff --git a/src/git_draft/common.py b/src/git_draft/common.py index 4c67444..33df3ce 100644 --- a/src/git_draft/common.py +++ b/src/git_draft/common.py @@ -71,14 +71,16 @@ class BotConfig: def random_id(n: int) -> str: + """Generates a random length n string of lowercase letters and digits""" return "".join(_random.choices(_alphabet, k=n)) class UnreachableError(RuntimeError): - pass + """Indicates unreachable code was unexpectedly executed""" def reindent(s: str, width=0) -> str: + """Reindents text by dedenting and optionally wrapping paragraphs""" paragraphs = ( " ".join(textwrap.dedent("\n".join(g)).splitlines()) for b, g in itertools.groupby(s.splitlines(), bool) From c1cdac40963c43e1b3a000204a59dbcd3befad7f Mon Sep 17 00:00:00 2001 From: Matthieu Monsch Date: Wed, 5 Mar 2025 20:00:31 -0800 Subject: [PATCH 7/8] fixup! feat: add completions-based OpenAI bot --- src/git_draft/bots/openai.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/git_draft/bots/openai.py b/src/git_draft/bots/openai.py index 3c11cf3..7de800a 100644 --- a/src/git_draft/bots/openai.py +++ b/src/git_draft/bots/openai.py @@ -122,6 +122,9 @@ def params(self) -> Sequence[openai.types.chat.ChatCompletionToolParam]: You are an expert software engineer, who writes correct and concise code. Use the provided functions to find the files you need to answer the query, read the content of the relevant ones, and save the changes you suggest. + + You should stop when and ONLY WHEN all the files you need to change have + been updated. """ @@ -175,29 +178,30 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action: model=self._model, messages=messages, tools=tools, + tool_choice="required", ) + assert len(response.choices) == 1 + done = True calls = response.choices[0].message.tool_calls - if not calls: + for call in calls or []: + output = tool_handler.handle_function(call.function) + if output is not None: + done = False + messages.append({"role": "user", "content": output}) + if done: break - for call in calls: - messages.append( - { - "role": "user", - "content": tool_handler.handle_function(call.function), - } - ) return Action() -class _CompletionsToolHandler(_ToolHandler[str]): +class _CompletionsToolHandler(_ToolHandler[str | None]): def _on_read_file(self, path: PurePosixPath, contents: str) -> str: return ( f"Here are the contents of `{path}`:\n\n```\n{contents}\n```\n" "" ) - def _on_write_file(self, path: PurePosixPath) -> str: - return f"The file at {path} has been updated to match." + def _on_write_file(self, path: PurePosixPath) -> None: + return None def _on_list_files(self, paths: Sequence[PurePosixPath]) -> str: joined = "\n".join(f"* {p}" for p in paths) From 537600fc739fa0d508f55e01c8d02f25b202358e Mon Sep 17 00:00:00 2001 From: Matthieu Monsch Date: Wed, 5 Mar 2025 20:11:15 -0800 Subject: [PATCH 8/8] fixup! feat: add completions-based OpenAI bot --- src/git_draft/bots/__init__.py | 2 +- src/git_draft/bots/openai.py | 1 + tests/git_draft/bots/__init___test.py | 13 +++++++++---- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/git_draft/bots/__init__.py b/src/git_draft/bots/__init__.py index a614a57..86a2cd6 100644 --- a/src/git_draft/bots/__init__.py +++ b/src/git_draft/bots/__init__.py @@ -43,7 +43,7 @@ def load_bot(config: BotConfig | None) -> Bot: def _default_bot() -> Bot: - if "OPENAI_API_KEY" not in os.environ: + if not os.environ.get("OPENAI_API_KEY"): raise RuntimeError( reindent( """ diff --git a/src/git_draft/bots/openai.py b/src/git_draft/bots/openai.py index 7de800a..31494b7 100644 --- a/src/git_draft/bots/openai.py +++ b/src/git_draft/bots/openai.py @@ -181,6 +181,7 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action: tool_choice="required", ) assert len(response.choices) == 1 + done = True calls = response.choices[0].message.tool_calls for call in calls or []: diff --git a/tests/git_draft/bots/__init___test.py b/tests/git_draft/bots/__init___test.py index 6473beb..f9749eb 100644 --- a/tests/git_draft/bots/__init___test.py +++ b/tests/git_draft/bots/__init___test.py @@ -2,11 +2,11 @@ import sys import pytest -from git_draft.bots import Bot, load_bot +import git_draft.bots as sut from git_draft.common import BotConfig -class FakeBot(Bot): +class FakeBot(sut.Bot): pass @@ -19,10 +19,15 @@ def import_module(name): monkeypatch.setattr(importlib, "import_module", import_module) config = BotConfig(factory="fake_module:FakeBot") - bot = load_bot(config) + bot = sut.load_bot(config) assert isinstance(bot, FakeBot) def test_non_existing_factory(self) -> None: config = BotConfig("git_draft:unknown_factory") with pytest.raises(NotImplementedError): - load_bot(config) + sut.load_bot(config) + + def test_default_no_key(self, monkeypatch) -> None: + monkeypatch.setenv("OPENAI_API_KEY", "") + with pytest.raises(RuntimeError): + sut.load_bot(None)