Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions src/git_draft/bots/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""OpenAI API-backed bots

They can be used with services other than OpenAPI as long as them implement a
They can be used with services other than OpenAI as long as them implement a
sufficient subset of the API. For example the `completions_bot` only requires
tools support.

Expand All @@ -15,13 +15,12 @@
from collections.abc import Mapping, Sequence
import json
import logging
import os
from pathlib import PurePosixPath
from typing import Any, Self, TypedDict, override

import openai

from ..common import JSONObject, reindent
from ..common import JSONObject, config_string, reindent
from .common import Action, Bot, Goal, Toolbox


Expand All @@ -37,10 +36,7 @@ def completions_bot(
model: str = _DEFAULT_MODEL,
) -> Bot:
"""Compatibility-mode bot, uses completions with function calling"""
if api_key and api_key.startswith("$"):
api_key = os.environ[api_key[1:]]
client = openai.OpenAI(api_key=api_key, base_url=base_url)
return _CompletionsBot(client, model)
return _CompletionsBot(_new_client(api_key, base_url), model)


def threads_bot(
Expand All @@ -49,8 +45,14 @@ def threads_bot(
model: str = _DEFAULT_MODEL,
) -> Bot:
"""Beta bot, uses assistant threads with function calling"""
client = openai.OpenAI(api_key=api_key, base_url=base_url)
return _ThreadsBot.create(client, model)
return _ThreadsBot.create(_new_client(api_key, base_url), model)


def _new_client(api_key: str | None, base_url: str | None) -> openai.OpenAI:
return openai.OpenAI(
api_key=config_string(api_key) if api_key else None,
base_url=base_url,
)


class _ToolsFactory:
Expand Down
6 changes: 6 additions & 0 deletions src/git_draft/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dataclasses
import itertools
import logging
import os
from pathlib import Path
import random
import sqlite3
Expand Down Expand Up @@ -73,6 +74,11 @@ class BotConfig:
pythonpath: str | None = None


def config_string(arg: str) -> str:
"""Dereferences environment value if the input starts with `$`"""
return os.environ[arg[1:]] if arg and arg.startswith("$") else arg


_random = random.Random()
_alphabet = string.ascii_lowercase + string.digits

Expand Down
10 changes: 10 additions & 0 deletions tests/git_draft/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ def test_load_default(self) -> None:
assert config.log_level == logging.INFO


class TestConfigString:
def test_literal(self) -> None:
assert sut.config_string("") == ""
assert sut.config_string("abc") == "abc"

def test_evar(self, monkeypatch) -> None:
monkeypatch.setenv("FOO", "111")
assert sut.config_string("$FOO") == "111"


@pytest.mark.parametrize(
"text,width,want",
[
Expand Down