Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@

logger = logging.getLogger(__name__)

RETRY_WAIT = 120
RETRY_WAIT = 60
RETRY_MAX_ATTEMPTS = 3
RETRY_EXPONENTIAL_BASE = 2
RETRY_JITTER = True
STRUCTURED_GENERATION_PROVIDERS = ["cerebras", "fireworks-ai"]
CODEAGENT_RESPONSE_FORMAT = {
"type": "json_schema",
Expand Down Expand Up @@ -1102,6 +1104,8 @@ def __init__(
self.retryer = Retrying(
max_attempts=RETRY_MAX_ATTEMPTS if retry else 1,
wait_seconds=RETRY_WAIT,
exponential_base=RETRY_EXPONENTIAL_BASE,
jitter=RETRY_JITTER,
retry_predicate=is_rate_limit_error,
reraise=True,
before_sleep_logger=(logger, logging.INFO),
Expand Down
20 changes: 15 additions & 5 deletions src/smolagents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import json
import keyword
import os
import random
import re
import time
from functools import lru_cache
Expand Down Expand Up @@ -516,20 +517,25 @@ def __init__(
self,
max_attempts: int = 1,
wait_seconds: float = 0.0,
exponential_base: float = 2.0,
jitter: bool = True,
retry_predicate: Callable[[BaseException], bool] | None = None,
reraise: bool = False,
before_sleep_logger: tuple[Logger, int] | None = None,
after_logger: tuple[Logger, int] | None = None,
):
self.max_attempts = max_attempts
self.wait_seconds = wait_seconds
self.exponential_base = exponential_base
self.jitter = jitter
self.retry_predicate = retry_predicate
self.reraise = reraise
self.before_sleep_logger = before_sleep_logger
self.after_logger = after_logger

def __call__(self, fn, *args: Any, **kwargs: Any) -> Any:
start_time = time.time()
delay = self.wait_seconds

for attempt_number in range(1, self.max_attempts + 1):
try:
Expand All @@ -542,7 +548,7 @@ def __call__(self, fn, *args: Any, **kwargs: Any) -> Any:
fn_name = getattr(fn, "__name__", repr(fn))
logger.log(
log_level,
f"Finished call to '{fn_name}' after {seconds:.3f}(s), this was attempt n°{attempt_number}.",
f"Finished call to '{fn_name}' after {seconds:.3f}(s), this was attempt n°{attempt_number}/{self.max_attempts}.",
)

return result
Expand All @@ -564,18 +570,22 @@ def __call__(self, fn, *args: Any, **kwargs: Any) -> Any:
fn_name = getattr(fn, "__name__", repr(fn))
logger.log(
log_level,
f"Finished call to '{fn_name}' after {seconds:.3f}(s), this was attempt n°{attempt_number}.",
f"Finished call to '{fn_name}' after {seconds:.3f}(s), this was attempt n°{attempt_number}/{self.max_attempts}.",
)

# Exponential backoff with jitter
# https://cookbook.openai.com/examples/how_to_handle_rate_limits#example-3-manual-backoff-implementation
delay *= self.exponential_base * (1 + self.jitter * random.random())

# Log before sleeping
if self.before_sleep_logger:
logger, log_level = self.before_sleep_logger
fn_name = getattr(fn, "__name__", repr(fn))
logger.log(
log_level,
f"Retrying {fn_name} in {self.wait_seconds} seconds as it raised {e.__class__.__name__}: {e}.",
f"Retrying {fn_name} in {delay} seconds as it raised {e.__class__.__name__}: {e}.",
)

# Sleep before next attempt
if self.wait_seconds > 0:
time.sleep(self.wait_seconds)
if delay > 0:
time.sleep(delay)
18 changes: 11 additions & 7 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,8 @@ def test_retry_on_rate_limit_error(self):
mock_litellm = MagicMock()

with (
patch("smolagents.models.RETRY_WAIT", 1),
patch("smolagents.models.RETRY_WAIT", 0.1),
patch("smolagents.utils.random.random", side_effect=[0.1, 0.1]),
patch("smolagents.models.LiteLLMModel.create_client", return_value=mock_litellm),
):
model = LiteLLMModel(model_id="test-model")
Expand All @@ -438,22 +439,25 @@ def test_retry_on_rate_limit_error(self):
# Create a 429 rate limit error
rate_limit_error = Exception("Error code: 429 - Rate limit exceeded")

# Mock the litellm client to raise error first, then succeed
model.client.completion.side_effect = [rate_limit_error, mock_success_response]
# Mock the litellm client to raise an error twice, and then succeed
model.client.completion.side_effect = [rate_limit_error, rate_limit_error, mock_success_response]

# Measure time to verify retry wait time
start_time = time.time()
result = model.generate(messages)
elapsed_time = time.time() - start_time

# Verify that completion was called twice (once failed, once succeeded)
assert model.client.completion.call_count == 2
# Verify that completion was called thrice (twice failed, once succeeded)
assert model.client.completion.call_count == 3
assert result.content == "Success response"
assert result.token_usage.input_tokens == 10
assert result.token_usage.output_tokens == 20

# Verify that the wait time was around 1s (allow some tolerance)
assert 0.9 <= elapsed_time <= 1.2
# Verify that the wait time was around
# 0.22s (1st retry) [0.1 * 2.0 * (1 + 1 * 0.1)]
# + 0.48s (2nd retry) [0.22 * 2.0 * (1 + 1 * 0.1)]
# = 0.704s (allow some tolerance)
assert 0.67 <= elapsed_time <= 0.73

def test_passing_flatten_messages(self):
model = LiteLLMModel(model_id="groq/llama-3.3-70b", flatten_messages_as_text=False)
Expand Down
Loading