diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 741c11e551..dead0ed4dc 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -58,6 +58,7 @@ AsyncCursor, AsyncRawBatchCursor, ) +from pymongo.asynchronous.helpers import _retry_overload from pymongo.collation import validate_collation_or_none from pymongo.common import _ecoc_coll_name, _esc_coll_name from pymongo.errors import ( @@ -2227,6 +2228,7 @@ async def create_indexes( return await self._create_indexes(indexes, session, **kwargs) @_csot.apply + @_retry_overload async def _create_indexes( self, indexes: Sequence[IndexModel], session: Optional[AsyncClientSession], **kwargs: Any ) -> list[str]: @@ -2422,7 +2424,6 @@ async def drop_indexes( kwargs["comment"] = comment await self._drop_index("*", session=session, **kwargs) - @_csot.apply async def drop_index( self, index_or_name: _IndexKeyHint, @@ -2472,6 +2473,7 @@ async def drop_index( await self._drop_index(index_or_name, session, comment, **kwargs) @_csot.apply + @_retry_overload async def _drop_index( self, index_or_name: _IndexKeyHint, @@ -3079,6 +3081,7 @@ async def aggregate_raw_batches( ) @_csot.apply + @_retry_overload async def rename( self, new_name: str, diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index f70c2b403f..f3b35a0dcb 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -38,6 +38,7 @@ from pymongo.asynchronous.change_stream import AsyncDatabaseChangeStream from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.asynchronous.helpers import _retry_overload from pymongo.common import _ecoc_coll_name, _esc_coll_name from pymongo.database_shared import _check_name, _CodecDocumentType from pymongo.errors import CollectionInvalid, InvalidOperation @@ -477,6 +478,7 @@ async def watch( return change_stream @_csot.apply + @_retry_overload async def create_collection( self, name: str, @@ -816,6 +818,7 @@ async def command( ... @_csot.apply + @_retry_overload async def command( self, command: Union[str, MutableMapping[str, Any]], @@ -947,6 +950,7 @@ async def command( ) @_csot.apply + @_retry_overload async def cursor_command( self, command: Union[str, MutableMapping[str, Any]], @@ -1264,6 +1268,7 @@ async def _drop_helper( ) @_csot.apply + @_retry_overload async def drop_collection( self, name_or_collection: Union[str, AsyncCollection[_DocumentTypeArg]], diff --git a/pymongo/asynchronous/helpers.py b/pymongo/asynchronous/helpers.py index 54fd64f74a..49d5ec604e 100644 --- a/pymongo/asynchronous/helpers.py +++ b/pymongo/asynchronous/helpers.py @@ -17,8 +17,11 @@ import asyncio import builtins +import functools +import random import socket import sys +import time from typing import ( Any, Callable, @@ -28,6 +31,7 @@ from pymongo.errors import ( OperationFailure, + PyMongoError, ) from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE @@ -38,6 +42,7 @@ def _handle_reauth(func: F) -> F: + @functools.wraps(func) async def inner(*args: Any, **kwargs: Any) -> Any: no_reauth = kwargs.pop("no_reauth", False) from pymongo.asynchronous.pool import AsyncConnection @@ -70,6 +75,42 @@ async def inner(*args: Any, **kwargs: Any) -> Any: return cast(F, inner) +_MAX_RETRIES = 3 +_BACKOFF_INITIAL = 0.05 +_BACKOFF_MAX = 10 +_TIME = time + + +async def _backoff( + attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX +) -> None: + jitter = random.random() # noqa: S311 + backoff = jitter * min(initial_delay * (2**attempt), max_delay) + await asyncio.sleep(backoff) + + +def _retry_overload(func: F) -> F: + @functools.wraps(func) + async def inner(*args: Any, **kwargs: Any) -> Any: + attempt = 0 + while True: + try: + return await func(*args, **kwargs) + except PyMongoError as exc: + if not exc.has_error_label("Retryable"): + raise + attempt += 1 + if attempt > _MAX_RETRIES: + raise + + # Implement exponential backoff on retry. + if exc.has_error_label("SystemOverloaded"): + await _backoff(attempt) + continue + + return cast(F, inner) + + async def _getaddrinfo( host: Any, port: Any, **kwargs: Any ) -> list[ diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index b616647791..ae6e819334 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -67,6 +67,7 @@ from pymongo.asynchronous.client_bulk import _AsyncClientBulk from pymongo.asynchronous.client_session import _EmptyServerSession from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.asynchronous.helpers import _MAX_RETRIES, _backoff, _retry_overload from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext from pymongo.client_options import ClientOptions @@ -2398,6 +2399,7 @@ async def list_database_names( return [doc["name"] async for doc in res] @_csot.apply + @_retry_overload async def drop_database( self, name_or_database: Union[str, database.AsyncDatabase[_DocumentTypeArg]], @@ -2735,6 +2737,7 @@ def __init__( ): self._last_error: Optional[Exception] = None self._retrying = False + self._always_retryable = False self._multiple_retries = _csot.get_timeout() is not None self._client = mongo_client @@ -2783,14 +2786,22 @@ async def run(self) -> T: # most likely be a waste of time. raise except PyMongoError as exc: + always_retryable = False + overloaded = False + exc_to_check = exc # Execute specialized catch on read if self._is_read: if isinstance(exc, (ConnectionFailure, OperationFailure)): # ConnectionFailures do not supply a code property exc_code = getattr(exc, "code", None) - if self._is_not_eligible_for_retry() or ( - isinstance(exc, OperationFailure) - and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES + always_retryable = exc.has_error_label("Retryable") + overloaded = exc.has_error_label("SystemOverloaded") + if not always_retryable and ( + self._is_not_eligible_for_retry() + or ( + isinstance(exc, OperationFailure) + and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES + ) ): raise self._retrying = True @@ -2801,19 +2812,22 @@ async def run(self) -> T: # Specialized catch on write operation if not self._is_read: - if not self._retryable: + if isinstance(exc, ClientBulkWriteException) and isinstance( + exc.error, PyMongoError + ): + exc_to_check = exc.error + retryable_write_label = exc_to_check.has_error_label("RetryableWriteError") + always_retryable = exc_to_check.has_error_label("Retryable") + overloaded = exc_to_check.has_error_label("SystemOverloaded") + if not self._retryable and not always_retryable: raise - if isinstance(exc, ClientBulkWriteException) and exc.error: - retryable_write_error_exc = isinstance( - exc.error, PyMongoError - ) and exc.error.has_error_label("RetryableWriteError") - else: - retryable_write_error_exc = exc.has_error_label("RetryableWriteError") - if retryable_write_error_exc: + if retryable_write_label or always_retryable: assert self._session await self._session._unpin() - if not retryable_write_error_exc or self._is_not_eligible_for_retry(): - if exc.has_error_label("NoWritesPerformed") and self._last_error: + if not always_retryable and ( + not retryable_write_label or self._is_not_eligible_for_retry() + ): + if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: raise self._last_error from exc else: raise @@ -2822,7 +2836,7 @@ async def run(self) -> T: self._bulk.retrying = True else: self._retrying = True - if not exc.has_error_label("NoWritesPerformed"): + if not exc_to_check.has_error_label("NoWritesPerformed"): self._last_error = exc if self._last_error is None: self._last_error = exc @@ -2830,6 +2844,16 @@ async def run(self) -> T: if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded: self._deprioritized_servers.append(self._server) + self._always_retryable = always_retryable + if always_retryable: + if self._attempt_number > _MAX_RETRIES: + if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: + raise self._last_error from exc + else: + raise + if overloaded: + await _backoff(self._attempt_number) + def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" return not self._retryable or (self._is_retrying() and not self._multiple_retries) @@ -2891,7 +2915,7 @@ async def _write(self) -> T: and conn.supports_sessions ) is_mongos = conn.is_mongos - if not sessions_supported: + if not self._always_retryable and not sessions_supported: # A retry is not possible because this server does # not support sessions raise the last error. self._check_last_error() @@ -2923,7 +2947,7 @@ async def _read(self) -> T: conn, read_pref, ): - if self._retrying and not self._retryable: + if self._retrying and not self._retryable and not self._always_retryable: self._check_last_error() if self._retrying: _debug_log( diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 9f32deb765..3df867f7bc 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -89,6 +89,7 @@ Cursor, RawBatchCursor, ) +from pymongo.synchronous.helpers import _retry_overload from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean @@ -2224,6 +2225,7 @@ def create_indexes( return self._create_indexes(indexes, session, **kwargs) @_csot.apply + @_retry_overload def _create_indexes( self, indexes: Sequence[IndexModel], session: Optional[ClientSession], **kwargs: Any ) -> list[str]: @@ -2419,7 +2421,6 @@ def drop_indexes( kwargs["comment"] = comment self._drop_index("*", session=session, **kwargs) - @_csot.apply def drop_index( self, index_or_name: _IndexKeyHint, @@ -2469,6 +2470,7 @@ def drop_index( self._drop_index(index_or_name, session, comment, **kwargs) @_csot.apply + @_retry_overload def _drop_index( self, index_or_name: _IndexKeyHint, @@ -3072,6 +3074,7 @@ def aggregate_raw_batches( ) @_csot.apply + @_retry_overload def rename( self, new_name: str, diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index e30f97817c..d8b9ae6a10 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -43,6 +43,7 @@ from pymongo.synchronous.change_stream import DatabaseChangeStream from pymongo.synchronous.collection import Collection from pymongo.synchronous.command_cursor import CommandCursor +from pymongo.synchronous.helpers import _retry_overload from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline if TYPE_CHECKING: @@ -477,6 +478,7 @@ def watch( return change_stream @_csot.apply + @_retry_overload def create_collection( self, name: str, @@ -816,6 +818,7 @@ def command( ... @_csot.apply + @_retry_overload def command( self, command: Union[str, MutableMapping[str, Any]], @@ -945,6 +948,7 @@ def command( ) @_csot.apply + @_retry_overload def cursor_command( self, command: Union[str, MutableMapping[str, Any]], @@ -1257,6 +1261,7 @@ def _drop_helper( ) @_csot.apply + @_retry_overload def drop_collection( self, name_or_collection: Union[str, Collection[_DocumentTypeArg]], diff --git a/pymongo/synchronous/helpers.py b/pymongo/synchronous/helpers.py index bc69a49e80..889382b19c 100644 --- a/pymongo/synchronous/helpers.py +++ b/pymongo/synchronous/helpers.py @@ -17,8 +17,11 @@ import asyncio import builtins +import functools +import random import socket import sys +import time from typing import ( Any, Callable, @@ -28,6 +31,7 @@ from pymongo.errors import ( OperationFailure, + PyMongoError, ) from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE @@ -38,6 +42,7 @@ def _handle_reauth(func: F) -> F: + @functools.wraps(func) def inner(*args: Any, **kwargs: Any) -> Any: no_reauth = kwargs.pop("no_reauth", False) from pymongo.message import _BulkWriteContext @@ -70,6 +75,42 @@ def inner(*args: Any, **kwargs: Any) -> Any: return cast(F, inner) +_MAX_RETRIES = 3 +_BACKOFF_INITIAL = 0.05 +_BACKOFF_MAX = 10 +_TIME = time + + +def _backoff( + attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX +) -> None: + jitter = random.random() # noqa: S311 + backoff = jitter * min(initial_delay * (2**attempt), max_delay) + time.sleep(backoff) + + +def _retry_overload(func: F) -> F: + @functools.wraps(func) + def inner(*args: Any, **kwargs: Any) -> Any: + attempt = 0 + while True: + try: + return func(*args, **kwargs) + except PyMongoError as exc: + if not exc.has_error_label("Retryable"): + raise + attempt += 1 + if attempt > _MAX_RETRIES: + raise + + # Implement exponential backoff on retry. + if exc.has_error_label("SystemOverloaded"): + _backoff(attempt) + continue + + return cast(F, inner) + + def _getaddrinfo( host: Any, port: Any, **kwargs: Any ) -> list[ diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index ef0663584c..dcd8c50cca 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -110,6 +110,7 @@ from pymongo.synchronous.client_bulk import _ClientBulk from pymongo.synchronous.client_session import _EmptyServerSession from pymongo.synchronous.command_cursor import CommandCursor +from pymongo.synchronous.helpers import _MAX_RETRIES, _backoff, _retry_overload from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription @@ -2388,6 +2389,7 @@ def list_database_names( return [doc["name"] for doc in res] @_csot.apply + @_retry_overload def drop_database( self, name_or_database: Union[str, database.Database[_DocumentTypeArg]], @@ -2725,6 +2727,7 @@ def __init__( ): self._last_error: Optional[Exception] = None self._retrying = False + self._always_retryable = False self._multiple_retries = _csot.get_timeout() is not None self._client = mongo_client @@ -2773,14 +2776,22 @@ def run(self) -> T: # most likely be a waste of time. raise except PyMongoError as exc: + always_retryable = False + overloaded = False + exc_to_check = exc # Execute specialized catch on read if self._is_read: if isinstance(exc, (ConnectionFailure, OperationFailure)): # ConnectionFailures do not supply a code property exc_code = getattr(exc, "code", None) - if self._is_not_eligible_for_retry() or ( - isinstance(exc, OperationFailure) - and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES + always_retryable = exc.has_error_label("Retryable") + overloaded = exc.has_error_label("SystemOverloaded") + if not always_retryable and ( + self._is_not_eligible_for_retry() + or ( + isinstance(exc, OperationFailure) + and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES + ) ): raise self._retrying = True @@ -2791,19 +2802,22 @@ def run(self) -> T: # Specialized catch on write operation if not self._is_read: - if not self._retryable: + if isinstance(exc, ClientBulkWriteException) and isinstance( + exc.error, PyMongoError + ): + exc_to_check = exc.error + retryable_write_label = exc_to_check.has_error_label("RetryableWriteError") + always_retryable = exc_to_check.has_error_label("Retryable") + overloaded = exc_to_check.has_error_label("SystemOverloaded") + if not self._retryable and not always_retryable: raise - if isinstance(exc, ClientBulkWriteException) and exc.error: - retryable_write_error_exc = isinstance( - exc.error, PyMongoError - ) and exc.error.has_error_label("RetryableWriteError") - else: - retryable_write_error_exc = exc.has_error_label("RetryableWriteError") - if retryable_write_error_exc: + if retryable_write_label or always_retryable: assert self._session self._session._unpin() - if not retryable_write_error_exc or self._is_not_eligible_for_retry(): - if exc.has_error_label("NoWritesPerformed") and self._last_error: + if not always_retryable and ( + not retryable_write_label or self._is_not_eligible_for_retry() + ): + if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: raise self._last_error from exc else: raise @@ -2812,7 +2826,7 @@ def run(self) -> T: self._bulk.retrying = True else: self._retrying = True - if not exc.has_error_label("NoWritesPerformed"): + if not exc_to_check.has_error_label("NoWritesPerformed"): self._last_error = exc if self._last_error is None: self._last_error = exc @@ -2820,6 +2834,16 @@ def run(self) -> T: if self._client.topology_description.topology_type == TOPOLOGY_TYPE.Sharded: self._deprioritized_servers.append(self._server) + self._always_retryable = always_retryable + if always_retryable: + if self._attempt_number > _MAX_RETRIES: + if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: + raise self._last_error from exc + else: + raise + if overloaded: + _backoff(self._attempt_number) + def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" return not self._retryable or (self._is_retrying() and not self._multiple_retries) @@ -2881,7 +2905,7 @@ def _write(self) -> T: and conn.supports_sessions ) is_mongos = conn.is_mongos - if not sessions_supported: + if not self._always_retryable and not sessions_supported: # A retry is not possible because this server does # not support sessions raise the last error. self._check_last_error() @@ -2913,7 +2937,7 @@ def _read(self) -> T: conn, read_pref, ): - if self._retrying and not self._retryable: + if self._retrying and not self._retryable and not self._always_retryable: self._check_last_error() if self._retrying: _debug_log( diff --git a/test/asynchronous/test_backpressure.py b/test/asynchronous/test_backpressure.py new file mode 100644 index 0000000000..a9a6fb56f5 --- /dev/null +++ b/test/asynchronous/test_backpressure.py @@ -0,0 +1,155 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test Client Backpressure spec.""" +from __future__ import annotations + +import sys + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest + +from pymongo.asynchronous.helpers import _MAX_RETRIES +from pymongo.errors import PyMongoError + +_IS_SYNC = False + +# Mock an system overload error. +mock_overload_error = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["find", "insert", "update"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["Retryable"], + }, +} + + +class TestBackpressure(AsyncIntegrationTest): + RUN_ON_LOAD_BALANCER = True + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_command(self): + await self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + async with self.fail_point(fail_many): + await self.db.command("find", "t") + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await self.db.command("find", "t") + + self.assertIn("Retryable", str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_find(self): + await self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + async with self.fail_point(fail_many): + await self.db.t.find_one() + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await self.db.t.find_one() + + self.assertIn("Retryable", str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_insert_one(self): + await self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + async with self.fail_point(fail_many): + await self.db.t.find_one() + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await self.db.t.find_one() + + self.assertIn("Retryable", str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_update_many(self): + # Even though update_many is not a retryable write operation, it will + # still be retried via the "Retryable" error label. + await self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + async with self.fail_point(fail_many): + await self.db.t.update_many({}, {"$set": {"x": 2}}) + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await self.db.t.update_many({}, {"$set": {"x": 2}}) + + self.assertIn("Retryable", str(error.exception)) + + @async_client_context.require_failCommand_appName + async def test_retry_overload_error_getMore(self): + coll = self.db.t + await coll.insert_many([{"x": 1} for _ in range(10)]) + + # Ensure command is retried on overload error. + fail_many = { + "configureFailPoint": "failCommand", + "mode": {"times": _MAX_RETRIES}, + "data": { + "failCommands": ["getMore"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["Retryable"], + }, + } + cursor = coll.find(batch_size=2) + await cursor.next() + async with self.fail_point(fail_many): + await cursor.to_list() + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = fail_many.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + cursor = coll.find(batch_size=2) + await cursor.next() + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await cursor.to_list() + + self.assertIn("Retryable", str(error.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_backpressure.py b/test/test_backpressure.py new file mode 100644 index 0000000000..324dd6f15a --- /dev/null +++ b/test/test_backpressure.py @@ -0,0 +1,155 @@ +# Copyright 2025-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test Client Backpressure spec.""" +from __future__ import annotations + +import sys + +sys.path[0:0] = [""] + +from test import IntegrationTest, client_context, unittest + +from pymongo.errors import PyMongoError +from pymongo.synchronous.helpers import _MAX_RETRIES + +_IS_SYNC = True + +# Mock an system overload error. +mock_overload_error = { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["find", "insert", "update"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["Retryable"], + }, +} + + +class TestBackpressure(IntegrationTest): + RUN_ON_LOAD_BALANCER = True + + @client_context.require_failCommand_appName + def test_retry_overload_error_command(self): + self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + with self.fail_point(fail_many): + self.db.command("find", "t") + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + self.db.command("find", "t") + + self.assertIn("Retryable", str(error.exception)) + + @client_context.require_failCommand_appName + def test_retry_overload_error_find(self): + self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + with self.fail_point(fail_many): + self.db.t.find_one() + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + self.db.t.find_one() + + self.assertIn("Retryable", str(error.exception)) + + @client_context.require_failCommand_appName + def test_retry_overload_error_insert_one(self): + self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + with self.fail_point(fail_many): + self.db.t.find_one() + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + self.db.t.find_one() + + self.assertIn("Retryable", str(error.exception)) + + @client_context.require_failCommand_appName + def test_retry_overload_error_update_many(self): + # Even though update_many is not a retryable write operation, it will + # still be retried via the "Retryable" error label. + self.db.t.insert_one({"x": 1}) + + # Ensure command is retried on overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": _MAX_RETRIES} + with self.fail_point(fail_many): + self.db.t.update_many({}, {"$set": {"x": 2}}) + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + self.db.t.update_many({}, {"$set": {"x": 2}}) + + self.assertIn("Retryable", str(error.exception)) + + @client_context.require_failCommand_appName + def test_retry_overload_error_getMore(self): + coll = self.db.t + coll.insert_many([{"x": 1} for _ in range(10)]) + + # Ensure command is retried on overload error. + fail_many = { + "configureFailPoint": "failCommand", + "mode": {"times": _MAX_RETRIES}, + "data": { + "failCommands": ["getMore"], + "errorCode": 462, # IngressRequestRateLimitExceeded + "errorLabels": ["Retryable"], + }, + } + cursor = coll.find(batch_size=2) + cursor.next() + with self.fail_point(fail_many): + cursor.to_list() + + # Ensure command stops retrying after _MAX_RETRIES. + fail_too_many = fail_many.copy() + fail_too_many["mode"] = {"times": _MAX_RETRIES + 1} + cursor = coll.find(batch_size=2) + cursor.next() + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + cursor.to_list() + + self.assertIn("Retryable", str(error.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/synchro.py b/tools/synchro.py index 9a760c0ad7..44698134cd 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -208,6 +208,7 @@ def async_only_test(f: str) -> bool: "test_auth_oidc.py", "test_auth_spec.py", "test_bulk.py", + "test_backpressure.py", "test_change_stream.py", "test_client.py", "test_client_bulk_write.py",