11"""Wrapper class for accessing OpenAI API."""
22
33import os
4+ from typing import Any , ClassVar
45
56import openai
67from dotenv import load_dotenv
78
89from .schemas import Message
910
11+ load_dotenv ()
12+
1013
1114class Generator :
1215 """Wrapper class for accessing OpenAI API."""
1316
14- def __init__ (self ) -> None :
15- """Initialize."""
16- load_dotenv ()
17- self .client = openai .OpenAI (base_url = os .environ ["OPENAI_BASE_URL" ], api_key = os .environ ["OPENAI_API_KEY" ])
18- self .async_client = openai .AsyncOpenAI (
19- base_url = os .environ ["OPENAI_BASE_URL" ], api_key = os .environ ["OPENAI_API_KEY" ]
20- )
21- self .model_name = os .environ ["OPENAI_MODEL_NAME" ]
17+ _default_generation_params : ClassVar [dict [str , Any ]] = {
18+ "max_tokens" : 150 ,
19+ "n" : 1 ,
20+ "stop" : None ,
21+ "temperature" : 0.7 ,
22+ }
23+
24+ def __init__ (self , base_url : str | None = None , model_name : str | None = None , ** generation_params : Any ) -> None : # noqa: ANN401
25+ """
26+ Initialize the wrapper for LLM.
27+
28+ :param base_url: HTTP-endpoint for sending API requests to OpenAI API compatible server.
29+ Omit this to infer OPENAI_BASE_URL from environment.
30+ :param model_name: Name of LLM. Omit this to infer OPENAI_MODEL_NAME from environment.
31+ :param generation_params: kwargs that will be sent with a request to the endpoint.
32+ Omit this to use AutoIntent's default parameters.
33+ """
34+ if not base_url :
35+ base_url = os .environ ["OPENAI_BASE_URL" ]
36+ if not model_name :
37+ model_name = os .environ ["OPENAI_MODEL_NAME" ]
38+ self .model_name = model_name
39+ self .client = openai .OpenAI (base_url = base_url )
40+ self .async_client = openai .AsyncOpenAI (base_url = base_url )
41+ self .generation_params = {
42+ ** self ._default_generation_params ,
43+ ** generation_params ,
44+ } # https://stackoverflow.com/a/65539348
2245
2346 def get_chat_completion (self , messages : list [Message ]) -> str :
2447 """Prompt LLM and return its answer synchronously."""
2548 response = self .client .chat .completions .create (
2649 messages = messages , # type: ignore[arg-type]
2750 model = self .model_name ,
28- max_tokens = 150 ,
29- n = 1 ,
30- stop = None ,
31- temperature = 0.7 ,
51+ ** self .generation_params ,
3252 )
3353 return response .choices [0 ].message .content # type: ignore[return-value]
3454
@@ -37,9 +57,6 @@ async def get_chat_completion_async(self, messages: list[Message]) -> str:
3757 response = await self .async_client .chat .completions .create (
3858 messages = messages , # type: ignore[arg-type]
3959 model = self .model_name ,
40- max_tokens = 150 ,
41- n = 1 ,
42- stop = None ,
43- temperature = 0.7 ,
60+ ** self .generation_params ,
4461 )
4562 return response .choices [0 ].message .content # type: ignore[return-value]
0 commit comments