Skip to content

Commit 933441c

Browse files
authored
Add retry to OpenAI llm (#849)
add ability to retry when certain exceptions are raised by `openai.Completions.create` Test plan: ran all OpenAI integration tests.
1 parent 4a8f5cd commit 933441c

File tree

3 files changed

+59
-3
lines changed

3 files changed

+59
-3
lines changed

langchain/llms/openai.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44
from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
55

66
from pydantic import BaseModel, Extra, Field, root_validator
7+
from tenacity import (
8+
after_log,
9+
retry,
10+
retry_if_exception_type,
11+
stop_after_attempt,
12+
wait_exponential,
13+
)
714

815
from langchain.llms.base import BaseLLM
916
from langchain.schema import Generation, LLMResult
@@ -56,6 +63,8 @@ class BaseOpenAI(BaseLLM, BaseModel):
5663
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
5764
logit_bias: Optional[Dict[str, float]] = Field(default_factory=dict)
5865
"""Adjust the probability of specific tokens being generated."""
66+
max_retries: int = 6
67+
"""Maximum number of retries to make when generating."""
5968

6069
class Config:
6170
"""Configuration for this pydantic object."""
@@ -115,6 +124,32 @@ def _default_params(self) -> Dict[str, Any]:
115124
}
116125
return {**normal_params, **self.model_kwargs}
117126

127+
def completion_with_retry(self, **kwargs: Any) -> Any:
128+
"""Use tenacity to retry the completion call."""
129+
import openai
130+
131+
min_seconds = 4
132+
max_seconds = 10
133+
# Wait 2^x * 1 second between each retry starting with
134+
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
135+
136+
@retry(
137+
reraise=True,
138+
stop=stop_after_attempt(self.max_retries),
139+
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
140+
retry=(
141+
retry_if_exception_type(openai.error.Timeout)
142+
| retry_if_exception_type(openai.error.APIError)
143+
| retry_if_exception_type(openai.error.APIConnectionError)
144+
| retry_if_exception_type(openai.error.RateLimitError)
145+
),
146+
after=after_log(logger, logging.DEBUG),
147+
)
148+
def _completion_with_retry(**kwargs: Any) -> Any:
149+
return self.client.create(**kwargs)
150+
151+
return _completion_with_retry(**kwargs)
152+
118153
def _generate(
119154
self, prompts: List[str], stop: Optional[List[str]] = None
120155
) -> LLMResult:
@@ -155,7 +190,7 @@ def _generate(
155190
# Includes prompt, completion, and total tokens used.
156191
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
157192
for _prompts in sub_prompts:
158-
response = self.client.create(prompt=_prompts, **params)
193+
response = self.completion_with_retry(prompt=_prompts, **params)
159194
choices.extend(response["choices"])
160195
_keys_to_use = _keys.intersection(response["usage"])
161196
for _key in _keys_to_use:

poetry.lock

Lines changed: 22 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ wolframalpha = {version = "5.0.0", optional = true}
3636
qdrant-client = {version = "^0.11.7", optional = true}
3737
dataclasses-json = "^0.5.7"
3838
tensorflow-text = {version = "^2.11.0", optional = true, python = "^3.10, <3.12"}
39+
tenacity = "^8.1.0"
3940

4041
[tool.poetry.group.docs.dependencies]
4142
autodoc_pydantic = "^1.8.0"

0 commit comments

Comments
 (0)