Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions autointent/generation/utterances/generator.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,54 @@
"""Wrapper class for accessing OpenAI API."""

import os
from typing import Any, ClassVar

import openai
from dotenv import load_dotenv

from .schemas import Message

load_dotenv()


class Generator:
"""Wrapper class for accessing OpenAI API."""

def __init__(self) -> None:
"""Initialize."""
load_dotenv()
self.client = openai.OpenAI(base_url=os.environ["OPENAI_BASE_URL"], api_key=os.environ["OPENAI_API_KEY"])
self.async_client = openai.AsyncOpenAI(
base_url=os.environ["OPENAI_BASE_URL"], api_key=os.environ["OPENAI_API_KEY"]
)
self.model_name = os.environ["OPENAI_MODEL_NAME"]
_default_generation_params: ClassVar[dict[str, Any]] = {
"max_tokens": 150,
"n": 1,
"stop": None,
"temperature": 0.7,
}

def __init__(self, base_url: str | None = None, model_name: str | None = None, **generation_params: Any) -> None: # noqa: ANN401
"""
Initialize the wrapper for LLM.

:param base_url: HTTP-endpoint for sending API requests to OpenAI API compatible server.
Omit this to infer OPENAI_BASE_URL from environment.
:param model_name: Name of LLM. Omit this to infer OPENAI_MODEL_NAME from environment.
:param generation_params: kwargs that will be sent with a request to the endpoint.
Omit this to use AutoIntent's default parameters.
"""
if not base_url:
base_url = os.environ["OPENAI_BASE_URL"]
if not model_name:
model_name = os.environ["OPENAI_MODEL_NAME"]
self.model_name = model_name
self.client = openai.OpenAI(base_url=base_url)
self.async_client = openai.AsyncOpenAI(base_url=base_url)
self.generation_params = {
**self._default_generation_params,
**generation_params,
} # https://stackoverflow.com/a/65539348

def get_chat_completion(self, messages: list[Message]) -> str:
"""Prompt LLM and return its answer synchronously."""
response = self.client.chat.completions.create(
messages=messages, # type: ignore[arg-type]
model=self.model_name,
max_tokens=150,
n=1,
stop=None,
temperature=0.7,
**self.generation_params,
)
return response.choices[0].message.content # type: ignore[return-value]

Expand All @@ -37,9 +57,6 @@ async def get_chat_completion_async(self, messages: list[Message]) -> str:
response = await self.async_client.chat.completions.create(
messages=messages, # type: ignore[arg-type]
model=self.model_name,
max_tokens=150,
n=1,
stop=None,
temperature=0.7,
**self.generation_params,
)
return response.choices[0].message.content # type: ignore[return-value]