Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pymongo/asynchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def __init__(
unicode_decode_error_handler="replace", document_class=dict
)
self._timeout = database.client.options.timeout
self._retry_policy = database.client._retry_policy

if create or kwargs:
if _IS_SYNC:
Expand Down
1 change: 1 addition & 0 deletions pymongo/asynchronous/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(
self._name = name
self._client: AsyncMongoClient[_DocumentType] = client
self._timeout = client.options.timeout
self._retry_policy = client._retry_policy

@property
def client(self) -> AsyncMongoClient[_DocumentType]:
Expand Down
86 changes: 81 additions & 5 deletions pymongo/asynchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
PyMongoError,
)
from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE
from pymongo.lock import _async_create_lock

_IS_SYNC = False

Expand Down Expand Up @@ -78,7 +79,10 @@ async def inner(*args: Any, **kwargs: Any) -> Any:
_MAX_RETRIES = 3
_BACKOFF_INITIAL = 0.05
_BACKOFF_MAX = 10
_TIME = time
# DRIVERS-3240 will determine these defaults.
DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0
DEFAULT_RETRY_TOKEN_RETURN = 0.1
_TIME = time # Added so synchro script doesn't remove the time import.


async def _backoff(
Expand All @@ -89,23 +93,95 @@ async def _backoff(
await asyncio.sleep(backoff)


class _TokenBucket:
"""A token bucket implementation for rate limiting."""

def __init__(
self,
capacity: float = DEFAULT_RETRY_TOKEN_CAPACITY,
return_rate: float = DEFAULT_RETRY_TOKEN_RETURN,
):
self.lock = _async_create_lock()
self.capacity = capacity
# DRIVERS-3240 will determine how full the bucket should start.
self.tokens = capacity
self.return_rate = return_rate

async def consume(self) -> bool:
"""Consume a token from the bucket if available."""
async with self.lock:
if self.tokens >= 1:
self.tokens -= 1
return True
return False

async def deposit(self, retry: bool = False) -> None:
"""Deposit a token back into the bucket."""
retry_token = 1 if retry else 0
async with self.lock:
self.tokens = min(self.capacity, self.tokens + retry_token + self.return_rate)


class _RetryPolicy:
"""A retry limiter that performs exponential backoff with jitter.

Retry attempts are limited by a token bucket to prevent overwhelming the server during
a prolonged outage or high load.
"""

def __init__(
self,
token_bucket: _TokenBucket,
attempts: int = _MAX_RETRIES,
backoff_initial: float = _BACKOFF_INITIAL,
backoff_max: float = _BACKOFF_MAX,
):
self.token_bucket = token_bucket
self.attempts = attempts
self.backoff_initial = backoff_initial
self.backoff_max = backoff_max

async def record_success(self, retry: bool) -> None:
"""Record a successful operation."""
await self.token_bucket.deposit(retry)

async def backoff(self, attempt: int) -> None:
"""Return the backoff duration for the given ."""
await _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)

async def should_retry(self, attempt: int) -> bool:
"""Return if we have budget to retry and how long to backoff."""
# TODO: Check CSOT deadline here.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this for a follow up PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemented it now.

if attempt > self.attempts:
return False
# Check token bucket last since we only want to consume a token if we actually retry.
if not await self.token_bucket.consume():
# DRIVERS-3246 Improve diagnostics when this case happens.
# We could add info to the exception and log.
return False
return True


def _retry_overload(func: F) -> F:
@functools.wraps(func)
async def inner(*args: Any, **kwargs: Any) -> Any:
async def inner(self: Any, *args: Any, **kwargs: Any) -> Any:
retry_policy = self._retry_policy
attempt = 0
while True:
try:
return await func(*args, **kwargs)
res = await func(self, *args, **kwargs)
await retry_policy.record_success(retry=attempt > 0)
return res
except PyMongoError as exc:
if not exc.has_error_label("Retryable"):
raise
attempt += 1
if attempt > _MAX_RETRIES:
if not await retry_policy.should_retry(attempt):
raise

# Implement exponential backoff on retry.
if exc.has_error_label("SystemOverloaded"):
await _backoff(attempt)
await retry_policy.backoff(attempt)
continue

return cast(F, inner)
Expand Down
17 changes: 12 additions & 5 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.asynchronous.client_session import _EmptyServerSession
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.helpers import _MAX_RETRIES, _backoff, _retry_overload
from pymongo.asynchronous.helpers import (
_retry_overload,
_RetryPolicy,
_TokenBucket,
)
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology, _ErrorContext
from pymongo.client_options import ClientOptions
Expand Down Expand Up @@ -774,6 +778,7 @@ def __init__(
self._timeout: float | None = None
self._topology_settings: TopologySettings = None # type: ignore[assignment]
self._event_listeners: _EventListeners | None = None
self._retry_policy = _RetryPolicy(_TokenBucket())

# _pool_class, _monitor_class, and _condition_class are for deep
# customization of PyMongo, e.g. Motor.
Expand Down Expand Up @@ -2740,7 +2745,7 @@ def __init__(
self._always_retryable = False
self._multiple_retries = _csot.get_timeout() is not None
self._client = mongo_client

self._retry_policy = mongo_client._retry_policy
self._func = func
self._bulk = bulk
self._session = session
Expand Down Expand Up @@ -2775,7 +2780,9 @@ async def run(self) -> T:
while True:
self._check_last_error(check_csot=True)
try:
return await self._read() if self._is_read else await self._write()
res = await self._read() if self._is_read else await self._write()
await self._retry_policy.record_success(self._attempt_number > 0)
return res
except ServerSelectionTimeoutError:
# The application may think the write was never attempted
# if we raise ServerSelectionTimeoutError on the retry
Expand Down Expand Up @@ -2846,13 +2853,13 @@ async def run(self) -> T:

self._always_retryable = always_retryable
if always_retryable:
if self._attempt_number > _MAX_RETRIES:
if not await self._retry_policy.should_retry(self._attempt_number):
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc
else:
raise
if overloaded:
await _backoff(self._attempt_number)
await self._retry_policy.backoff(self._attempt_number)

def _is_not_eligible_for_retry(self) -> bool:
"""Checks if the exchange is not eligible for retry"""
Expand Down
1 change: 1 addition & 0 deletions pymongo/synchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def __init__(
unicode_decode_error_handler="replace", document_class=dict
)
self._timeout = database.client.options.timeout
self._retry_policy = database.client._retry_policy

if create or kwargs:
if _IS_SYNC:
Expand Down
1 change: 1 addition & 0 deletions pymongo/synchronous/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(
self._name = name
self._client: MongoClient[_DocumentType] = client
self._timeout = client.options.timeout
self._retry_policy = client._retry_policy

@property
def client(self) -> MongoClient[_DocumentType]:
Expand Down
86 changes: 81 additions & 5 deletions pymongo/synchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
PyMongoError,
)
from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE
from pymongo.lock import _create_lock

_IS_SYNC = True

Expand Down Expand Up @@ -78,7 +79,10 @@ def inner(*args: Any, **kwargs: Any) -> Any:
_MAX_RETRIES = 3
_BACKOFF_INITIAL = 0.05
_BACKOFF_MAX = 10
_TIME = time
# DRIVERS-3240 will determine these defaults.
DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0
DEFAULT_RETRY_TOKEN_RETURN = 0.1
_TIME = time # Added so synchro script doesn't remove the time import.


def _backoff(
Expand All @@ -89,23 +93,95 @@ def _backoff(
time.sleep(backoff)


class _TokenBucket:
"""A token bucket implementation for rate limiting."""

def __init__(
self,
capacity: float = DEFAULT_RETRY_TOKEN_CAPACITY,
return_rate: float = DEFAULT_RETRY_TOKEN_RETURN,
):
self.lock = _create_lock()
self.capacity = capacity
# DRIVERS-3240 will determine how full the bucket should start.
self.tokens = capacity
self.return_rate = return_rate

def consume(self) -> bool:
"""Consume a token from the bucket if available."""
with self.lock:
if self.tokens >= 1:
self.tokens -= 1
return True
return False

def deposit(self, retry: bool = False) -> None:
"""Deposit a token back into the bucket."""
retry_token = 1 if retry else 0
with self.lock:
self.tokens = min(self.capacity, self.tokens + retry_token + self.return_rate)


class _RetryPolicy:
"""A retry limiter that performs exponential backoff with jitter.

Retry attempts are limited by a token bucket to prevent overwhelming the server during
a prolonged outage or high load.
"""

def __init__(
self,
token_bucket: _TokenBucket,
attempts: int = _MAX_RETRIES,
backoff_initial: float = _BACKOFF_INITIAL,
backoff_max: float = _BACKOFF_MAX,
):
self.token_bucket = token_bucket
self.attempts = attempts
self.backoff_initial = backoff_initial
self.backoff_max = backoff_max

def record_success(self, retry: bool) -> None:
"""Record a successful operation."""
self.token_bucket.deposit(retry)

def backoff(self, attempt: int) -> None:
"""Return the backoff duration for the given ."""
_backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)

def should_retry(self, attempt: int) -> bool:
"""Return if we have budget to retry and how long to backoff."""
# TODO: Check CSOT deadline here.
if attempt > self.attempts:
return False
# Check token bucket last since we only want to consume a token if we actually retry.
if not self.token_bucket.consume():
# DRIVERS-3246 Improve diagnostics when this case happens.
# We could add info to the exception and log.
return False
return True


def _retry_overload(func: F) -> F:
@functools.wraps(func)
def inner(*args: Any, **kwargs: Any) -> Any:
def inner(self: Any, *args: Any, **kwargs: Any) -> Any:
retry_policy = self._retry_policy
attempt = 0
while True:
try:
return func(*args, **kwargs)
res = func(self, *args, **kwargs)
retry_policy.record_success(retry=attempt > 0)
return res
except PyMongoError as exc:
if not exc.has_error_label("Retryable"):
raise
attempt += 1
if attempt > _MAX_RETRIES:
if not retry_policy.should_retry(attempt):
raise

# Implement exponential backoff on retry.
if exc.has_error_label("SystemOverloaded"):
_backoff(attempt)
retry_policy.backoff(attempt)
continue

return cast(F, inner)
Expand Down
17 changes: 12 additions & 5 deletions pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,11 @@
from pymongo.synchronous.client_bulk import _ClientBulk
from pymongo.synchronous.client_session import _EmptyServerSession
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.helpers import _MAX_RETRIES, _backoff, _retry_overload
from pymongo.synchronous.helpers import (
_retry_overload,
_RetryPolicy,
_TokenBucket,
)
from pymongo.synchronous.settings import TopologySettings
from pymongo.synchronous.topology import Topology, _ErrorContext
from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription
Expand Down Expand Up @@ -774,6 +778,7 @@ def __init__(
self._timeout: float | None = None
self._topology_settings: TopologySettings = None # type: ignore[assignment]
self._event_listeners: _EventListeners | None = None
self._retry_policy = _RetryPolicy(_TokenBucket())

# _pool_class, _monitor_class, and _condition_class are for deep
# customization of PyMongo, e.g. Motor.
Expand Down Expand Up @@ -2730,7 +2735,7 @@ def __init__(
self._always_retryable = False
self._multiple_retries = _csot.get_timeout() is not None
self._client = mongo_client

self._retry_policy = mongo_client._retry_policy
self._func = func
self._bulk = bulk
self._session = session
Expand Down Expand Up @@ -2765,7 +2770,9 @@ def run(self) -> T:
while True:
self._check_last_error(check_csot=True)
try:
return self._read() if self._is_read else self._write()
res = self._read() if self._is_read else self._write()
self._retry_policy.record_success(self._attempt_number > 0)
return res
except ServerSelectionTimeoutError:
# The application may think the write was never attempted
# if we raise ServerSelectionTimeoutError on the retry
Expand Down Expand Up @@ -2836,13 +2843,13 @@ def run(self) -> T:

self._always_retryable = always_retryable
if always_retryable:
if self._attempt_number > _MAX_RETRIES:
if not self._retry_policy.should_retry(self._attempt_number):
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc
else:
raise
if overloaded:
_backoff(self._attempt_number)
self._retry_policy.backoff(self._attempt_number)

def _is_not_eligible_for_retry(self) -> bool:
"""Checks if the exchange is not eligible for retry"""
Expand Down
Loading
Loading