Skip to content

Commit c5db582

Browse files
authored
feat: support threads bot environment API key (#63)
1 parent 9e95e7a commit c5db582

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

src/git_draft/bots/openai.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""OpenAI API-backed bots
22
3-
They can be used with services other than OpenAPI as long as them implement a
3+
They can be used with services other than OpenAI as long as them implement a
44
sufficient subset of the API. For example the `completions_bot` only requires
55
tools support.
66
@@ -15,13 +15,12 @@
1515
from collections.abc import Mapping, Sequence
1616
import json
1717
import logging
18-
import os
1918
from pathlib import PurePosixPath
2019
from typing import Any, Self, TypedDict, override
2120

2221
import openai
2322

24-
from ..common import JSONObject, reindent
23+
from ..common import JSONObject, config_string, reindent
2524
from .common import Action, Bot, Goal, Toolbox
2625

2726

@@ -37,10 +36,7 @@ def completions_bot(
3736
model: str = _DEFAULT_MODEL,
3837
) -> Bot:
3938
"""Compatibility-mode bot, uses completions with function calling"""
40-
if api_key and api_key.startswith("$"):
41-
api_key = os.environ[api_key[1:]]
42-
client = openai.OpenAI(api_key=api_key, base_url=base_url)
43-
return _CompletionsBot(client, model)
39+
return _CompletionsBot(_new_client(api_key, base_url), model)
4440

4541

4642
def threads_bot(
@@ -49,8 +45,14 @@ def threads_bot(
4945
model: str = _DEFAULT_MODEL,
5046
) -> Bot:
5147
"""Beta bot, uses assistant threads with function calling"""
52-
client = openai.OpenAI(api_key=api_key, base_url=base_url)
53-
return _ThreadsBot.create(client, model)
48+
return _ThreadsBot.create(_new_client(api_key, base_url), model)
49+
50+
51+
def _new_client(api_key: str | None, base_url: str | None) -> openai.OpenAI:
52+
return openai.OpenAI(
53+
api_key=config_string(api_key) if api_key else None,
54+
base_url=base_url,
55+
)
5456

5557

5658
class _ToolsFactory:

src/git_draft/common.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import dataclasses
77
import itertools
88
import logging
9+
import os
910
from pathlib import Path
1011
import random
1112
import sqlite3
@@ -73,6 +74,11 @@ class BotConfig:
7374
pythonpath: str | None = None
7475

7576

77+
def config_string(arg: str) -> str:
78+
"""Dereferences environment value if the input starts with `$`"""
79+
return os.environ[arg[1:]] if arg and arg.startswith("$") else arg
80+
81+
7682
_random = random.Random()
7783
_alphabet = string.ascii_lowercase + string.digits
7884

tests/git_draft/common_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,16 @@ def test_load_default(self) -> None:
6262
assert config.log_level == logging.INFO
6363

6464

65+
class TestConfigString:
66+
def test_literal(self) -> None:
67+
assert sut.config_string("") == ""
68+
assert sut.config_string("abc") == "abc"
69+
70+
def test_evar(self, monkeypatch) -> None:
71+
monkeypatch.setenv("FOO", "111")
72+
assert sut.config_string("$FOO") == "111"
73+
74+
6575
@pytest.mark.parametrize(
6676
"text,width,want",
6777
[

0 commit comments

Comments
 (0)