Skip to content

Commit e831685

Browse files
committed
PYTHON-5506 Prototype adaptive token bucket retry
1 parent 75eee91 commit e831685

File tree

10 files changed

+234
-20
lines changed

10 files changed

+234
-20
lines changed

pymongo/asynchronous/collection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def __init__(
253253
unicode_decode_error_handler="replace", document_class=dict
254254
)
255255
self._timeout = database.client.options.timeout
256+
self._retry_policy = database.client._retry_policy
256257

257258
if create or kwargs:
258259
if _IS_SYNC:

pymongo/asynchronous/database.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def __init__(
136136
self._name = name
137137
self._client: AsyncMongoClient[_DocumentType] = client
138138
self._timeout = client.options.timeout
139+
self._retry_policy = client._retry_policy
139140

140141
@property
141142
def client(self) -> AsyncMongoClient[_DocumentType]:

pymongo/asynchronous/helpers.py

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
PyMongoError,
3535
)
3636
from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE
37+
from pymongo.lock import _async_create_lock
3738

3839
_IS_SYNC = False
3940

@@ -78,7 +79,10 @@ async def inner(*args: Any, **kwargs: Any) -> Any:
7879
_MAX_RETRIES = 3
7980
_BACKOFF_INITIAL = 0.05
8081
_BACKOFF_MAX = 10
81-
_TIME = time
82+
# DRIVERS-3240 will determine these defaults.
83+
DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0
84+
DEFAULT_RETRY_TOKEN_RETURN = 0.1
85+
_TIME = time # Added so synchro script doesn't remove the time import.
8286

8387

8488
async def _backoff(
@@ -89,23 +93,95 @@ async def _backoff(
8993
await asyncio.sleep(backoff)
9094

9195

96+
class _TokenBucket:
97+
"""A token bucket implementation for rate limiting."""
98+
99+
def __init__(
100+
self,
101+
capacity: float = DEFAULT_RETRY_TOKEN_CAPACITY,
102+
return_rate: float = DEFAULT_RETRY_TOKEN_RETURN,
103+
):
104+
self.lock = _async_create_lock()
105+
self.capacity = capacity
106+
# DRIVERS-3240 will determine how full the bucket should start.
107+
self.tokens = capacity
108+
self.return_rate = return_rate
109+
110+
async def consume(self) -> bool:
111+
"""Consume a token from the bucket if available."""
112+
async with self.lock:
113+
if self.tokens >= 1:
114+
self.tokens -= 1
115+
return True
116+
return False
117+
118+
async def deposit(self, retry: bool = False) -> None:
119+
"""Deposit a token back into the bucket."""
120+
retry_token = 1 if retry else 0
121+
async with self.lock:
122+
self.tokens = min(self.capacity, self.tokens + retry_token + self.return_rate)
123+
124+
125+
class _RetryPolicy:
126+
"""A retry limiter that performs exponential backoff with jitter.
127+
128+
Retry attempts are limited by a token bucket to prevent overwhelming the server during
129+
a prolonged outage or high load.
130+
"""
131+
132+
def __init__(
133+
self,
134+
token_bucket: _TokenBucket,
135+
attempts: int = _MAX_RETRIES,
136+
backoff_initial: float = _BACKOFF_INITIAL,
137+
backoff_max: float = _BACKOFF_MAX,
138+
):
139+
self.token_bucket = token_bucket
140+
self.attempts = attempts
141+
self.backoff_initial = backoff_initial
142+
self.backoff_max = backoff_max
143+
144+
async def record_success(self, retry: bool):
145+
"""Record a successful operation."""
146+
await self.token_bucket.deposit(retry)
147+
148+
async def backoff(self, attempt: int) -> None:
149+
"""Return the backoff duration for the given ."""
150+
await _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)
151+
152+
async def should_retry(self, attempt: int) -> bool:
153+
"""Return if we have budget to retry and how long to backoff."""
154+
# TODO: Check CSOT deadline here.
155+
if attempt > self.attempts:
156+
return False
157+
# Check token bucket last since we only want to consume a token if we actually retry.
158+
if not await self.token_bucket.consume():
159+
# DRIVERS-3246 Improve diagnostics when this case happens.
160+
# We could add info to the exception and log.
161+
return False
162+
return True
163+
164+
92165
def _retry_overload(func: F) -> F:
93166
@functools.wraps(func)
94-
async def inner(*args: Any, **kwargs: Any) -> Any:
167+
async def inner(self: Any, *args: Any, **kwargs: Any) -> Any:
168+
retry_policy = self._retry_policy
95169
attempt = 0
96170
while True:
97171
try:
98-
return await func(*args, **kwargs)
172+
res = await func(self, *args, **kwargs)
173+
await retry_policy.record_success(retry=attempt > 0)
174+
return res
99175
except PyMongoError as exc:
100176
if not exc.has_error_label("Retryable"):
101177
raise
102178
attempt += 1
103-
if attempt > _MAX_RETRIES:
179+
if not await retry_policy.should_retry(attempt):
104180
raise
105181

106182
# Implement exponential backoff on retry.
107183
if exc.has_error_label("SystemOverloaded"):
108-
await _backoff(attempt)
184+
await retry_policy.backoff(attempt)
109185
continue
110186

111187
return cast(F, inner)

pymongo/asynchronous/mongo_client.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,11 @@
6767
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
6868
from pymongo.asynchronous.client_session import _EmptyServerSession
6969
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
70-
from pymongo.asynchronous.helpers import _MAX_RETRIES, _backoff, _retry_overload
70+
from pymongo.asynchronous.helpers import (
71+
_retry_overload,
72+
_RetryPolicy,
73+
_TokenBucket,
74+
)
7175
from pymongo.asynchronous.settings import TopologySettings
7276
from pymongo.asynchronous.topology import Topology, _ErrorContext
7377
from pymongo.client_options import ClientOptions
@@ -774,6 +778,7 @@ def __init__(
774778
self._timeout: float | None = None
775779
self._topology_settings: TopologySettings = None # type: ignore[assignment]
776780
self._event_listeners: _EventListeners | None = None
781+
self._retry_policy = _RetryPolicy(_TokenBucket())
777782

778783
# _pool_class, _monitor_class, and _condition_class are for deep
779784
# customization of PyMongo, e.g. Motor.
@@ -2740,7 +2745,7 @@ def __init__(
27402745
self._always_retryable = False
27412746
self._multiple_retries = _csot.get_timeout() is not None
27422747
self._client = mongo_client
2743-
2748+
self._retry_policy = mongo_client._retry_policy
27442749
self._func = func
27452750
self._bulk = bulk
27462751
self._session = session
@@ -2775,7 +2780,9 @@ async def run(self) -> T:
27752780
while True:
27762781
self._check_last_error(check_csot=True)
27772782
try:
2778-
return await self._read() if self._is_read else await self._write()
2783+
res = await self._read() if self._is_read else await self._write()
2784+
await self._retry_policy.record_success(self._attempt_number > 0)
2785+
return res
27792786
except ServerSelectionTimeoutError:
27802787
# The application may think the write was never attempted
27812788
# if we raise ServerSelectionTimeoutError on the retry
@@ -2846,13 +2853,13 @@ async def run(self) -> T:
28462853

28472854
self._always_retryable = always_retryable
28482855
if always_retryable:
2849-
if self._attempt_number > _MAX_RETRIES:
2856+
if not await self._retry_policy.should_retry(self._attempt_number):
28502857
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
28512858
raise self._last_error from exc
28522859
else:
28532860
raise
28542861
if overloaded:
2855-
await _backoff(self._attempt_number)
2862+
await self._retry_policy.backoff(self._attempt_number)
28562863

28572864
def _is_not_eligible_for_retry(self) -> bool:
28582865
"""Checks if the exchange is not eligible for retry"""

pymongo/synchronous/collection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ def __init__(
256256
unicode_decode_error_handler="replace", document_class=dict
257257
)
258258
self._timeout = database.client.options.timeout
259+
self._retry_policy = database.client._retry_policy
259260

260261
if create or kwargs:
261262
if _IS_SYNC:

pymongo/synchronous/database.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def __init__(
136136
self._name = name
137137
self._client: MongoClient[_DocumentType] = client
138138
self._timeout = client.options.timeout
139+
self._retry_policy = client._retry_policy
139140

140141
@property
141142
def client(self) -> MongoClient[_DocumentType]:

pymongo/synchronous/helpers.py

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
PyMongoError,
3535
)
3636
from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE
37+
from pymongo.lock import _create_lock
3738

3839
_IS_SYNC = True
3940

@@ -78,7 +79,10 @@ def inner(*args: Any, **kwargs: Any) -> Any:
7879
_MAX_RETRIES = 3
7980
_BACKOFF_INITIAL = 0.05
8081
_BACKOFF_MAX = 10
81-
_TIME = time
82+
# DRIVERS-3240 will determine these defaults.
83+
DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0
84+
DEFAULT_RETRY_TOKEN_RETURN = 0.1
85+
_TIME = time # Added so synchro script doesn't remove the time import.
8286

8387

8488
def _backoff(
@@ -89,23 +93,95 @@ def _backoff(
8993
time.sleep(backoff)
9094

9195

96+
class _TokenBucket:
97+
"""A token bucket implementation for rate limiting."""
98+
99+
def __init__(
100+
self,
101+
capacity: float = DEFAULT_RETRY_TOKEN_CAPACITY,
102+
return_rate: float = DEFAULT_RETRY_TOKEN_RETURN,
103+
):
104+
self.lock = _create_lock()
105+
self.capacity = capacity
106+
# DRIVERS-3240 will determine how full the bucket should start.
107+
self.tokens = capacity
108+
self.return_rate = return_rate
109+
110+
def consume(self) -> bool:
111+
"""Consume a token from the bucket if available."""
112+
with self.lock:
113+
if self.tokens >= 1:
114+
self.tokens -= 1
115+
return True
116+
return False
117+
118+
def deposit(self, retry: bool = False) -> None:
119+
"""Deposit a token back into the bucket."""
120+
retry_token = 1 if retry else 0
121+
with self.lock:
122+
self.tokens = min(self.capacity, self.tokens + retry_token + self.return_rate)
123+
124+
125+
class _RetryPolicy:
126+
"""A retry limiter that performs exponential backoff with jitter.
127+
128+
Retry attempts are limited by a token bucket to prevent overwhelming the server during
129+
a prolonged outage or high load.
130+
"""
131+
132+
def __init__(
133+
self,
134+
token_bucket: _TokenBucket,
135+
attempts: int = _MAX_RETRIES,
136+
backoff_initial: float = _BACKOFF_INITIAL,
137+
backoff_max: float = _BACKOFF_MAX,
138+
):
139+
self.token_bucket = token_bucket
140+
self.attempts = attempts
141+
self.backoff_initial = backoff_initial
142+
self.backoff_max = backoff_max
143+
144+
def record_success(self, retry: bool):
145+
"""Record a successful operation."""
146+
self.token_bucket.deposit(retry)
147+
148+
def backoff(self, attempt: int) -> None:
149+
"""Return the backoff duration for the given ."""
150+
_backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)
151+
152+
def should_retry(self, attempt: int) -> bool:
153+
"""Return if we have budget to retry and how long to backoff."""
154+
# TODO: Check CSOT deadline here.
155+
if attempt > self.attempts:
156+
return False
157+
# Check token bucket last since we only want to consume a token if we actually retry.
158+
if not self.token_bucket.consume():
159+
# DRIVERS-3246 Improve diagnostics when this case happens.
160+
# We could add info to the exception and log.
161+
return False
162+
return True
163+
164+
92165
def _retry_overload(func: F) -> F:
93166
@functools.wraps(func)
94-
def inner(*args: Any, **kwargs: Any) -> Any:
167+
def inner(self: Any, *args: Any, **kwargs: Any) -> Any:
168+
retry_policy = self._retry_policy
95169
attempt = 0
96170
while True:
97171
try:
98-
return func(*args, **kwargs)
172+
res = func(self, *args, **kwargs)
173+
retry_policy.record_success(retry=attempt > 0)
174+
return res
99175
except PyMongoError as exc:
100176
if not exc.has_error_label("Retryable"):
101177
raise
102178
attempt += 1
103-
if attempt > _MAX_RETRIES:
179+
if not retry_policy.should_retry(attempt):
104180
raise
105181

106182
# Implement exponential backoff on retry.
107183
if exc.has_error_label("SystemOverloaded"):
108-
_backoff(attempt)
184+
retry_policy.backoff(attempt)
109185
continue
110186

111187
return cast(F, inner)

pymongo/synchronous/mongo_client.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,11 @@
110110
from pymongo.synchronous.client_bulk import _ClientBulk
111111
from pymongo.synchronous.client_session import _EmptyServerSession
112112
from pymongo.synchronous.command_cursor import CommandCursor
113-
from pymongo.synchronous.helpers import _MAX_RETRIES, _backoff, _retry_overload
113+
from pymongo.synchronous.helpers import (
114+
_retry_overload,
115+
_RetryPolicy,
116+
_TokenBucket,
117+
)
114118
from pymongo.synchronous.settings import TopologySettings
115119
from pymongo.synchronous.topology import Topology, _ErrorContext
116120
from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription
@@ -774,6 +778,7 @@ def __init__(
774778
self._timeout: float | None = None
775779
self._topology_settings: TopologySettings = None # type: ignore[assignment]
776780
self._event_listeners: _EventListeners | None = None
781+
self._retry_policy = _RetryPolicy(_TokenBucket())
777782

778783
# _pool_class, _monitor_class, and _condition_class are for deep
779784
# customization of PyMongo, e.g. Motor.
@@ -2730,7 +2735,7 @@ def __init__(
27302735
self._always_retryable = False
27312736
self._multiple_retries = _csot.get_timeout() is not None
27322737
self._client = mongo_client
2733-
2738+
self._retry_policy = mongo_client._retry_policy
27342739
self._func = func
27352740
self._bulk = bulk
27362741
self._session = session
@@ -2765,7 +2770,9 @@ def run(self) -> T:
27652770
while True:
27662771
self._check_last_error(check_csot=True)
27672772
try:
2768-
return self._read() if self._is_read else self._write()
2773+
res = self._read() if self._is_read else self._write()
2774+
self._retry_policy.record_success(self._attempt_number > 0)
2775+
return res
27692776
except ServerSelectionTimeoutError:
27702777
# The application may think the write was never attempted
27712778
# if we raise ServerSelectionTimeoutError on the retry
@@ -2836,13 +2843,13 @@ def run(self) -> T:
28362843

28372844
self._always_retryable = always_retryable
28382845
if always_retryable:
2839-
if self._attempt_number > _MAX_RETRIES:
2846+
if not self._retry_policy.should_retry(self._attempt_number):
28402847
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
28412848
raise self._last_error from exc
28422849
else:
28432850
raise
28442851
if overloaded:
2845-
_backoff(self._attempt_number)
2852+
self._retry_policy.backoff(self._attempt_number)
28462853

28472854
def _is_not_eligible_for_retry(self) -> bool:
28482855
"""Checks if the exchange is not eligible for retry"""

0 commit comments

Comments
 (0)