11"""Wrapper class for accessing OpenAI API."""
22
3+ import json
34import logging
45import os
6+ from pathlib import Path
57from textwrap import dedent
6- from typing import Any , ClassVar , Literal , TypeVar
8+ from typing import Any , ClassVar , Literal , TypedDict , TypeVar
79
810import openai
911from dotenv import load_dotenv
2224"""Type variable for Pydantic models used in structured output generation."""
2325
2426
27+ class GeneratorDumpData (TypedDict ):
28+ use_cache : bool
29+ model_name : str
30+ base_url : str | None
31+ generation_params : dict [str , Any ]
32+
33+
34+ class RetriesExceededError (RuntimeError ):
35+ """Exception raised when LLM call fails after all retry attempts."""
36+
37+ def __init__ (self , max_retries : int , messages : list [Message ]) -> None :
38+ """Initialize the error with retry count and messages.
39+
40+ Args:
41+ max_retries: Maximum number of retry attempts that were made
42+ messages: Messages that were sent to the LLM
43+ """
44+ msg = f"LLM call failed after { max_retries + 1 } attempts. Messages: { messages } "
45+ super ().__init__ (msg )
46+
47+
2548class Generator :
2649 """Wrapper class for accessing OpenAI API.
2750
@@ -32,6 +55,8 @@ class Generator:
3255 **generation_params: kwargs that will be sent with a request to the endpoint.
3356 """
3457
58+ _dump_data_filename = "init_params.json"
59+
3560 _default_generation_params : ClassVar [dict [str , Any ]] = {
3661 "max_tokens" : 150 ,
3762 "n" : 1 ,
@@ -57,17 +82,23 @@ def __init__(
5782 """
5883 base_url = base_url or os .getenv ("OPENAI_BASE_URL" )
5984 model_name = model_name or os .getenv ("OPENAI_MODEL_NAME" )
85+
6086 if model_name is None :
6187 msg = "Specify model_name arg or OPENAI_MODEL_NAME environment variable"
6288 raise ValueError (msg )
89+
6390 self .model_name = model_name
91+ self .base_url = base_url
92+ self .use_cache = use_cache
93+
6494 self .client = openai .OpenAI (base_url = base_url )
6595 self .async_client = openai .AsyncOpenAI (base_url = base_url )
96+ self .cache = StructuredOutputCache (use_cache = use_cache )
97+
6698 self .generation_params = {
6799 ** self ._default_generation_params ,
68100 ** generation_params ,
69101 } # https://stackoverflow.com/a/65539348
70- self .cache = StructuredOutputCache (use_cache = use_cache )
71102
72103 def get_chat_completion (self , messages : list [Message ]) -> str :
73104 """Prompt LLM and return its answer.
@@ -221,12 +252,8 @@ async def get_structured_output_async(
221252 current_messages .extend (self ._create_retry_messages (error , raw ))
222253
223254 if res is None :
224- msg = (
225- f"Failed to generate valid structured output after { max_retries + 1 } attempts.\n "
226- f"Messages: { current_messages } "
227- )
228255 logger .exception (msg )
229- raise RuntimeError ( msg )
256+ raise RetriesExceededError ( max_retries = max_retries , messages = current_messages )
230257
231258 # Cache the successful result
232259 self .cache .set (messages , output_model , backend , self .generation_params , res )
@@ -338,14 +365,32 @@ def get_structured_output_sync(
338365 current_messages .extend (self ._create_retry_messages (error , raw ))
339366
340367 if res is None :
341- msg = (
342- f"Failed to generate valid structured output after { max_retries + 1 } attempts.\n "
343- f"Messages: { current_messages } "
344- )
345368 logger .exception (msg )
346- raise RuntimeError ( msg )
369+ raise RetriesExceededError ( max_retries = max_retries , messages = current_messages )
347370
348371 # Cache the successful result
349372 self .cache .set (messages , output_model , backend , self .generation_params , res )
350373
351374 return res
375+
376+ def dump (self , path : Path , exist_ok : bool = True ) -> None :
377+ data : GeneratorDumpData = {
378+ "base_url" : self .base_url ,
379+ "generation_params" : self .generation_params ,
380+ "model_name" : self .model_name ,
381+ "use_cache" : self .use_cache ,
382+ }
383+
384+ path .mkdir (exist_ok = exist_ok , parents = True )
385+
386+ with (path / self ._dump_data_filename ).open ("w" , encoding = "utf-8" ) as file :
387+ json .dump (data , file , indent = 4 , ensure_ascii = False )
388+
389+ @classmethod
390+ def load (cls , path : Path ) -> "Generator" :
391+ with (path / cls ._dump_data_filename ).open (encoding = "utf-8" ) as file :
392+ data : GeneratorDumpData = json .load (file )
393+
394+ generation_params = data .pop ("generation_params" ) # type: ignore[misc]
395+
396+ return cls (** data , ** generation_params )
0 commit comments