Skip to content

Commit 97b07aa

Browse files
authored
Make LLM Generator configurable (#138)
* refactor `Generator` * bug fix * respond to samoed
1 parent b055bbc commit 97b07aa

File tree

1 file changed

+33
-16
lines changed

1 file changed

+33
-16
lines changed
Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,54 @@
11
"""Wrapper class for accessing OpenAI API."""
22

33
import os
4+
from typing import Any, ClassVar
45

56
import openai
67
from dotenv import load_dotenv
78

89
from .schemas import Message
910

11+
load_dotenv()
12+
1013

1114
class Generator:
1215
"""Wrapper class for accessing OpenAI API."""
1316

14-
def __init__(self) -> None:
15-
"""Initialize."""
16-
load_dotenv()
17-
self.client = openai.OpenAI(base_url=os.environ["OPENAI_BASE_URL"], api_key=os.environ["OPENAI_API_KEY"])
18-
self.async_client = openai.AsyncOpenAI(
19-
base_url=os.environ["OPENAI_BASE_URL"], api_key=os.environ["OPENAI_API_KEY"]
20-
)
21-
self.model_name = os.environ["OPENAI_MODEL_NAME"]
17+
_default_generation_params: ClassVar[dict[str, Any]] = {
18+
"max_tokens": 150,
19+
"n": 1,
20+
"stop": None,
21+
"temperature": 0.7,
22+
}
23+
24+
def __init__(self, base_url: str | None = None, model_name: str | None = None, **generation_params: Any) -> None: # noqa: ANN401
25+
"""
26+
Initialize the wrapper for LLM.
27+
28+
:param base_url: HTTP-endpoint for sending API requests to OpenAI API compatible server.
29+
Omit this to infer OPENAI_BASE_URL from environment.
30+
:param model_name: Name of LLM. Omit this to infer OPENAI_MODEL_NAME from environment.
31+
:param generation_params: kwargs that will be sent with a request to the endpoint.
32+
Omit this to use AutoIntent's default parameters.
33+
"""
34+
if not base_url:
35+
base_url = os.environ["OPENAI_BASE_URL"]
36+
if not model_name:
37+
model_name = os.environ["OPENAI_MODEL_NAME"]
38+
self.model_name = model_name
39+
self.client = openai.OpenAI(base_url=base_url)
40+
self.async_client = openai.AsyncOpenAI(base_url=base_url)
41+
self.generation_params = {
42+
**self._default_generation_params,
43+
**generation_params,
44+
} # https://stackoverflow.com/a/65539348
2245

2346
def get_chat_completion(self, messages: list[Message]) -> str:
2447
"""Prompt LLM and return its answer synchronously."""
2548
response = self.client.chat.completions.create(
2649
messages=messages, # type: ignore[arg-type]
2750
model=self.model_name,
28-
max_tokens=150,
29-
n=1,
30-
stop=None,
31-
temperature=0.7,
51+
**self.generation_params,
3252
)
3353
return response.choices[0].message.content # type: ignore[return-value]
3454

@@ -37,9 +57,6 @@ async def get_chat_completion_async(self, messages: list[Message]) -> str:
3757
response = await self.async_client.chat.completions.create(
3858
messages=messages, # type: ignore[arg-type]
3959
model=self.model_name,
40-
max_tokens=150,
41-
n=1,
42-
stop=None,
43-
temperature=0.7,
60+
**self.generation_params,
4461
)
4562
return response.choices[0].message.content # type: ignore[return-value]

0 commit comments

Comments
 (0)