|
4 | 4 | from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union
|
5 | 5 |
|
6 | 6 | from pydantic import BaseModel, Extra, Field, root_validator
|
| 7 | +from tenacity import ( |
| 8 | + after_log, |
| 9 | + retry, |
| 10 | + retry_if_exception_type, |
| 11 | + stop_after_attempt, |
| 12 | + wait_exponential, |
| 13 | +) |
7 | 14 |
|
8 | 15 | from langchain.llms.base import BaseLLM
|
9 | 16 | from langchain.schema import Generation, LLMResult
|
@@ -56,6 +63,8 @@ class BaseOpenAI(BaseLLM, BaseModel):
|
56 | 63 | """Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
57 | 64 | logit_bias: Optional[Dict[str, float]] = Field(default_factory=dict)
|
58 | 65 | """Adjust the probability of specific tokens being generated."""
|
| 66 | + max_retries: int = 6 |
| 67 | + """Maximum number of retries to make when generating.""" |
59 | 68 |
|
60 | 69 | class Config:
|
61 | 70 | """Configuration for this pydantic object."""
|
@@ -115,6 +124,32 @@ def _default_params(self) -> Dict[str, Any]:
|
115 | 124 | }
|
116 | 125 | return {**normal_params, **self.model_kwargs}
|
117 | 126 |
|
| 127 | + def completion_with_retry(self, **kwargs: Any) -> Any: |
| 128 | + """Use tenacity to retry the completion call.""" |
| 129 | + import openai |
| 130 | + |
| 131 | + min_seconds = 4 |
| 132 | + max_seconds = 10 |
| 133 | + # Wait 2^x * 1 second between each retry starting with |
| 134 | + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards |
| 135 | + |
| 136 | + @retry( |
| 137 | + reraise=True, |
| 138 | + stop=stop_after_attempt(self.max_retries), |
| 139 | + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), |
| 140 | + retry=( |
| 141 | + retry_if_exception_type(openai.error.Timeout) |
| 142 | + | retry_if_exception_type(openai.error.APIError) |
| 143 | + | retry_if_exception_type(openai.error.APIConnectionError) |
| 144 | + | retry_if_exception_type(openai.error.RateLimitError) |
| 145 | + ), |
| 146 | + after=after_log(logger, logging.DEBUG), |
| 147 | + ) |
| 148 | + def _completion_with_retry(**kwargs: Any) -> Any: |
| 149 | + return self.client.create(**kwargs) |
| 150 | + |
| 151 | + return _completion_with_retry(**kwargs) |
| 152 | + |
118 | 153 | def _generate(
|
119 | 154 | self, prompts: List[str], stop: Optional[List[str]] = None
|
120 | 155 | ) -> LLMResult:
|
@@ -155,7 +190,7 @@ def _generate(
|
155 | 190 | # Includes prompt, completion, and total tokens used.
|
156 | 191 | _keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
157 | 192 | for _prompts in sub_prompts:
|
158 |
| - response = self.client.create(prompt=_prompts, **params) |
| 193 | + response = self.completion_with_retry(prompt=_prompts, **params) |
159 | 194 | choices.extend(response["choices"])
|
160 | 195 | _keys_to_use = _keys.intersection(response["usage"])
|
161 | 196 | for _key in _keys_to_use:
|
|
0 commit comments