55import os
66from pathlib import Path
77from textwrap import dedent
8- from typing import Any , ClassVar , Literal , TypedDict , TypeVar
8+ from typing import Any , Literal , TypedDict , TypeVar
99
1010import openai
1111from dotenv import load_dotenv
@@ -57,27 +57,21 @@ class Generator:
5757
5858 _dump_data_filename = "init_params.json"
5959
60- _default_generation_params : ClassVar [dict [str , Any ]] = {
61- "max_tokens" : 150 ,
62- "n" : 1 ,
63- "stop" : None ,
64- "temperature" : 0.7 ,
65- }
66- """Default generation parameters for API requests."""
67-
6860 def __init__ (
6961 self ,
7062 base_url : str | None = None ,
7163 model_name : str | None = None ,
7264 use_cache : bool = True ,
73- ** generation_params : Any , # noqa: ANN401
65+ client_params : dict [str , Any ] | None = None ,
66+ ** generation_params : dict [str , Any ],
7467 ) -> None :
7568 """Initialize the Generator with API configuration.
7669
7770 Args:
7871 base_url: OpenAI API compatible server URL.
7972 model_name: Name of the language model to use.
8073 use_cache: Whether to use caching for structured outputs.
74+ client_params: Additional parameters for client.
8175 **generation_params: Additional generation parameters to override defaults passed to OpenAI completions API.
8276 """
8377 base_url = base_url or os .getenv ("OPENAI_BASE_URL" )
@@ -91,27 +85,23 @@ def __init__(
9185 self .base_url = base_url
9286 self .use_cache = use_cache
9387
94- self .client = openai .OpenAI (base_url = base_url )
95- self .async_client = openai .AsyncOpenAI (base_url = base_url )
88+ self .client = openai .OpenAI (base_url = base_url , ** (client_params or {}))
89+ self .async_client = openai .AsyncOpenAI (base_url = base_url , ** (client_params or {}))
90+ self .generation_params = generation_params
9691 self .cache = StructuredOutputCache (use_cache = use_cache )
9792
98- self .generation_params = {
99- ** self ._default_generation_params ,
100- ** generation_params ,
101- } # https://stackoverflow.com/a/65539348
102-
10393 def get_chat_completion (self , messages : list [Message ]) -> str :
10494 """Prompt LLM and return its answer.
10595
10696 Args:
10797 messages: List of messages to send to the model.
10898 """
10999 response = self .client .chat .completions .create (
110- messages = messages , # type: ignore[arg-type ]
100+ messages = messages , # type: ignore[call-overload ]
111101 model = self .model_name ,
112102 ** self .generation_params ,
113103 )
114- return response .choices [0 ].message .content # type: ignore[return-value ]
104+ return response .choices [0 ].message .content # type: ignore[no-any-return ]
115105
116106 async def get_chat_completion_async (self , messages : list [Message ]) -> str :
117107 """Prompt LLM and return its answer asynchronously.
@@ -120,11 +110,15 @@ async def get_chat_completion_async(self, messages: list[Message]) -> str:
120110 messages: List of messages to send to the model.
121111 """
122112 response = await self .async_client .chat .completions .create (
123- messages = messages , # type: ignore[arg-type ]
113+ messages = messages , # type: ignore[call-overload ]
124114 model = self .model_name ,
125115 ** self .generation_params ,
126116 )
127- return response .choices [0 ].message .content # type: ignore[return-value]
117+
118+ if response is None or not response .choices :
119+ msg = "No response received from the model."
120+ raise RuntimeError (msg )
121+ return response .choices [0 ].message .content # type: ignore[no-any-return]
128122
129123 def _create_retry_messages (self , error_message : str , raw : str | None ) -> list [Message ]:
130124 """Create a follow-up message for retry with error details and schema."""
@@ -168,7 +162,7 @@ async def _get_structured_output_openai_async(
168162 model = self .model_name ,
169163 messages = messages , # type: ignore[arg-type]
170164 response_format = output_model ,
171- ** self .generation_params ,
165+ ** self .generation_params , # type: ignore[arg-type]
172166 )
173167 raw = response .choices [0 ].message .content
174168 res = response .choices [0 ].message .parsed
@@ -194,12 +188,12 @@ async def _get_structured_output_vllm_async(
194188 json_schema = output_model .model_json_schema ()
195189 response = await self .async_client .chat .completions .create (
196190 model = self .model_name ,
197- messages = messages , # type: ignore[arg-type ]
191+ messages = messages , # type: ignore[call-overload ]
198192 extra_body = {"guided_json" : json_schema },
199193 ** self .generation_params ,
200194 )
201195 raw = response .choices [0 ].message .content
202- res = output_model .model_validate_json (raw ) # type: ignore[arg-type]
196+ res = output_model .model_validate_json (raw )
203197 except (ValidationError , ValueError ) as e :
204198 msg = f"Failed to obtain structured output for model { self .model_name } and messages { messages } : { e !s} "
205199 logger .warning (msg )
@@ -252,6 +246,10 @@ async def get_structured_output_async(
252246 current_messages .extend (self ._create_retry_messages (error , raw ))
253247
254248 if res is None :
249+ msg = (
250+ f"Failed to generate valid structured output after { max_retries + 1 } attempts.\n "
251+ f"Messages: { current_messages } "
252+ )
255253 logger .exception (msg )
256254 raise RetriesExceededError (max_retries = max_retries , messages = current_messages )
257255
@@ -281,7 +279,7 @@ def _get_structured_output_openai_sync(
281279 model = self .model_name ,
282280 messages = messages , # type: ignore[arg-type]
283281 response_format = output_model ,
284- ** self .generation_params ,
282+ ** self .generation_params , # type: ignore[arg-type]
285283 )
286284 raw = response .choices [0 ].message .content
287285 res = response .choices [0 ].message .parsed
@@ -307,12 +305,12 @@ def _get_structured_output_vllm_sync(
307305 json_schema = output_model .model_json_schema ()
308306 response = self .client .chat .completions .create (
309307 model = self .model_name ,
310- messages = messages , # type: ignore[arg-type ]
308+ messages = messages , # type: ignore[call-overload ]
311309 extra_body = {"guided_json" : json_schema },
312310 ** self .generation_params ,
313311 )
314312 raw = response .choices [0 ].message .content
315- res = output_model .model_validate_json (raw ) # type: ignore[arg-type]
313+ res = output_model .model_validate_json (raw )
316314 except (ValidationError , ValueError ) as e :
317315 msg = f"Failed to obtain structured output for model { self .model_name } and messages { messages } : { e !s} "
318316 logger .warning (msg )
@@ -365,6 +363,7 @@ def get_structured_output_sync(
365363 current_messages .extend (self ._create_retry_messages (error , raw ))
366364
367365 if res is None :
366+ msg = "Structured output returned None but no error was caught."
368367 logger .exception (msg )
369368 raise RetriesExceededError (max_retries = max_retries , messages = current_messages )
370369
0 commit comments