Skip to content

Commit 875c564

Browse files
authored
PYTHON-5506 Prototype adaptive token bucket retry (#2501)
Add adaptive token bucket based retry policy. Successfully completed commands deposit 0.1 token. Failed retry attempts consume 1 token. A retry is only permitted if there is an available token. Token bucket starts full with the maximum 1000 tokens.
1 parent 75eee91 commit 875c564

File tree

10 files changed

+373
-35
lines changed

10 files changed

+373
-35
lines changed

pymongo/asynchronous/collection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def __init__(
253253
unicode_decode_error_handler="replace", document_class=dict
254254
)
255255
self._timeout = database.client.options.timeout
256+
self._retry_policy = database.client._retry_policy
256257

257258
if create or kwargs:
258259
if _IS_SYNC:

pymongo/asynchronous/database.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def __init__(
136136
self._name = name
137137
self._client: AsyncMongoClient[_DocumentType] = client
138138
self._timeout = client.options.timeout
139+
self._retry_policy = client._retry_policy
139140

140141
@property
141142
def client(self) -> AsyncMongoClient[_DocumentType]:

pymongo/asynchronous/helpers.py

Lines changed: 94 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,21 @@
2121
import random
2222
import socket
2323
import sys
24-
import time
24+
import time as time # noqa: PLC0414 # needed in sync version
2525
from typing import (
2626
Any,
2727
Callable,
2828
TypeVar,
2929
cast,
3030
)
3131

32+
from pymongo import _csot
3233
from pymongo.errors import (
3334
OperationFailure,
3435
PyMongoError,
3536
)
3637
from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE
38+
from pymongo.lock import _async_create_lock
3739

3840
_IS_SYNC = False
3941

@@ -78,34 +80,115 @@ async def inner(*args: Any, **kwargs: Any) -> Any:
7880
_MAX_RETRIES = 3
7981
_BACKOFF_INITIAL = 0.05
8082
_BACKOFF_MAX = 10
81-
_TIME = time
83+
# DRIVERS-3240 will determine these defaults.
84+
DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0
85+
DEFAULT_RETRY_TOKEN_RETURN = 0.1
8286

8387

84-
async def _backoff(
88+
def _backoff(
8589
attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX
86-
) -> None:
90+
) -> float:
8791
jitter = random.random() # noqa: S311
88-
backoff = jitter * min(initial_delay * (2**attempt), max_delay)
89-
await asyncio.sleep(backoff)
92+
return jitter * min(initial_delay * (2**attempt), max_delay)
93+
94+
95+
class _TokenBucket:
96+
"""A token bucket implementation for rate limiting."""
97+
98+
def __init__(
99+
self,
100+
capacity: float = DEFAULT_RETRY_TOKEN_CAPACITY,
101+
return_rate: float = DEFAULT_RETRY_TOKEN_RETURN,
102+
):
103+
self.lock = _async_create_lock()
104+
self.capacity = capacity
105+
# DRIVERS-3240 will determine how full the bucket should start.
106+
self.tokens = capacity
107+
self.return_rate = return_rate
108+
109+
async def consume(self) -> bool:
110+
"""Consume a token from the bucket if available."""
111+
async with self.lock:
112+
if self.tokens >= 1:
113+
self.tokens -= 1
114+
return True
115+
return False
116+
117+
async def deposit(self, retry: bool = False) -> None:
118+
"""Deposit a token back into the bucket."""
119+
retry_token = 1 if retry else 0
120+
async with self.lock:
121+
self.tokens = min(self.capacity, self.tokens + retry_token + self.return_rate)
122+
123+
124+
class _RetryPolicy:
125+
"""A retry limiter that performs exponential backoff with jitter.
126+
127+
Retry attempts are limited by a token bucket to prevent overwhelming the server during
128+
a prolonged outage or high load.
129+
"""
130+
131+
def __init__(
132+
self,
133+
token_bucket: _TokenBucket,
134+
attempts: int = _MAX_RETRIES,
135+
backoff_initial: float = _BACKOFF_INITIAL,
136+
backoff_max: float = _BACKOFF_MAX,
137+
):
138+
self.token_bucket = token_bucket
139+
self.attempts = attempts
140+
self.backoff_initial = backoff_initial
141+
self.backoff_max = backoff_max
142+
143+
async def record_success(self, retry: bool) -> None:
144+
"""Record a successful operation."""
145+
await self.token_bucket.deposit(retry)
146+
147+
def backoff(self, attempt: int) -> float:
148+
"""Return the backoff duration for the given ."""
149+
return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)
150+
151+
async def should_retry(self, attempt: int, delay: float) -> bool:
152+
"""Return if we have budget to retry and how long to backoff."""
153+
if attempt > self.attempts:
154+
return False
155+
156+
# If the delay would exceed the deadline, bail early before consuming a token.
157+
if _csot.get_timeout():
158+
if time.monotonic() + delay > _csot.get_deadline():
159+
return False
160+
161+
# Check token bucket last since we only want to consume a token if we actually retry.
162+
if not await self.token_bucket.consume():
163+
# DRIVERS-3246 Improve diagnostics when this case happens.
164+
# We could add info to the exception and log.
165+
return False
166+
return True
90167

91168

92169
def _retry_overload(func: F) -> F:
93170
@functools.wraps(func)
94-
async def inner(*args: Any, **kwargs: Any) -> Any:
171+
async def inner(self: Any, *args: Any, **kwargs: Any) -> Any:
172+
retry_policy = self._retry_policy
95173
attempt = 0
96174
while True:
97175
try:
98-
return await func(*args, **kwargs)
176+
res = await func(self, *args, **kwargs)
177+
await retry_policy.record_success(retry=attempt > 0)
178+
return res
99179
except PyMongoError as exc:
100180
if not exc.has_error_label("Retryable"):
101181
raise
102182
attempt += 1
103-
if attempt > _MAX_RETRIES:
183+
delay = 0
184+
if exc.has_error_label("SystemOverloaded"):
185+
delay = retry_policy.backoff(attempt)
186+
if not await retry_policy.should_retry(attempt, delay):
104187
raise
105188

106189
# Implement exponential backoff on retry.
107-
if exc.has_error_label("SystemOverloaded"):
108-
await _backoff(attempt)
190+
if delay:
191+
await asyncio.sleep(delay)
109192
continue
110193

111194
return cast(F, inner)

pymongo/asynchronous/mongo_client.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import asyncio
3636
import contextlib
3737
import os
38+
import time as time # noqa: PLC0414 # needed in sync version
3839
import warnings
3940
import weakref
4041
from collections import defaultdict
@@ -67,7 +68,11 @@
6768
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
6869
from pymongo.asynchronous.client_session import _EmptyServerSession
6970
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
70-
from pymongo.asynchronous.helpers import _MAX_RETRIES, _backoff, _retry_overload
71+
from pymongo.asynchronous.helpers import (
72+
_retry_overload,
73+
_RetryPolicy,
74+
_TokenBucket,
75+
)
7176
from pymongo.asynchronous.settings import TopologySettings
7277
from pymongo.asynchronous.topology import Topology, _ErrorContext
7378
from pymongo.client_options import ClientOptions
@@ -774,6 +779,7 @@ def __init__(
774779
self._timeout: float | None = None
775780
self._topology_settings: TopologySettings = None # type: ignore[assignment]
776781
self._event_listeners: _EventListeners | None = None
782+
self._retry_policy = _RetryPolicy(_TokenBucket())
777783

778784
# _pool_class, _monitor_class, and _condition_class are for deep
779785
# customization of PyMongo, e.g. Motor.
@@ -2740,7 +2746,7 @@ def __init__(
27402746
self._always_retryable = False
27412747
self._multiple_retries = _csot.get_timeout() is not None
27422748
self._client = mongo_client
2743-
2749+
self._retry_policy = mongo_client._retry_policy
27442750
self._func = func
27452751
self._bulk = bulk
27462752
self._session = session
@@ -2775,7 +2781,9 @@ async def run(self) -> T:
27752781
while True:
27762782
self._check_last_error(check_csot=True)
27772783
try:
2778-
return await self._read() if self._is_read else await self._write()
2784+
res = await self._read() if self._is_read else await self._write()
2785+
await self._retry_policy.record_success(self._attempt_number > 0)
2786+
return res
27792787
except ServerSelectionTimeoutError:
27802788
# The application may think the write was never attempted
27812789
# if we raise ServerSelectionTimeoutError on the retry
@@ -2846,13 +2854,14 @@ async def run(self) -> T:
28462854

28472855
self._always_retryable = always_retryable
28482856
if always_retryable:
2849-
if self._attempt_number > _MAX_RETRIES:
2857+
delay = self._retry_policy.backoff(self._attempt_number) if overloaded else 0
2858+
if not await self._retry_policy.should_retry(self._attempt_number, delay):
28502859
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
28512860
raise self._last_error from exc
28522861
else:
28532862
raise
28542863
if overloaded:
2855-
await _backoff(self._attempt_number)
2864+
await asyncio.sleep(delay)
28562865

28572866
def _is_not_eligible_for_retry(self) -> bool:
28582867
"""Checks if the exchange is not eligible for retry"""

pymongo/synchronous/collection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ def __init__(
256256
unicode_decode_error_handler="replace", document_class=dict
257257
)
258258
self._timeout = database.client.options.timeout
259+
self._retry_policy = database.client._retry_policy
259260

260261
if create or kwargs:
261262
if _IS_SYNC:

pymongo/synchronous/database.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def __init__(
136136
self._name = name
137137
self._client: MongoClient[_DocumentType] = client
138138
self._timeout = client.options.timeout
139+
self._retry_policy = client._retry_policy
139140

140141
@property
141142
def client(self) -> MongoClient[_DocumentType]:

pymongo/synchronous/helpers.py

Lines changed: 93 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,21 @@
2121
import random
2222
import socket
2323
import sys
24-
import time
24+
import time as time # noqa: PLC0414 # needed in sync version
2525
from typing import (
2626
Any,
2727
Callable,
2828
TypeVar,
2929
cast,
3030
)
3131

32+
from pymongo import _csot
3233
from pymongo.errors import (
3334
OperationFailure,
3435
PyMongoError,
3536
)
3637
from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE
38+
from pymongo.lock import _create_lock
3739

3840
_IS_SYNC = True
3941

@@ -78,34 +80,115 @@ def inner(*args: Any, **kwargs: Any) -> Any:
7880
_MAX_RETRIES = 3
7981
_BACKOFF_INITIAL = 0.05
8082
_BACKOFF_MAX = 10
81-
_TIME = time
83+
# DRIVERS-3240 will determine these defaults.
84+
DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0
85+
DEFAULT_RETRY_TOKEN_RETURN = 0.1
8286

8387

8488
def _backoff(
8589
attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX
86-
) -> None:
90+
) -> float:
8791
jitter = random.random() # noqa: S311
88-
backoff = jitter * min(initial_delay * (2**attempt), max_delay)
89-
time.sleep(backoff)
92+
return jitter * min(initial_delay * (2**attempt), max_delay)
93+
94+
95+
class _TokenBucket:
96+
"""A token bucket implementation for rate limiting."""
97+
98+
def __init__(
99+
self,
100+
capacity: float = DEFAULT_RETRY_TOKEN_CAPACITY,
101+
return_rate: float = DEFAULT_RETRY_TOKEN_RETURN,
102+
):
103+
self.lock = _create_lock()
104+
self.capacity = capacity
105+
# DRIVERS-3240 will determine how full the bucket should start.
106+
self.tokens = capacity
107+
self.return_rate = return_rate
108+
109+
def consume(self) -> bool:
110+
"""Consume a token from the bucket if available."""
111+
with self.lock:
112+
if self.tokens >= 1:
113+
self.tokens -= 1
114+
return True
115+
return False
116+
117+
def deposit(self, retry: bool = False) -> None:
118+
"""Deposit a token back into the bucket."""
119+
retry_token = 1 if retry else 0
120+
with self.lock:
121+
self.tokens = min(self.capacity, self.tokens + retry_token + self.return_rate)
122+
123+
124+
class _RetryPolicy:
125+
"""A retry limiter that performs exponential backoff with jitter.
126+
127+
Retry attempts are limited by a token bucket to prevent overwhelming the server during
128+
a prolonged outage or high load.
129+
"""
130+
131+
def __init__(
132+
self,
133+
token_bucket: _TokenBucket,
134+
attempts: int = _MAX_RETRIES,
135+
backoff_initial: float = _BACKOFF_INITIAL,
136+
backoff_max: float = _BACKOFF_MAX,
137+
):
138+
self.token_bucket = token_bucket
139+
self.attempts = attempts
140+
self.backoff_initial = backoff_initial
141+
self.backoff_max = backoff_max
142+
143+
def record_success(self, retry: bool) -> None:
144+
"""Record a successful operation."""
145+
self.token_bucket.deposit(retry)
146+
147+
def backoff(self, attempt: int) -> float:
148+
"""Return the backoff duration for the given ."""
149+
return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)
150+
151+
def should_retry(self, attempt: int, delay: float) -> bool:
152+
"""Return if we have budget to retry and how long to backoff."""
153+
if attempt > self.attempts:
154+
return False
155+
156+
# If the delay would exceed the deadline, bail early before consuming a token.
157+
if _csot.get_timeout():
158+
if time.monotonic() + delay > _csot.get_deadline():
159+
return False
160+
161+
# Check token bucket last since we only want to consume a token if we actually retry.
162+
if not self.token_bucket.consume():
163+
# DRIVERS-3246 Improve diagnostics when this case happens.
164+
# We could add info to the exception and log.
165+
return False
166+
return True
90167

91168

92169
def _retry_overload(func: F) -> F:
93170
@functools.wraps(func)
94-
def inner(*args: Any, **kwargs: Any) -> Any:
171+
def inner(self: Any, *args: Any, **kwargs: Any) -> Any:
172+
retry_policy = self._retry_policy
95173
attempt = 0
96174
while True:
97175
try:
98-
return func(*args, **kwargs)
176+
res = func(self, *args, **kwargs)
177+
retry_policy.record_success(retry=attempt > 0)
178+
return res
99179
except PyMongoError as exc:
100180
if not exc.has_error_label("Retryable"):
101181
raise
102182
attempt += 1
103-
if attempt > _MAX_RETRIES:
183+
delay = 0
184+
if exc.has_error_label("SystemOverloaded"):
185+
delay = retry_policy.backoff(attempt)
186+
if not retry_policy.should_retry(attempt, delay):
104187
raise
105188

106189
# Implement exponential backoff on retry.
107-
if exc.has_error_label("SystemOverloaded"):
108-
_backoff(attempt)
190+
if delay:
191+
time.sleep(delay)
109192
continue
110193

111194
return cast(F, inner)

0 commit comments

Comments
 (0)