Skip to content

Commit 7ffc2f0

Browse files
authored
fix: add retry logic for OpenAI and Azure OpenAI (#315)
1 parent de8a1f0 commit 7ffc2f0

File tree

2 files changed

+84
-4
lines changed

2 files changed

+84
-4
lines changed

src/ragas/llms/openai.py

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,104 @@
11
from __future__ import annotations
22

3+
import asyncio
4+
import logging
35
import os
46
import typing as t
57
from abc import abstractmethod
68
from dataclasses import dataclass, field
79

10+
import openai
811
from langchain.adapters.openai import convert_message_to_dict
12+
from langchain.callbacks.manager import (
13+
AsyncCallbackManagerForLLMRun,
14+
CallbackManagerForLLMRun,
15+
)
916
from langchain.schema import Generation, LLMResult
1017
from openai import AsyncAzureOpenAI, AsyncClient, AsyncOpenAI
18+
from tenacity import (
19+
RetryCallState,
20+
before_sleep_log,
21+
retry,
22+
retry_base,
23+
retry_if_exception_type,
24+
stop_after_attempt,
25+
wait_exponential,
26+
)
1127

1228
from ragas.async_utils import run_async_tasks
1329
from ragas.exceptions import AzureOpenAIKeyNotFound, OpenAIKeyNotFound
1430
from ragas.llms.base import RagasLLM
1531
from ragas.llms.langchain import _compute_token_usage_langchain
16-
from ragas.utils import NO_KEY
32+
from ragas.utils import NO_KEY, get_debug_mode
1733

1834
if t.TYPE_CHECKING:
1935
from langchain.callbacks.base import Callbacks
2036
from langchain.prompts import ChatPromptTemplate
2137

38+
logger = logging.getLogger(__name__)
39+
40+
errors = [
41+
openai.APITimeoutError,
42+
openai.APIConnectionError,
43+
openai.RateLimitError,
44+
openai.APIConnectionError,
45+
openai.InternalServerError,
46+
]
47+
48+
49+
def create_base_retry_decorator(
50+
error_types: t.List[t.Type[BaseException]],
51+
max_retries: int = 1,
52+
run_manager: t.Optional[
53+
t.Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
54+
] = None,
55+
) -> t.Callable[[t.Any], t.Any]:
56+
"""Create a retry decorator for a given LLM and provided list of error types."""
57+
58+
log_level = logging.WARNING if get_debug_mode() else logging.DEBUG
59+
_logging = before_sleep_log(logger, log_level)
60+
61+
def _before_sleep(retry_state: RetryCallState) -> None:
62+
_logging(retry_state)
63+
if run_manager:
64+
if isinstance(run_manager, AsyncCallbackManagerForLLMRun):
65+
coro = run_manager.on_retry(retry_state)
66+
try:
67+
loop = asyncio.get_event_loop()
68+
if loop.is_running():
69+
loop.create_task(coro)
70+
else:
71+
asyncio.run(coro)
72+
except Exception as e:
73+
logger.error(f"Error in on_retry: {e}")
74+
else:
75+
run_manager.on_retry(retry_state)
76+
return None
77+
78+
min_seconds = 4
79+
max_seconds = 10
80+
# Wait 2^x * 1 second between each retry starting with
81+
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
82+
retry_instance: "retry_base" = retry_if_exception_type(error_types[0])
83+
for error in error_types[1:]:
84+
retry_instance = retry_instance | retry_if_exception_type(error)
85+
return retry(
86+
reraise=True,
87+
stop=stop_after_attempt(max_retries),
88+
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
89+
retry=retry_instance,
90+
before_sleep=_before_sleep,
91+
)
92+
93+
94+
retry_decorator = create_base_retry_decorator(errors, max_retries=4)
95+
2296

2397
class OpenAIBase(RagasLLM):
24-
def __init__(self, model: str, _api_key_env_var: str) -> None:
98+
def __init__(self, model: str, _api_key_env_var: str, timeout: int = 60) -> None:
2599
self.model = model
26100
self._api_key_env_var = _api_key_env_var
101+
self.timeout = timeout
27102

28103
# api key
29104
key_from_env = os.getenv(self._api_key_env_var, NO_KEY)
@@ -83,6 +158,7 @@ def generate(
83158
llm_output = _compute_token_usage_langchain(llm_results)
84159
return LLMResult(generations=generations, llm_output=llm_output)
85160

161+
@retry_decorator
86162
async def agenerate(
87163
self,
88164
prompt: ChatPromptTemplate,
@@ -112,9 +188,13 @@ def __post_init__(self):
112188
self._client_init()
113189

114190
def _client_init(self):
115-
self._client = AsyncOpenAI(api_key=self.api_key)
191+
self._client = AsyncOpenAI(api_key=self.api_key, timeout=self.timeout)
116192

117193
def validate_api_key(self):
194+
# before validating, check if the api key is already set
195+
api_key = os.getenv(self._api_key_env_var, NO_KEY)
196+
if api_key != NO_KEY:
197+
self._client.api_key = api_key
118198
if self.llm.api_key == NO_KEY:
119199
raise OpenAIKeyNotFound
120200

@@ -136,6 +216,7 @@ def _client_init(self):
136216
api_version=self.api_version,
137217
azure_endpoint=self.azure_endpoint,
138218
api_key=self.api_key,
219+
timeout=self.timeout,
139220
)
140221

141222
def validate_api_key(self):

src/ragas/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
@lru_cache(maxsize=1)
1313
def get_debug_mode() -> bool:
1414
if os.environ.get(DEBUG_ENV_VAR, str(False)).lower() == "true":
15-
logging.basicConfig(level=logging.DEBUG)
1615
return True
1716
else:
1817
return False

0 commit comments

Comments
 (0)