Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pymongo/asynchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def __init__(
unicode_decode_error_handler="replace", document_class=dict
)
self._timeout = database.client.options.timeout
self._retry_policy = database.client._retry_policy

if create or kwargs:
if _IS_SYNC:
Expand Down
1 change: 1 addition & 0 deletions pymongo/asynchronous/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(
self._name = name
self._client: AsyncMongoClient[_DocumentType] = client
self._timeout = client.options.timeout
self._retry_policy = client._retry_policy

@property
def client(self) -> AsyncMongoClient[_DocumentType]:
Expand Down
105 changes: 94 additions & 11 deletions pymongo/asynchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,21 @@
import random
import socket
import sys
import time
import time as time # noqa: PLC0414 # needed in sync version
from typing import (
Any,
Callable,
TypeVar,
cast,
)

from pymongo import _csot
from pymongo.errors import (
OperationFailure,
PyMongoError,
)
from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE
from pymongo.lock import _async_create_lock

_IS_SYNC = False

Expand Down Expand Up @@ -78,34 +80,115 @@ async def inner(*args: Any, **kwargs: Any) -> Any:
_MAX_RETRIES = 3
_BACKOFF_INITIAL = 0.05
_BACKOFF_MAX = 10
_TIME = time
# DRIVERS-3240 will determine these defaults.
DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0
DEFAULT_RETRY_TOKEN_RETURN = 0.1


async def _backoff(
def _backoff(
attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX
) -> None:
) -> float:
jitter = random.random() # noqa: S311
backoff = jitter * min(initial_delay * (2**attempt), max_delay)
await asyncio.sleep(backoff)
return jitter * min(initial_delay * (2**attempt), max_delay)


class _TokenBucket:
"""A token bucket implementation for rate limiting."""

def __init__(
self,
capacity: float = DEFAULT_RETRY_TOKEN_CAPACITY,
return_rate: float = DEFAULT_RETRY_TOKEN_RETURN,
):
self.lock = _async_create_lock()
self.capacity = capacity
# DRIVERS-3240 will determine how full the bucket should start.
self.tokens = capacity
self.return_rate = return_rate

async def consume(self) -> bool:
"""Consume a token from the bucket if available."""
async with self.lock:
if self.tokens >= 1:
self.tokens -= 1
return True
return False

async def deposit(self, retry: bool = False) -> None:
"""Deposit a token back into the bucket."""
retry_token = 1 if retry else 0
async with self.lock:
self.tokens = min(self.capacity, self.tokens + retry_token + self.return_rate)


class _RetryPolicy:
"""A retry limiter that performs exponential backoff with jitter.

Retry attempts are limited by a token bucket to prevent overwhelming the server during
a prolonged outage or high load.
"""

def __init__(
self,
token_bucket: _TokenBucket,
attempts: int = _MAX_RETRIES,
backoff_initial: float = _BACKOFF_INITIAL,
backoff_max: float = _BACKOFF_MAX,
):
self.token_bucket = token_bucket
self.attempts = attempts
self.backoff_initial = backoff_initial
self.backoff_max = backoff_max

async def record_success(self, retry: bool) -> None:
"""Record a successful operation."""
await self.token_bucket.deposit(retry)

def backoff(self, attempt: int) -> float:
"""Return the backoff duration for the given ."""
return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)

async def should_retry(self, attempt: int, delay: float) -> bool:
"""Return if we have budget to retry and how long to backoff."""
if attempt > self.attempts:
return False

# If the delay would exceed the deadline, bail early before consuming a token.
if _csot.get_timeout():
if time.monotonic() + delay > _csot.get_deadline():
return False

# Check token bucket last since we only want to consume a token if we actually retry.
if not await self.token_bucket.consume():
# DRIVERS-3246 Improve diagnostics when this case happens.
# We could add info to the exception and log.
return False
return True


def _retry_overload(func: F) -> F:
@functools.wraps(func)
async def inner(*args: Any, **kwargs: Any) -> Any:
async def inner(self: Any, *args: Any, **kwargs: Any) -> Any:
retry_policy = self._retry_policy
attempt = 0
while True:
try:
return await func(*args, **kwargs)
res = await func(self, *args, **kwargs)
await retry_policy.record_success(retry=attempt > 0)
return res
except PyMongoError as exc:
if not exc.has_error_label("Retryable"):
raise
attempt += 1
if attempt > _MAX_RETRIES:
delay = 0
if exc.has_error_label("SystemOverloaded"):
delay = retry_policy.backoff(attempt)
if not await retry_policy.should_retry(attempt, delay):
raise

# Implement exponential backoff on retry.
if exc.has_error_label("SystemOverloaded"):
await _backoff(attempt)
if delay:
await asyncio.sleep(delay)
continue

return cast(F, inner)
Expand Down
19 changes: 14 additions & 5 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import asyncio
import contextlib
import os
import time as time # noqa: PLC0414 # needed in sync version
import warnings
import weakref
from collections import defaultdict
Expand Down Expand Up @@ -67,7 +68,11 @@
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.asynchronous.client_session import _EmptyServerSession
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.helpers import _MAX_RETRIES, _backoff, _retry_overload
from pymongo.asynchronous.helpers import (
_retry_overload,
_RetryPolicy,
_TokenBucket,
)
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology, _ErrorContext
from pymongo.client_options import ClientOptions
Expand Down Expand Up @@ -774,6 +779,7 @@ def __init__(
self._timeout: float | None = None
self._topology_settings: TopologySettings = None # type: ignore[assignment]
self._event_listeners: _EventListeners | None = None
self._retry_policy = _RetryPolicy(_TokenBucket())

# _pool_class, _monitor_class, and _condition_class are for deep
# customization of PyMongo, e.g. Motor.
Expand Down Expand Up @@ -2740,7 +2746,7 @@ def __init__(
self._always_retryable = False
self._multiple_retries = _csot.get_timeout() is not None
self._client = mongo_client

self._retry_policy = mongo_client._retry_policy
self._func = func
self._bulk = bulk
self._session = session
Expand Down Expand Up @@ -2775,7 +2781,9 @@ async def run(self) -> T:
while True:
self._check_last_error(check_csot=True)
try:
return await self._read() if self._is_read else await self._write()
res = await self._read() if self._is_read else await self._write()
await self._retry_policy.record_success(self._attempt_number > 0)
return res
except ServerSelectionTimeoutError:
# The application may think the write was never attempted
# if we raise ServerSelectionTimeoutError on the retry
Expand Down Expand Up @@ -2846,13 +2854,14 @@ async def run(self) -> T:

self._always_retryable = always_retryable
if always_retryable:
if self._attempt_number > _MAX_RETRIES:
delay = self._retry_policy.backoff(self._attempt_number) if overloaded else 0
if not await self._retry_policy.should_retry(self._attempt_number, delay):
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc
else:
raise
if overloaded:
await _backoff(self._attempt_number)
await asyncio.sleep(delay)

def _is_not_eligible_for_retry(self) -> bool:
"""Checks if the exchange is not eligible for retry"""
Expand Down
1 change: 1 addition & 0 deletions pymongo/synchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def __init__(
unicode_decode_error_handler="replace", document_class=dict
)
self._timeout = database.client.options.timeout
self._retry_policy = database.client._retry_policy

if create or kwargs:
if _IS_SYNC:
Expand Down
1 change: 1 addition & 0 deletions pymongo/synchronous/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(
self._name = name
self._client: MongoClient[_DocumentType] = client
self._timeout = client.options.timeout
self._retry_policy = client._retry_policy

@property
def client(self) -> MongoClient[_DocumentType]:
Expand Down
103 changes: 93 additions & 10 deletions pymongo/synchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,21 @@
import random
import socket
import sys
import time
import time as time # noqa: PLC0414 # needed in sync version
from typing import (
Any,
Callable,
TypeVar,
cast,
)

from pymongo import _csot
from pymongo.errors import (
OperationFailure,
PyMongoError,
)
from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE
from pymongo.lock import _create_lock

_IS_SYNC = True

Expand Down Expand Up @@ -78,34 +80,115 @@ def inner(*args: Any, **kwargs: Any) -> Any:
_MAX_RETRIES = 3
_BACKOFF_INITIAL = 0.05
_BACKOFF_MAX = 10
_TIME = time
# DRIVERS-3240 will determine these defaults.
DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0
DEFAULT_RETRY_TOKEN_RETURN = 0.1


def _backoff(
attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX
) -> None:
) -> float:
jitter = random.random() # noqa: S311
backoff = jitter * min(initial_delay * (2**attempt), max_delay)
time.sleep(backoff)
return jitter * min(initial_delay * (2**attempt), max_delay)


class _TokenBucket:
"""A token bucket implementation for rate limiting."""

def __init__(
self,
capacity: float = DEFAULT_RETRY_TOKEN_CAPACITY,
return_rate: float = DEFAULT_RETRY_TOKEN_RETURN,
):
self.lock = _create_lock()
self.capacity = capacity
# DRIVERS-3240 will determine how full the bucket should start.
self.tokens = capacity
self.return_rate = return_rate

def consume(self) -> bool:
"""Consume a token from the bucket if available."""
with self.lock:
if self.tokens >= 1:
self.tokens -= 1
return True
return False

def deposit(self, retry: bool = False) -> None:
"""Deposit a token back into the bucket."""
retry_token = 1 if retry else 0
with self.lock:
self.tokens = min(self.capacity, self.tokens + retry_token + self.return_rate)


class _RetryPolicy:
"""A retry limiter that performs exponential backoff with jitter.

Retry attempts are limited by a token bucket to prevent overwhelming the server during
a prolonged outage or high load.
"""

def __init__(
self,
token_bucket: _TokenBucket,
attempts: int = _MAX_RETRIES,
backoff_initial: float = _BACKOFF_INITIAL,
backoff_max: float = _BACKOFF_MAX,
):
self.token_bucket = token_bucket
self.attempts = attempts
self.backoff_initial = backoff_initial
self.backoff_max = backoff_max

def record_success(self, retry: bool) -> None:
"""Record a successful operation."""
self.token_bucket.deposit(retry)

def backoff(self, attempt: int) -> float:
"""Return the backoff duration for the given ."""
return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)

def should_retry(self, attempt: int, delay: float) -> bool:
"""Return if we have budget to retry and how long to backoff."""
if attempt > self.attempts:
return False

# If the delay would exceed the deadline, bail early before consuming a token.
if _csot.get_timeout():
if time.monotonic() + delay > _csot.get_deadline():
return False

# Check token bucket last since we only want to consume a token if we actually retry.
if not self.token_bucket.consume():
# DRIVERS-3246 Improve diagnostics when this case happens.
# We could add info to the exception and log.
return False
return True


def _retry_overload(func: F) -> F:
@functools.wraps(func)
def inner(*args: Any, **kwargs: Any) -> Any:
def inner(self: Any, *args: Any, **kwargs: Any) -> Any:
retry_policy = self._retry_policy
attempt = 0
while True:
try:
return func(*args, **kwargs)
res = func(self, *args, **kwargs)
retry_policy.record_success(retry=attempt > 0)
return res
except PyMongoError as exc:
if not exc.has_error_label("Retryable"):
raise
attempt += 1
if attempt > _MAX_RETRIES:
delay = 0
if exc.has_error_label("SystemOverloaded"):
delay = retry_policy.backoff(attempt)
if not retry_policy.should_retry(attempt, delay):
raise

# Implement exponential backoff on retry.
if exc.has_error_label("SystemOverloaded"):
_backoff(attempt)
if delay:
time.sleep(delay)
continue

return cast(F, inner)
Expand Down
Loading
Loading