diff --git a/src/smolagents/models.py b/src/smolagents/models.py index df9ebd2de..e9895622a 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -35,8 +35,10 @@ logger = logging.getLogger(__name__) -RETRY_WAIT = 120 +RETRY_WAIT = 60 RETRY_MAX_ATTEMPTS = 3 +RETRY_EXPONENTIAL_BASE = 2 +RETRY_JITTER = True STRUCTURED_GENERATION_PROVIDERS = ["cerebras", "fireworks-ai"] CODEAGENT_RESPONSE_FORMAT = { "type": "json_schema", @@ -1102,6 +1104,8 @@ def __init__( self.retryer = Retrying( max_attempts=RETRY_MAX_ATTEMPTS if retry else 1, wait_seconds=RETRY_WAIT, + exponential_base=RETRY_EXPONENTIAL_BASE, + jitter=RETRY_JITTER, retry_predicate=is_rate_limit_error, reraise=True, before_sleep_logger=(logger, logging.INFO), diff --git a/src/smolagents/utils.py b/src/smolagents/utils.py index 9c0e51c8e..8a475757c 100644 --- a/src/smolagents/utils.py +++ b/src/smolagents/utils.py @@ -21,6 +21,7 @@ import json import keyword import os +import random import re import time from functools import lru_cache @@ -516,6 +517,8 @@ def __init__( self, max_attempts: int = 1, wait_seconds: float = 0.0, + exponential_base: float = 2.0, + jitter: bool = True, retry_predicate: Callable[[BaseException], bool] | None = None, reraise: bool = False, before_sleep_logger: tuple[Logger, int] | None = None, @@ -523,6 +526,8 @@ def __init__( ): self.max_attempts = max_attempts self.wait_seconds = wait_seconds + self.exponential_base = exponential_base + self.jitter = jitter self.retry_predicate = retry_predicate self.reraise = reraise self.before_sleep_logger = before_sleep_logger @@ -530,6 +535,7 @@ def __init__( def __call__(self, fn, *args: Any, **kwargs: Any) -> Any: start_time = time.time() + delay = self.wait_seconds for attempt_number in range(1, self.max_attempts + 1): try: @@ -542,7 +548,7 @@ def __call__(self, fn, *args: Any, **kwargs: Any) -> Any: fn_name = getattr(fn, "__name__", repr(fn)) logger.log( log_level, - f"Finished call to '{fn_name}' after {seconds:.3f}(s), this was attempt n°{attempt_number}.", + f"Finished call to '{fn_name}' after {seconds:.3f}(s), this was attempt n°{attempt_number}/{self.max_attempts}.", ) return result @@ -564,18 +570,22 @@ def __call__(self, fn, *args: Any, **kwargs: Any) -> Any: fn_name = getattr(fn, "__name__", repr(fn)) logger.log( log_level, - f"Finished call to '{fn_name}' after {seconds:.3f}(s), this was attempt n°{attempt_number}.", + f"Finished call to '{fn_name}' after {seconds:.3f}(s), this was attempt n°{attempt_number}/{self.max_attempts}.", ) + # Exponential backoff with jitter + # https://cookbook.openai.com/examples/how_to_handle_rate_limits#example-3-manual-backoff-implementation + delay *= self.exponential_base * (1 + self.jitter * random.random()) + # Log before sleeping if self.before_sleep_logger: logger, log_level = self.before_sleep_logger fn_name = getattr(fn, "__name__", repr(fn)) logger.log( log_level, - f"Retrying {fn_name} in {self.wait_seconds} seconds as it raised {e.__class__.__name__}: {e}.", + f"Retrying {fn_name} in {delay} seconds as it raised {e.__class__.__name__}: {e}.", ) # Sleep before next attempt - if self.wait_seconds > 0: - time.sleep(self.wait_seconds) + if delay > 0: + time.sleep(delay) diff --git a/tests/test_models.py b/tests/test_models.py index 456299e3a..6859f4b3f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -419,7 +419,8 @@ def test_retry_on_rate_limit_error(self): mock_litellm = MagicMock() with ( - patch("smolagents.models.RETRY_WAIT", 1), + patch("smolagents.models.RETRY_WAIT", 0.1), + patch("smolagents.utils.random.random", side_effect=[0.1, 0.1]), patch("smolagents.models.LiteLLMModel.create_client", return_value=mock_litellm), ): model = LiteLLMModel(model_id="test-model") @@ -438,22 +439,25 @@ def test_retry_on_rate_limit_error(self): # Create a 429 rate limit error rate_limit_error = Exception("Error code: 429 - Rate limit exceeded") - # Mock the litellm client to raise error first, then succeed - model.client.completion.side_effect = [rate_limit_error, mock_success_response] + # Mock the litellm client to raise an error twice, and then succeed + model.client.completion.side_effect = [rate_limit_error, rate_limit_error, mock_success_response] # Measure time to verify retry wait time start_time = time.time() result = model.generate(messages) elapsed_time = time.time() - start_time - # Verify that completion was called twice (once failed, once succeeded) - assert model.client.completion.call_count == 2 + # Verify that completion was called thrice (twice failed, once succeeded) + assert model.client.completion.call_count == 3 assert result.content == "Success response" assert result.token_usage.input_tokens == 10 assert result.token_usage.output_tokens == 20 - # Verify that the wait time was around 1s (allow some tolerance) - assert 0.9 <= elapsed_time <= 1.2 + # Verify that the wait time was around + # 0.22s (1st retry) [0.1 * 2.0 * (1 + 1 * 0.1)] + # + 0.48s (2nd retry) [0.22 * 2.0 * (1 + 1 * 0.1)] + # = 0.704s (allow some tolerance) + assert 0.67 <= elapsed_time <= 0.73 def test_passing_flatten_messages(self): model = LiteLLMModel(model_id="groq/llama-3.3-70b", flatten_messages_as_text=False)