Skip to content

Commit ff06a0c

Browse files
authored
Proposition: Add exponential backoff with jitter for retries (#1829)
1 parent 317b573 commit ff06a0c

File tree

3 files changed

+31
-13
lines changed

3 files changed

+31
-13
lines changed

src/smolagents/models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,10 @@
3535

3636
logger = logging.getLogger(__name__)
3737

38-
RETRY_WAIT = 120
38+
RETRY_WAIT = 60
3939
RETRY_MAX_ATTEMPTS = 3
40+
RETRY_EXPONENTIAL_BASE = 2
41+
RETRY_JITTER = True
4042
STRUCTURED_GENERATION_PROVIDERS = ["cerebras", "fireworks-ai"]
4143
CODEAGENT_RESPONSE_FORMAT = {
4244
"type": "json_schema",
@@ -1109,6 +1111,8 @@ def __init__(
11091111
self.retryer = Retrying(
11101112
max_attempts=RETRY_MAX_ATTEMPTS if retry else 1,
11111113
wait_seconds=RETRY_WAIT,
1114+
exponential_base=RETRY_EXPONENTIAL_BASE,
1115+
jitter=RETRY_JITTER,
11121116
retry_predicate=is_rate_limit_error,
11131117
reraise=True,
11141118
before_sleep_logger=(logger, logging.INFO),

src/smolagents/utils.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import json
2222
import keyword
2323
import os
24+
import random
2425
import re
2526
import time
2627
from functools import lru_cache
@@ -516,20 +517,25 @@ def __init__(
516517
self,
517518
max_attempts: int = 1,
518519
wait_seconds: float = 0.0,
520+
exponential_base: float = 2.0,
521+
jitter: bool = True,
519522
retry_predicate: Callable[[BaseException], bool] | None = None,
520523
reraise: bool = False,
521524
before_sleep_logger: tuple[Logger, int] | None = None,
522525
after_logger: tuple[Logger, int] | None = None,
523526
):
524527
self.max_attempts = max_attempts
525528
self.wait_seconds = wait_seconds
529+
self.exponential_base = exponential_base
530+
self.jitter = jitter
526531
self.retry_predicate = retry_predicate
527532
self.reraise = reraise
528533
self.before_sleep_logger = before_sleep_logger
529534
self.after_logger = after_logger
530535

531536
def __call__(self, fn, *args: Any, **kwargs: Any) -> Any:
532537
start_time = time.time()
538+
delay = self.wait_seconds
533539

534540
for attempt_number in range(1, self.max_attempts + 1):
535541
try:
@@ -542,7 +548,7 @@ def __call__(self, fn, *args: Any, **kwargs: Any) -> Any:
542548
fn_name = getattr(fn, "__name__", repr(fn))
543549
logger.log(
544550
log_level,
545-
f"Finished call to '{fn_name}' after {seconds:.3f}(s), this was attempt n°{attempt_number}.",
551+
f"Finished call to '{fn_name}' after {seconds:.3f}(s), this was attempt n°{attempt_number}/{self.max_attempts}.",
546552
)
547553

548554
return result
@@ -564,18 +570,22 @@ def __call__(self, fn, *args: Any, **kwargs: Any) -> Any:
564570
fn_name = getattr(fn, "__name__", repr(fn))
565571
logger.log(
566572
log_level,
567-
f"Finished call to '{fn_name}' after {seconds:.3f}(s), this was attempt n°{attempt_number}.",
573+
f"Finished call to '{fn_name}' after {seconds:.3f}(s), this was attempt n°{attempt_number}/{self.max_attempts}.",
568574
)
569575

576+
# Exponential backoff with jitter
577+
# https://cookbook.openai.com/examples/how_to_handle_rate_limits#example-3-manual-backoff-implementation
578+
delay *= self.exponential_base * (1 + self.jitter * random.random())
579+
570580
# Log before sleeping
571581
if self.before_sleep_logger:
572582
logger, log_level = self.before_sleep_logger
573583
fn_name = getattr(fn, "__name__", repr(fn))
574584
logger.log(
575585
log_level,
576-
f"Retrying {fn_name} in {self.wait_seconds} seconds as it raised {e.__class__.__name__}: {e}.",
586+
f"Retrying {fn_name} in {delay} seconds as it raised {e.__class__.__name__}: {e}.",
577587
)
578588

579589
# Sleep before next attempt
580-
if self.wait_seconds > 0:
581-
time.sleep(self.wait_seconds)
590+
if delay > 0:
591+
time.sleep(delay)

tests/test_models.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,8 @@ def test_retry_on_rate_limit_error(self):
419419
mock_litellm = MagicMock()
420420

421421
with (
422-
patch("smolagents.models.RETRY_WAIT", 1),
422+
patch("smolagents.models.RETRY_WAIT", 0.1),
423+
patch("smolagents.utils.random.random", side_effect=[0.1, 0.1]),
423424
patch("smolagents.models.LiteLLMModel.create_client", return_value=mock_litellm),
424425
):
425426
model = LiteLLMModel(model_id="test-model")
@@ -438,22 +439,25 @@ def test_retry_on_rate_limit_error(self):
438439
# Create a 429 rate limit error
439440
rate_limit_error = Exception("Error code: 429 - Rate limit exceeded")
440441

441-
# Mock the litellm client to raise error first, then succeed
442-
model.client.completion.side_effect = [rate_limit_error, mock_success_response]
442+
# Mock the litellm client to raise an error twice, and then succeed
443+
model.client.completion.side_effect = [rate_limit_error, rate_limit_error, mock_success_response]
443444

444445
# Measure time to verify retry wait time
445446
start_time = time.time()
446447
result = model.generate(messages)
447448
elapsed_time = time.time() - start_time
448449

449-
# Verify that completion was called twice (once failed, once succeeded)
450-
assert model.client.completion.call_count == 2
450+
# Verify that completion was called thrice (twice failed, once succeeded)
451+
assert model.client.completion.call_count == 3
451452
assert result.content == "Success response"
452453
assert result.token_usage.input_tokens == 10
453454
assert result.token_usage.output_tokens == 20
454455

455-
# Verify that the wait time was around 1s (allow some tolerance)
456-
assert 0.9 <= elapsed_time <= 1.2
456+
# Verify that the wait time was around
457+
# 0.22s (1st retry) [0.1 * 2.0 * (1 + 1 * 0.1)]
458+
# + 0.48s (2nd retry) [0.22 * 2.0 * (1 + 1 * 0.1)]
459+
# = 0.704s (allow some tolerance)
460+
assert 0.67 <= elapsed_time <= 0.73
457461

458462
def test_passing_flatten_messages(self):
459463
model = LiteLLMModel(model_id="groq/llama-3.3-70b", flatten_messages_as_text=False)

0 commit comments

Comments
 (0)