Skip to content

Commit be1d71f

Browse files
Proposition: retry on rate limit errors (#1801)
1 parent 1904ddd commit be1d71f

File tree

8 files changed

+198
-47
lines changed

8 files changed

+198
-47
lines changed

README.md

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ model = InferenceClientModel(
9595
from smolagents import LiteLLMModel
9696

9797
model = LiteLLMModel(
98-
model_id="anthropic/claude-3-5-sonnet-latest",
98+
model_id="anthropic/claude-4-sonnet-latest",
9999
temperature=0.2,
100100
api_key=os.environ["ANTHROPIC_API_KEY"]
101101
)
@@ -106,9 +106,9 @@ model = LiteLLMModel(
106106

107107
```py
108108
import os
109-
from smolagents import OpenAIServerModel
109+
from smolagents import OpenAIModel
110110

111-
model = OpenAIServerModel(
111+
model = OpenAIModel(
112112
model_id="deepseek-ai/DeepSeek-R1",
113113
api_base="https://api.together.xyz/v1/", # Leave this blank to query OpenAI servers.
114114
api_key=os.environ["TOGETHER_API_KEY"], # Switch to the API key for the server you're targeting.
@@ -120,9 +120,9 @@ model = OpenAIServerModel(
120120

121121
```py
122122
import os
123-
from smolagents import OpenAIServerModel
123+
from smolagents import OpenAIModel
124124

125-
model = OpenAIServerModel(
125+
model = OpenAIModel(
126126
model_id="openai/gpt-4o",
127127
api_base="https://openrouter.ai/api/v1", # Leave this blank to query OpenAI servers.
128128
api_key=os.environ["OPENROUTER_API_KEY"], # Switch to the API key for the server you're targeting.
@@ -137,7 +137,7 @@ model = OpenAIServerModel(
137137
from smolagents import TransformersModel
138138

139139
model = TransformersModel(
140-
model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
140+
model_id="Qwen/Qwen3-4B-Instruct-2507",
141141
max_new_tokens=4096,
142142
device_map="auto"
143143
)
@@ -148,9 +148,9 @@ model = TransformersModel(
148148

149149
```py
150150
import os
151-
from smolagents import AzureOpenAIServerModel
151+
from smolagents import AzureOpenAIModel
152152

153-
model = AzureOpenAIServerModel(
153+
model = AzureOpenAIModel(
154154
model_id = os.environ.get("AZURE_OPENAI_MODEL"),
155155
azure_endpoint=os.environ.get("AZURE_OPENAI_ENDPOINT"),
156156
api_key=os.environ.get("AZURE_OPENAI_API_KEY"),
@@ -163,9 +163,9 @@ model = AzureOpenAIServerModel(
163163

164164
```py
165165
import os
166-
from smolagents import AmazonBedrockServerModel
166+
from smolagents import AmazonBedrockModel
167167

168-
model = AmazonBedrockServerModel(
168+
model = AmazonBedrockModel(
169169
model_id = os.environ.get("AMAZON_BEDROCK_MODEL_ID")
170170
)
171171
```
@@ -178,14 +178,14 @@ You can run agents from CLI using two commands: `smolagent` and `webagent`.
178178
`smolagent` is a generalist command to run a multi-step `CodeAgent` that can be equipped with various tools.
179179

180180
```bash
181-
smolagent "Plan a trip to Tokyo, Kyoto and Osaka between Mar 28 and Apr 7." --model-type "InferenceClientModel" --model-id "Qwen/Qwen2.5-Coder-32B-Instruct" --imports "pandas numpy" --tools "web_search"
181+
smolagent "Plan a trip to Tokyo, Kyoto and Osaka between Mar 28 and Apr 7." --model-type "InferenceClientModel" --model-id "Qwen/Qwen3-Next-80B-A3B-Instruct" --imports pandas numpy --tools web_search
182182
```
183183

184184
Meanwhile `webagent` is a specific web-browsing agent using [helium](https://github.com/mherrmann/helium) (read more [here](https://github.com/huggingface/smolagents/blob/main/src/smolagents/vision_web_browser.py)).
185185

186186
For instance:
187187
```bash
188-
webagent "go to xyz.com/men, get to sale section, click the first clothing item you see. Get the product details, and the price, return them. note that I'm shopping from France" --model-type "LiteLLMModel" --model-id "gpt-4o"
188+
webagent "go to xyz.com/men, get to sale section, click the first clothing item you see. Get the product details, and the price, return them. note that I'm shopping from France" --model-type "LiteLLMModel" --model-id "gpt-5"
189189
```
190190

191191
## How do Code agents work?

examples/multiple_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# model = TransformersModel(model_id="meta-llama/Llama-3.2-2B-Instruct")
1010

1111
# For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-20240620'
12-
# model = LiteLLMModel(model_id="gpt-4o")
12+
# model = LiteLLMModel(model_id="gpt-5")
1313

1414

1515
@tool

examples/rag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def forward(self, query: str) -> str:
5858
retriever_tool = RetrieverTool(docs_processed)
5959
agent = CodeAgent(
6060
tools=[retriever_tool],
61-
model=InferenceClientModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
61+
model=InferenceClientModel(model_id="Qwen/Qwen3-Next-80B-A3B-Instruct"),
6262
max_steps=4,
6363
verbosity_level=2,
6464
stream_outputs=True,

examples/rag_using_chromadb.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,14 @@ def forward(self, query: str) -> str:
9898
# Choose which LLM engine to use!
9999

100100
# from smolagents import InferenceClientModel
101-
# model = InferenceClientModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
101+
# model = InferenceClientModel(model_id="Qwen/Qwen3-Next-80B-A3B-Instruct")
102102

103103
# from smolagents import TransformersModel
104-
# model = TransformersModel(model_id="meta-llama/Llama-3.2-2B-Instruct")
104+
# model = TransformersModel(model_id="Qwen/Qwen3-4B-Instruct-2507")
105105

106-
# For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-20240620' and also change 'os.environ.get("ANTHROPIC_API_KEY")'
106+
# For anthropic: change model_id below to 'anthropic/claude-4-sonnet-latest' and also change 'os.environ.get("ANTHROPIC_API_KEY")'
107107
model = LiteLLMModel(
108-
model_id="groq/llama-3.3-70b-versatile",
108+
model_id="groq/openai/gpt-oss-120b",
109109
api_key=os.environ.get("GROQ_API_KEY"),
110110
)
111111

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ dependencies = [
1616
"requests>=2.32.3",
1717
"rich>=13.9.4",
1818
"jinja2>=3.1.4",
19-
"pillow>=10.0.1", # Security fix for CVE-2023-4863: https://pillow.readthedocs.io/en/stable/releasenotes/10.0.1.html
20-
"python-dotenv"
19+
"pillow>=10.0.1",
20+
# Security fix for CVE-2023-4863: https://pillow.readthedocs.io/en/stable/releasenotes/10.0.1.html
21+
"python-dotenv",
2122
]
2223

2324
[project.optional-dependencies]

src/smolagents/models.py

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from .monitoring import TokenUsage
2828
from .tools import Tool
29-
from .utils import RateLimiter, _is_package_available, encode_image_base64, make_image_url, parse_json_blob
29+
from .utils import RateLimiter, Retrying, _is_package_available, encode_image_base64, make_image_url, parse_json_blob
3030

3131

3232
if TYPE_CHECKING:
@@ -35,6 +35,8 @@
3535

3636
logger = logging.getLogger(__name__)
3737

38+
RETRY_WAIT = 120
39+
RETRY_MAX_ATTEMPTS = 3
3840
STRUCTURED_GENERATION_PROVIDERS = ["cerebras", "fireworks-ai"]
3941
CODEAGENT_RESPONSE_FORMAT = {
4042
"type": "json_schema",
@@ -1078,6 +1080,8 @@ class ApiModel(Model):
10781080
Pre-configured API client instance. If not provided, a default client will be created. Defaults to None.
10791081
requests_per_minute (`float`, **optional**):
10801082
Rate limit in requests per minute.
1083+
retry (`bool`, **optional**):
1084+
Wether to retry on rate limit errors, up to RETRY_MAX_ATTEMPTS times. Defaults to True.
10811085
**kwargs:
10821086
Additional keyword arguments to forward to the underlying model completion call.
10831087
"""
@@ -1088,12 +1092,21 @@ def __init__(
10881092
custom_role_conversions: dict[str, str] | None = None,
10891093
client: Any | None = None,
10901094
requests_per_minute: float | None = None,
1095+
retry: bool = True,
10911096
**kwargs,
10921097
):
10931098
super().__init__(model_id=model_id, **kwargs)
10941099
self.custom_role_conversions = custom_role_conversions or {}
10951100
self.client = client or self.create_client()
10961101
self.rate_limiter = RateLimiter(requests_per_minute)
1102+
self.retryer = Retrying(
1103+
max_attempts=RETRY_MAX_ATTEMPTS if retry else 1,
1104+
wait_seconds=RETRY_WAIT,
1105+
retry_predicate=is_rate_limit_error,
1106+
reraise=True,
1107+
before_sleep_logger=(logger, logging.INFO),
1108+
after_logger=(logger, logging.INFO),
1109+
)
10971110

10981111
def create_client(self):
10991112
"""Create the API client for the specific service."""
@@ -1104,6 +1117,17 @@ def _apply_rate_limit(self):
11041117
self.rate_limiter.throttle()
11051118

11061119

1120+
def is_rate_limit_error(exception: BaseException) -> bool:
1121+
"""Check if the exception is a rate limit error."""
1122+
error_str = str(exception).lower()
1123+
return (
1124+
"429" in error_str
1125+
or "rate limit" in error_str
1126+
or "too many requests" in error_str
1127+
or "rate_limit" in error_str
1128+
)
1129+
1130+
11071131
class LiteLLMModel(ApiModel):
11081132
"""Model to use [LiteLLM Python SDK](https://docs.litellm.ai/docs/#litellm-python-sdk) to access hundreds of LLMs.
11091133
@@ -1186,7 +1210,8 @@ def generate(
11861210
**kwargs,
11871211
)
11881212
self._apply_rate_limit()
1189-
response = self.client.completion(**completion_kwargs)
1213+
response = self.retryer(self.client.completion, **completion_kwargs)
1214+
11901215
if not response.choices:
11911216
raise RuntimeError(
11921217
f"Unexpected API response: model '{self.model_id}' returned no choices. "
@@ -1228,7 +1253,9 @@ def generate_stream(
12281253
**kwargs,
12291254
)
12301255
self._apply_rate_limit()
1231-
for event in self.client.completion(**completion_kwargs, stream=True, stream_options={"include_usage": True}):
1256+
for event in self.retryer(
1257+
self.client.completion, **completion_kwargs, stream=True, stream_options={"include_usage": True}
1258+
):
12321259
if getattr(event, "usage", None):
12331260
yield ChatMessageStreamDelta(
12341261
content="",
@@ -1398,8 +1425,8 @@ class InferenceClientModel(ApiModel):
13981425
Example:
13991426
```python
14001427
>>> engine = InferenceClientModel(
1401-
... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
1402-
... provider="nebius",
1428+
... model_id="Qwen/Qwen3-Next-80B-A3B-Thinking",
1429+
... provider="hyperbolic",
14031430
... token="your_hf_token_here",
14041431
... max_tokens=5000,
14051432
... )
@@ -1412,7 +1439,7 @@ class InferenceClientModel(ApiModel):
14121439

14131440
def __init__(
14141441
self,
1415-
model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct",
1442+
model_id: str = "Qwen/Qwen3-Next-80B-A3B-Instruct",
14161443
provider: str | None = None,
14171444
token: str | None = None,
14181445
timeout: int = 120,
@@ -1472,7 +1499,7 @@ def generate(
14721499
**kwargs,
14731500
)
14741501
self._apply_rate_limit()
1475-
response = self.client.chat_completion(**completion_kwargs)
1502+
response = self.retryer(self.client.chat_completion, **completion_kwargs)
14761503
content = response.choices[0].message.content
14771504
if stop_sequences is not None and not self.supports_stop_parameter:
14781505
content = remove_content_after_stop_sequences(content, stop_sequences)
@@ -1506,8 +1533,11 @@ def generate_stream(
15061533
**kwargs,
15071534
)
15081535
self._apply_rate_limit()
1509-
for event in self.client.chat.completions.create(
1510-
**completion_kwargs, stream=True, stream_options={"include_usage": True}
1536+
for event in self.retryer(
1537+
self.client.chat.completions.create,
1538+
**completion_kwargs,
1539+
stream=True,
1540+
stream_options={"include_usage": True},
15111541
):
15121542
if getattr(event, "usage", None):
15131543
yield ChatMessageStreamDelta(
@@ -1539,12 +1569,12 @@ def generate_stream(
15391569
raise ValueError(f"No content or tool calls in event: {event}")
15401570

15411571

1542-
class OpenAIServerModel(ApiModel):
1572+
class OpenAIModel(ApiModel):
15431573
"""This model connects to an OpenAI-compatible API server.
15441574
15451575
Parameters:
15461576
model_id (`str`):
1547-
The model identifier to use on the server (e.g. "gpt-3.5-turbo").
1577+
The model identifier to use on the server (e.g. "gpt-5").
15481578
api_base (`str`, *optional*):
15491579
The base URL of the OpenAI-compatible API server.
15501580
api_key (`str`, *optional*):
@@ -1595,7 +1625,7 @@ def create_client(self):
15951625
import openai
15961626
except ModuleNotFoundError as e:
15971627
raise ModuleNotFoundError(
1598-
"Please install 'openai' extra to use OpenAIServerModel: `pip install 'smolagents[openai]'`"
1628+
"Please install 'openai' extra to use OpenAIModel: `pip install 'smolagents[openai]'`"
15991629
) from e
16001630

16011631
return openai.OpenAI(**self.client_kwargs)
@@ -1619,8 +1649,11 @@ def generate_stream(
16191649
**kwargs,
16201650
)
16211651
self._apply_rate_limit()
1622-
for event in self.client.chat.completions.create(
1623-
**completion_kwargs, stream=True, stream_options={"include_usage": True}
1652+
for event in self.retryer(
1653+
self.client.chat.completions.create,
1654+
**completion_kwargs,
1655+
stream=True,
1656+
stream_options={"include_usage": True},
16241657
):
16251658
if event.usage:
16261659
yield ChatMessageStreamDelta(
@@ -1670,7 +1703,7 @@ def generate(
16701703
**kwargs,
16711704
)
16721705
self._apply_rate_limit()
1673-
response = self.client.chat.completions.create(**completion_kwargs)
1706+
response = self.retryer(self.client.chat.completions.create, **completion_kwargs)
16741707
content = response.choices[0].message.content
16751708
if stop_sequences is not None and not self.supports_stop_parameter:
16761709
content = remove_content_after_stop_sequences(content, stop_sequences)
@@ -1686,10 +1719,10 @@ def generate(
16861719
)
16871720

16881721

1689-
OpenAIModel = OpenAIServerModel
1722+
OpenAIServerModel = OpenAIModel
16901723

16911724

1692-
class AzureOpenAIServerModel(OpenAIServerModel):
1725+
class AzureOpenAIModel(OpenAIModel):
16931726
"""This model connects to an Azure OpenAI deployment.
16941727
16951728
Parameters:
@@ -1740,16 +1773,16 @@ def create_client(self):
17401773
import openai
17411774
except ModuleNotFoundError as e:
17421775
raise ModuleNotFoundError(
1743-
"Please install 'openai' extra to use AzureOpenAIServerModel: `pip install 'smolagents[openai]'`"
1776+
"Please install 'openai' extra to use AzureOpenAIModel: `pip install 'smolagents[openai]'`"
17441777
) from e
17451778

17461779
return openai.AzureOpenAI(**self.client_kwargs)
17471780

17481781

1749-
AzureOpenAIModel = AzureOpenAIServerModel
1782+
AzureOpenAIServerModel = AzureOpenAIModel
17501783

17511784

1752-
class AmazonBedrockServerModel(ApiModel):
1785+
class AmazonBedrockModel(ApiModel):
17531786
"""
17541787
A model class for interacting with Amazon Bedrock Server models through the Bedrock API.
17551788
@@ -1789,7 +1822,7 @@ class AmazonBedrockServerModel(ApiModel):
17891822
Examples:
17901823
Creating a model instance with default settings:
17911824
```python
1792-
>>> bedrock_model = AmazonBedrockServerModel(
1825+
>>> bedrock_model = AmazonBedrockModel(
17931826
... model_id='us.amazon.nova-pro-v1:0'
17941827
... )
17951828
```
@@ -1798,15 +1831,15 @@ class AmazonBedrockServerModel(ApiModel):
17981831
```python
17991832
>>> import boto3
18001833
>>> client = boto3.client('bedrock-runtime', region_name='us-west-2')
1801-
>>> bedrock_model = AmazonBedrockServerModel(
1834+
>>> bedrock_model = AmazonBedrockModel(
18021835
... model_id='us.amazon.nova-pro-v1:0',
18031836
... client=client
18041837
... )
18051838
```
18061839
18071840
Creating a model instance with client_kwargs for internal client creation:
18081841
```python
1809-
>>> bedrock_model = AmazonBedrockServerModel(
1842+
>>> bedrock_model = AmazonBedrockModel(
18101843
... model_id='us.amazon.nova-pro-v1:0',
18111844
... client_kwargs={'region_name': 'us-west-2', 'endpoint_url': 'https://custom-endpoint.com'}
18121845
... )
@@ -1823,7 +1856,7 @@ class AmazonBedrockServerModel(ApiModel):
18231856
... "guardrailVersion": 'v1'
18241857
... },
18251858
... }
1826-
>>> bedrock_model = AmazonBedrockServerModel(
1859+
>>> bedrock_model = AmazonBedrockModel(
18271860
... model_id='anthropic.claude-3-haiku-20240307-v1:0',
18281861
... **additional_api_config
18291862
... )
@@ -1929,7 +1962,7 @@ def generate(
19291962
)
19301963
self._apply_rate_limit()
19311964
# self.client is created in ApiModel class
1932-
response = self.client.converse(**completion_kwargs)
1965+
response = self.retryer(self.client.converse, **completion_kwargs)
19331966

19341967
# Get content blocks with "text" key: in case thinking blocks are present, discard them
19351968
message_content_blocks_with_text = [
@@ -1953,7 +1986,7 @@ def generate(
19531986
)
19541987

19551988

1956-
AmazonBedrockModel = AmazonBedrockServerModel
1989+
AmazonBedrockServerModel = AmazonBedrockModel
19571990

19581991
__all__ = [
19591992
"REMOVE_PARAMETER",

0 commit comments

Comments
 (0)