Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pymongo/asynchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
AsyncCursor,
AsyncRawBatchCursor,
)
from pymongo.asynchronous.helpers import _retry_overload
from pymongo.collation import validate_collation_or_none
from pymongo.common import _ecoc_coll_name, _esc_coll_name
from pymongo.errors import (
Expand Down Expand Up @@ -2227,6 +2228,7 @@ async def create_indexes(
return await self._create_indexes(indexes, session, **kwargs)

@_csot.apply
@_retry_overload
async def _create_indexes(
self, indexes: Sequence[IndexModel], session: Optional[AsyncClientSession], **kwargs: Any
) -> list[str]:
Expand Down Expand Up @@ -2422,7 +2424,6 @@ async def drop_indexes(
kwargs["comment"] = comment
await self._drop_index("*", session=session, **kwargs)

@_csot.apply
async def drop_index(
self,
index_or_name: _IndexKeyHint,
Expand Down Expand Up @@ -2472,6 +2473,7 @@ async def drop_index(
await self._drop_index(index_or_name, session, comment, **kwargs)

@_csot.apply
@_retry_overload
async def _drop_index(
self,
index_or_name: _IndexKeyHint,
Expand Down Expand Up @@ -3079,6 +3081,7 @@ async def aggregate_raw_batches(
)

@_csot.apply
@_retry_overload
async def rename(
self,
new_name: str,
Expand Down
5 changes: 5 additions & 0 deletions pymongo/asynchronous/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from pymongo.asynchronous.change_stream import AsyncDatabaseChangeStream
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.helpers import _retry_overload
from pymongo.common import _ecoc_coll_name, _esc_coll_name
from pymongo.database_shared import _check_name, _CodecDocumentType
from pymongo.errors import CollectionInvalid, InvalidOperation
Expand Down Expand Up @@ -477,6 +478,7 @@ async def watch(
return change_stream

@_csot.apply
@_retry_overload
async def create_collection(
self,
name: str,
Expand Down Expand Up @@ -816,6 +818,7 @@ async def command(
...

@_csot.apply
@_retry_overload
async def command(
self,
command: Union[str, MutableMapping[str, Any]],
Expand Down Expand Up @@ -947,6 +950,7 @@ async def command(
)

@_csot.apply
@_retry_overload
async def cursor_command(
self,
command: Union[str, MutableMapping[str, Any]],
Expand Down Expand Up @@ -1264,6 +1268,7 @@ async def _drop_helper(
)

@_csot.apply
@_retry_overload
async def drop_collection(
self,
name_or_collection: Union[str, AsyncCollection[_DocumentTypeArg]],
Expand Down
41 changes: 41 additions & 0 deletions pymongo/asynchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@

import asyncio
import builtins
import functools
import random
import socket
import sys
import time
from typing import (
Any,
Callable,
Expand All @@ -28,6 +31,7 @@

from pymongo.errors import (
OperationFailure,
PyMongoError,
)
from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE

Expand All @@ -38,6 +42,7 @@


def _handle_reauth(func: F) -> F:
@functools.wraps(func)
async def inner(*args: Any, **kwargs: Any) -> Any:
no_reauth = kwargs.pop("no_reauth", False)
from pymongo.asynchronous.pool import AsyncConnection
Expand Down Expand Up @@ -70,6 +75,42 @@ async def inner(*args: Any, **kwargs: Any) -> Any:
return cast(F, inner)


_MAX_RETRIES = 3
_BACKOFF_INITIAL = 0.05
_BACKOFF_MAX = 10
_TIME = time


async def _backoff(
attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX
) -> None:
jitter = random.random() # noqa: S311
backoff = jitter * min(initial_delay * (2**attempt), max_delay)
await asyncio.sleep(backoff)


def _retry_overload(func: F) -> F:
@functools.wraps(func)
async def inner(*args: Any, **kwargs: Any) -> Any:
attempt = 0
while True:
try:
return await func(*args, **kwargs)
except PyMongoError as exc:
if not exc.has_error_label("Retryable"):
raise
attempt += 1
if attempt > _MAX_RETRIES:
raise

# Implement exponential backoff on retry.
if exc.has_error_label("SystemOverloaded"):
await _backoff(attempt)
continue

return cast(F, inner)


async def _getaddrinfo(
host: Any, port: Any, **kwargs: Any
) -> list[
Expand Down
56 changes: 40 additions & 16 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.asynchronous.client_session import _EmptyServerSession
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.helpers import _MAX_RETRIES, _backoff, _retry_overload
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology, _ErrorContext
from pymongo.client_options import ClientOptions
Expand Down Expand Up @@ -2398,6 +2399,7 @@ async def list_database_names(
return [doc["name"] async for doc in res]

@_csot.apply
@_retry_overload
async def drop_database(
self,
name_or_database: Union[str, database.AsyncDatabase[_DocumentTypeArg]],
Expand Down Expand Up @@ -2735,6 +2737,7 @@ def __init__(
):
self._last_error: Optional[Exception] = None
self._retrying = False
self._always_retryable = False
self._multiple_retries = _csot.get_timeout() is not None
self._client = mongo_client

Expand Down Expand Up @@ -2783,14 +2786,22 @@ async def run(self) -> T:
# most likely be a waste of time.
raise
except PyMongoError as exc:
always_retryable = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always_retryable represents what exactly? That an error is retryable under all conditions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it means it's safe to retry even if the operation is otherwise non-retryable, like an update_many. There's probably a better name for it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overloaded still takes precedence over always_retryable, correct?

overloaded = False
exc_to_check = exc
# Execute specialized catch on read
if self._is_read:
if isinstance(exc, (ConnectionFailure, OperationFailure)):
# ConnectionFailures do not supply a code property
exc_code = getattr(exc, "code", None)
if self._is_not_eligible_for_retry() or (
isinstance(exc, OperationFailure)
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
always_retryable = exc.has_error_label("Retryable")
overloaded = exc.has_error_label("SystemOverloaded")
if not always_retryable and (
self._is_not_eligible_for_retry()
or (
isinstance(exc, OperationFailure)
and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES
)
):
raise
self._retrying = True
Expand All @@ -2801,19 +2812,22 @@ async def run(self) -> T:

# Specialized catch on write operation
if not self._is_read:
if not self._retryable:
if isinstance(exc, ClientBulkWriteException) and isinstance(
exc.error, PyMongoError
):
exc_to_check = exc.error
retryable_write_label = exc_to_check.has_error_label("RetryableWriteError")
always_retryable = exc_to_check.has_error_label("Retryable")
overloaded = exc_to_check.has_error_label("SystemOverloaded")
if not self._retryable and not always_retryable:
raise
if isinstance(exc, ClientBulkWriteException) and exc.error:
retryable_write_error_exc = isinstance(
exc.error, PyMongoError
) and exc.error.has_error_label("RetryableWriteError")
else:
retryable_write_error_exc = exc.has_error_label("RetryableWriteError")
if retryable_write_error_exc:
if retryable_write_label or always_retryable:
assert self._session
await self._session._unpin()
if not retryable_write_error_exc or self._is_not_eligible_for_retry():
if exc.has_error_label("NoWritesPerformed") and self._last_error:
if not always_retryable and (
not retryable_write_label or self._is_not_eligible_for_retry()
):
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc
else:
raise
Expand All @@ -2822,14 +2836,24 @@ async def run(self) -> T:
self._bulk.retrying = True
else:
self._retrying = True
if not exc.has_error_label("NoWritesPerformed"):
if not exc_to_check.has_error_label("NoWritesPerformed"):
self._last_error = exc
if self._last_error is None:
self._last_error = exc

if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded:
self._deprioritized_servers.append(self._server)

self._always_retryable = always_retryable
if always_retryable:
if self._attempt_number > _MAX_RETRIES:
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
raise self._last_error from exc
else:
raise
if overloaded:
await _backoff(self._attempt_number)

def _is_not_eligible_for_retry(self) -> bool:
"""Checks if the exchange is not eligible for retry"""
return not self._retryable or (self._is_retrying() and not self._multiple_retries)
Expand Down Expand Up @@ -2891,7 +2915,7 @@ async def _write(self) -> T:
and conn.supports_sessions
)
is_mongos = conn.is_mongos
if not sessions_supported:
if not self._always_retryable and not sessions_supported:
# A retry is not possible because this server does
# not support sessions raise the last error.
self._check_last_error()
Expand Down Expand Up @@ -2923,7 +2947,7 @@ async def _read(self) -> T:
conn,
read_pref,
):
if self._retrying and not self._retryable:
if self._retrying and not self._retryable and not self._always_retryable:
self._check_last_error()
if self._retrying:
_debug_log(
Expand Down
5 changes: 4 additions & 1 deletion pymongo/synchronous/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
Cursor,
RawBatchCursor,
)
from pymongo.synchronous.helpers import _retry_overload
from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline
from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean

Expand Down Expand Up @@ -2224,6 +2225,7 @@ def create_indexes(
return self._create_indexes(indexes, session, **kwargs)

@_csot.apply
@_retry_overload
def _create_indexes(
self, indexes: Sequence[IndexModel], session: Optional[ClientSession], **kwargs: Any
) -> list[str]:
Expand Down Expand Up @@ -2419,7 +2421,6 @@ def drop_indexes(
kwargs["comment"] = comment
self._drop_index("*", session=session, **kwargs)

@_csot.apply
def drop_index(
self,
index_or_name: _IndexKeyHint,
Expand Down Expand Up @@ -2469,6 +2470,7 @@ def drop_index(
self._drop_index(index_or_name, session, comment, **kwargs)

@_csot.apply
@_retry_overload
def _drop_index(
self,
index_or_name: _IndexKeyHint,
Expand Down Expand Up @@ -3072,6 +3074,7 @@ def aggregate_raw_batches(
)

@_csot.apply
@_retry_overload
def rename(
self,
new_name: str,
Expand Down
5 changes: 5 additions & 0 deletions pymongo/synchronous/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from pymongo.synchronous.change_stream import DatabaseChangeStream
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.helpers import _retry_overload
from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline

if TYPE_CHECKING:
Expand Down Expand Up @@ -477,6 +478,7 @@ def watch(
return change_stream

@_csot.apply
@_retry_overload
def create_collection(
self,
name: str,
Expand Down Expand Up @@ -816,6 +818,7 @@ def command(
...

@_csot.apply
@_retry_overload
def command(
self,
command: Union[str, MutableMapping[str, Any]],
Expand Down Expand Up @@ -945,6 +948,7 @@ def command(
)

@_csot.apply
@_retry_overload
def cursor_command(
self,
command: Union[str, MutableMapping[str, Any]],
Expand Down Expand Up @@ -1257,6 +1261,7 @@ def _drop_helper(
)

@_csot.apply
@_retry_overload
def drop_collection(
self,
name_or_collection: Union[str, Collection[_DocumentTypeArg]],
Expand Down
Loading
Loading