11from __future__ import annotations
22
3+ import asyncio
4+ import logging
35import os
46import typing as t
57from abc import abstractmethod
68from dataclasses import dataclass , field
79
10+ import openai
811from langchain .adapters .openai import convert_message_to_dict
12+ from langchain .callbacks .manager import (
13+ AsyncCallbackManagerForLLMRun ,
14+ CallbackManagerForLLMRun ,
15+ )
916from langchain .schema import Generation , LLMResult
1017from openai import AsyncAzureOpenAI , AsyncClient , AsyncOpenAI
18+ from tenacity import (
19+ RetryCallState ,
20+ before_sleep_log ,
21+ retry ,
22+ retry_base ,
23+ retry_if_exception_type ,
24+ stop_after_attempt ,
25+ wait_exponential ,
26+ )
1127
1228from ragas .async_utils import run_async_tasks
1329from ragas .exceptions import AzureOpenAIKeyNotFound , OpenAIKeyNotFound
1430from ragas .llms .base import RagasLLM
1531from ragas .llms .langchain import _compute_token_usage_langchain
16- from ragas .utils import NO_KEY
32+ from ragas .utils import NO_KEY , get_debug_mode
1733
1834if t .TYPE_CHECKING :
1935 from langchain .callbacks .base import Callbacks
2036 from langchain .prompts import ChatPromptTemplate
2137
38+ logger = logging .getLogger (__name__ )
39+
40+ errors = [
41+ openai .APITimeoutError ,
42+ openai .APIConnectionError ,
43+ openai .RateLimitError ,
44+ openai .APIConnectionError ,
45+ openai .InternalServerError ,
46+ ]
47+
48+
49+ def create_base_retry_decorator (
50+ error_types : t .List [t .Type [BaseException ]],
51+ max_retries : int = 1 ,
52+ run_manager : t .Optional [
53+ t .Union [AsyncCallbackManagerForLLMRun , CallbackManagerForLLMRun ]
54+ ] = None ,
55+ ) -> t .Callable [[t .Any ], t .Any ]:
56+ """Create a retry decorator for a given LLM and provided list of error types."""
57+
58+ log_level = logging .WARNING if get_debug_mode () else logging .DEBUG
59+ _logging = before_sleep_log (logger , log_level )
60+
61+ def _before_sleep (retry_state : RetryCallState ) -> None :
62+ _logging (retry_state )
63+ if run_manager :
64+ if isinstance (run_manager , AsyncCallbackManagerForLLMRun ):
65+ coro = run_manager .on_retry (retry_state )
66+ try :
67+ loop = asyncio .get_event_loop ()
68+ if loop .is_running ():
69+ loop .create_task (coro )
70+ else :
71+ asyncio .run (coro )
72+ except Exception as e :
73+ logger .error (f"Error in on_retry: { e } " )
74+ else :
75+ run_manager .on_retry (retry_state )
76+ return None
77+
78+ min_seconds = 4
79+ max_seconds = 10
80+ # Wait 2^x * 1 second between each retry starting with
81+ # 4 seconds, then up to 10 seconds, then 10 seconds afterwards
82+ retry_instance : "retry_base" = retry_if_exception_type (error_types [0 ])
83+ for error in error_types [1 :]:
84+ retry_instance = retry_instance | retry_if_exception_type (error )
85+ return retry (
86+ reraise = True ,
87+ stop = stop_after_attempt (max_retries ),
88+ wait = wait_exponential (multiplier = 1 , min = min_seconds , max = max_seconds ),
89+ retry = retry_instance ,
90+ before_sleep = _before_sleep ,
91+ )
92+
93+
94+ retry_decorator = create_base_retry_decorator (errors , max_retries = 4 )
95+
2296
2397class OpenAIBase (RagasLLM ):
24- def __init__ (self , model : str , _api_key_env_var : str ) -> None :
98+ def __init__ (self , model : str , _api_key_env_var : str , timeout : int = 60 ) -> None :
2599 self .model = model
26100 self ._api_key_env_var = _api_key_env_var
101+ self .timeout = timeout
27102
28103 # api key
29104 key_from_env = os .getenv (self ._api_key_env_var , NO_KEY )
@@ -83,6 +158,7 @@ def generate(
83158 llm_output = _compute_token_usage_langchain (llm_results )
84159 return LLMResult (generations = generations , llm_output = llm_output )
85160
161+ @retry_decorator
86162 async def agenerate (
87163 self ,
88164 prompt : ChatPromptTemplate ,
@@ -112,9 +188,13 @@ def __post_init__(self):
112188 self ._client_init ()
113189
114190 def _client_init (self ):
115- self ._client = AsyncOpenAI (api_key = self .api_key )
191+ self ._client = AsyncOpenAI (api_key = self .api_key , timeout = self . timeout )
116192
117193 def validate_api_key (self ):
194+ # before validating, check if the api key is already set
195+ api_key = os .getenv (self ._api_key_env_var , NO_KEY )
196+ if api_key != NO_KEY :
197+ self ._client .api_key = api_key
118198 if self .llm .api_key == NO_KEY :
119199 raise OpenAIKeyNotFound
120200
@@ -136,6 +216,7 @@ def _client_init(self):
136216 api_version = self .api_version ,
137217 azure_endpoint = self .azure_endpoint ,
138218 api_key = self .api_key ,
219+ timeout = self .timeout ,
139220 )
140221
141222 def validate_api_key (self ):
0 commit comments