diff --git a/src/git_draft/bots/openai.py b/src/git_draft/bots/openai.py index b0c9919..7e994a5 100644 --- a/src/git_draft/bots/openai.py +++ b/src/git_draft/bots/openai.py @@ -1,6 +1,6 @@ """OpenAI API-backed bots -They can be used with services other than OpenAPI as long as them implement a +They can be used with services other than OpenAI as long as them implement a sufficient subset of the API. For example the `completions_bot` only requires tools support. @@ -15,13 +15,12 @@ from collections.abc import Mapping, Sequence import json import logging -import os from pathlib import PurePosixPath from typing import Any, Self, TypedDict, override import openai -from ..common import JSONObject, reindent +from ..common import JSONObject, config_string, reindent from .common import Action, Bot, Goal, Toolbox @@ -37,10 +36,7 @@ def completions_bot( model: str = _DEFAULT_MODEL, ) -> Bot: """Compatibility-mode bot, uses completions with function calling""" - if api_key and api_key.startswith("$"): - api_key = os.environ[api_key[1:]] - client = openai.OpenAI(api_key=api_key, base_url=base_url) - return _CompletionsBot(client, model) + return _CompletionsBot(_new_client(api_key, base_url), model) def threads_bot( @@ -49,8 +45,14 @@ def threads_bot( model: str = _DEFAULT_MODEL, ) -> Bot: """Beta bot, uses assistant threads with function calling""" - client = openai.OpenAI(api_key=api_key, base_url=base_url) - return _ThreadsBot.create(client, model) + return _ThreadsBot.create(_new_client(api_key, base_url), model) + + +def _new_client(api_key: str | None, base_url: str | None) -> openai.OpenAI: + return openai.OpenAI( + api_key=config_string(api_key) if api_key else None, + base_url=base_url, + ) class _ToolsFactory: diff --git a/src/git_draft/common.py b/src/git_draft/common.py index c8f9ed0..d21b533 100644 --- a/src/git_draft/common.py +++ b/src/git_draft/common.py @@ -6,6 +6,7 @@ import dataclasses import itertools import logging +import os from pathlib import Path import random import sqlite3 @@ -73,6 +74,11 @@ class BotConfig: pythonpath: str | None = None +def config_string(arg: str) -> str: + """Dereferences environment value if the input starts with `$`""" + return os.environ[arg[1:]] if arg and arg.startswith("$") else arg + + _random = random.Random() _alphabet = string.ascii_lowercase + string.digits diff --git a/tests/git_draft/common_test.py b/tests/git_draft/common_test.py index 7cd44f6..cc158a9 100644 --- a/tests/git_draft/common_test.py +++ b/tests/git_draft/common_test.py @@ -62,6 +62,16 @@ def test_load_default(self) -> None: assert config.log_level == logging.INFO +class TestConfigString: + def test_literal(self) -> None: + assert sut.config_string("") == "" + assert sut.config_string("abc") == "abc" + + def test_evar(self, monkeypatch) -> None: + monkeypatch.setenv("FOO", "111") + assert sut.config_string("$FOO") == "111" + + @pytest.mark.parametrize( "text,width,want", [