Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions autointent/generation/utterances/generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Wrapper class for accessing OpenAI API."""

import os
from typing import Any, ClassVar

import openai
from dotenv import load_dotenv
Expand All @@ -11,24 +12,43 @@
class Generator:
"""Wrapper class for accessing OpenAI API."""

def __init__(self) -> None:
"""Initialize."""
load_dotenv()
self.client = openai.OpenAI(base_url=os.environ["OPENAI_BASE_URL"], api_key=os.environ["OPENAI_API_KEY"])
self.async_client = openai.AsyncOpenAI(
base_url=os.environ["OPENAI_BASE_URL"], api_key=os.environ["OPENAI_API_KEY"]
)
self.model_name = os.environ["OPENAI_MODEL_NAME"]
_default_generation_params: ClassVar[dict[str, Any]] = {
"max_tokens": 150,
"n": 1,
"stop": None,
"temperature": 0.7,
}

def __init__(self, base_url: str | None = None, model_name: str | None = None, **generation_params: Any) -> None: # noqa: ANN401
"""
Initialize the wrapper for LLM.

:param base_url: HTTP-endpoint for sending API requests to OpenAI API compatible server.
Omit this to infer OPENAI_BASE_URL from environment.
:param model_name: Name of LLM. Omit this to infer OPENAI_MODEL_NAME from environment.
:param generation_params: kwargs that will be sent with a request to the endpoint.
Omit this to use AutoIntent's default parameters.
"""
if not base_url:
load_dotenv()
base_url = os.environ["OPENAI_BASE_URL"]
if not model_name:
load_dotenv()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Я бы не добавлял

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

в первый иф оно же может не зайти

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Я бы просто вообще убрал его, но тк оно все равно есть я бы вынес перед классом

model_name = os.environ["OPENAI_MODEL_NAME"]
self.model_name = model_name
self.client = openai.OpenAI(base_url=base_url)
self.async_client = openai.AsyncOpenAI(base_url=base_url)
self.generation_params = {
**self._default_generation_params,
**generation_params,
} # https://stackoverflow.com/a/65539348

def get_chat_completion(self, messages: list[Message]) -> str:
"""Prompt LLM and return its answer synchronously."""
response = self.client.chat.completions.create(
messages=messages, # type: ignore[arg-type]
model=self.model_name,
max_tokens=150,
n=1,
stop=None,
temperature=0.7,
**self.generation_params,
)
return response.choices[0].message.content # type: ignore[return-value]

Expand All @@ -37,9 +57,6 @@ async def get_chat_completion_async(self, messages: list[Message]) -> str:
response = await self.async_client.chat.completions.create(
messages=messages, # type: ignore[arg-type]
model=self.model_name,
max_tokens=150,
n=1,
stop=None,
temperature=0.7,
**self.generation_params,
)
return response.choices[0].message.content # type: ignore[return-value]