Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@

logger = logging.getLogger(__name__)

RETRY_WAIT = 120
RETRY_WAIT = 60
RETRY_MAX_ATTEMPTS = 3
RETRY_EXPONENTIAL_BASE = 2
RETRY_JITTER = True
STRUCTURED_GENERATION_PROVIDERS = ["cerebras", "fireworks-ai"]
CODEAGENT_RESPONSE_FORMAT = {
"type": "json_schema",
Expand Down Expand Up @@ -1102,6 +1104,8 @@ def __init__(
self.retryer = Retrying(
max_attempts=RETRY_MAX_ATTEMPTS if retry else 1,
wait_seconds=RETRY_WAIT,
exponential_base=RETRY_EXPONENTIAL_BASE,
jitter=RETRY_JITTER,
retry_predicate=is_rate_limit_error,
reraise=True,
before_sleep_logger=(logger, logging.INFO),
Expand Down
20 changes: 15 additions & 5 deletions src/smolagents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import re
import time
import random
from functools import lru_cache
from io import BytesIO
from logging import Logger
Expand Down Expand Up @@ -516,20 +517,25 @@ def __init__(
self,
max_attempts: int = 1,
wait_seconds: float = 0.0,
exponential_base: float = 2.0,
jitter: bool = True,
retry_predicate: Callable[[BaseException], bool] | None = None,
reraise: bool = False,
before_sleep_logger: tuple[Logger, int] | None = None,
after_logger: tuple[Logger, int] | None = None,
):
self.max_attempts = max_attempts
self.wait_seconds = wait_seconds
self.exponential_base = exponential_base
self.jitter = jitter
self.retry_predicate = retry_predicate
self.reraise = reraise
self.before_sleep_logger = before_sleep_logger
self.after_logger = after_logger

def __call__(self, fn, *args: Any, **kwargs: Any) -> Any:
start_time = time.time()
delay = self.wait_seconds

for attempt_number in range(1, self.max_attempts + 1):
try:
Expand All @@ -542,7 +548,7 @@ def __call__(self, fn, *args: Any, **kwargs: Any) -> Any:
fn_name = getattr(fn, "__name__", repr(fn))
logger.log(
log_level,
f"Finished call to '{fn_name}' after {seconds:.3f}(s), this was attempt n°{attempt_number}.",
f"Finished call to '{fn_name}' after {seconds:.3f}(s), this was attempt n°{attempt_number}/{self.max_attempts}.",
)

return result
Expand All @@ -564,18 +570,22 @@ def __call__(self, fn, *args: Any, **kwargs: Any) -> Any:
fn_name = getattr(fn, "__name__", repr(fn))
logger.log(
log_level,
f"Finished call to '{fn_name}' after {seconds:.3f}(s), this was attempt n°{attempt_number}.",
f"Finished call to '{fn_name}' after {seconds:.3f}(s), this was attempt n°{attempt_number}/{self.max_attempts}.",
)

# Exponential backoff with jitter
# https://cookbook.openai.com/examples/how_to_handle_rate_limits#example-3-manual-backoff-implementation
delay *= self.exponential_base * (1 + self.jitter * random.random())

# Log before sleeping
if self.before_sleep_logger:
logger, log_level = self.before_sleep_logger
fn_name = getattr(fn, "__name__", repr(fn))
logger.log(
log_level,
f"Retrying {fn_name} in {self.wait_seconds} seconds as it raised {e.__class__.__name__}: {e}.",
f"Retrying {fn_name} in {delay} seconds as it raised {e.__class__.__name__}: {e}.",
)

# Sleep before next attempt
if self.wait_seconds > 0:
time.sleep(self.wait_seconds)
if delay > 0:
time.sleep(delay)
18 changes: 11 additions & 7 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,8 @@ def test_retry_on_rate_limit_error(self):
mock_litellm = MagicMock()

with (
patch("smolagents.models.RETRY_WAIT", 1),
patch("smolagents.models.RETRY_WAIT", 0.1),
patch("smolagents.utils.random.random", side_effect=[0.1, 0.1]),
patch("smolagents.models.LiteLLMModel.create_client", return_value=mock_litellm),
):
model = LiteLLMModel(model_id="test-model")
Expand All @@ -438,22 +439,25 @@ def test_retry_on_rate_limit_error(self):
# Create a 429 rate limit error
rate_limit_error = Exception("Error code: 429 - Rate limit exceeded")

# Mock the litellm client to raise error first, then succeed
model.client.completion.side_effect = [rate_limit_error, mock_success_response]
# Mock the litellm client to raise an error twice, and then succeed
model.client.completion.side_effect = [rate_limit_error, rate_limit_error, mock_success_response]

# Measure time to verify retry wait time
start_time = time.time()
result = model.generate(messages)
elapsed_time = time.time() - start_time

# Verify that completion was called twice (once failed, once succeeded)
assert model.client.completion.call_count == 2
# Verify that completion was called thrice (twice failed, once succeeded)
assert model.client.completion.call_count == 3
assert result.content == "Success response"
assert result.token_usage.input_tokens == 10
assert result.token_usage.output_tokens == 20

# Verify that the wait time was around 1s (allow some tolerance)
assert 0.9 <= elapsed_time <= 1.2
# Verify that the wait time was around
# 0.22s (1st retry) [0.1 * 2.0 * (1 + 1 * 0.1)]
# + 0.48s (2nd retry) [0.22 * 2.0 * (1 + 1 * 0.1)]
# = 0.704s (allow some tolerance)
assert 0.67 <= elapsed_time <= 0.73

def test_passing_flatten_messages(self):
model = LiteLLMModel(model_id="groq/llama-3.3-70b", flatten_messages_as_text=False)
Expand Down
Loading