diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index dead0ed4dc..6ff62e9fe3 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -253,6 +253,7 @@ def __init__( unicode_decode_error_handler="replace", document_class=dict ) self._timeout = database.client.options.timeout + self._retry_policy = database.client._retry_policy if create or kwargs: if _IS_SYNC: diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index f3b35a0dcb..8abc7059d0 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -136,6 +136,7 @@ def __init__( self._name = name self._client: AsyncMongoClient[_DocumentType] = client self._timeout = client.options.timeout + self._retry_policy = client._retry_policy @property def client(self) -> AsyncMongoClient[_DocumentType]: diff --git a/pymongo/asynchronous/helpers.py b/pymongo/asynchronous/helpers.py index 49d5ec604e..6ef3beacf5 100644 --- a/pymongo/asynchronous/helpers.py +++ b/pymongo/asynchronous/helpers.py @@ -21,7 +21,7 @@ import random import socket import sys -import time +import time as time # noqa: PLC0414 # needed in sync version from typing import ( Any, Callable, @@ -29,11 +29,13 @@ cast, ) +from pymongo import _csot from pymongo.errors import ( OperationFailure, PyMongoError, ) from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE +from pymongo.lock import _async_create_lock _IS_SYNC = False @@ -78,34 +80,115 @@ async def inner(*args: Any, **kwargs: Any) -> Any: _MAX_RETRIES = 3 _BACKOFF_INITIAL = 0.05 _BACKOFF_MAX = 10 -_TIME = time +# DRIVERS-3240 will determine these defaults. +DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0 +DEFAULT_RETRY_TOKEN_RETURN = 0.1 -async def _backoff( +def _backoff( attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX -) -> None: +) -> float: jitter = random.random() # noqa: S311 - backoff = jitter * min(initial_delay * (2**attempt), max_delay) - await asyncio.sleep(backoff) + return jitter * min(initial_delay * (2**attempt), max_delay) + + +class _TokenBucket: + """A token bucket implementation for rate limiting.""" + + def __init__( + self, + capacity: float = DEFAULT_RETRY_TOKEN_CAPACITY, + return_rate: float = DEFAULT_RETRY_TOKEN_RETURN, + ): + self.lock = _async_create_lock() + self.capacity = capacity + # DRIVERS-3240 will determine how full the bucket should start. + self.tokens = capacity + self.return_rate = return_rate + + async def consume(self) -> bool: + """Consume a token from the bucket if available.""" + async with self.lock: + if self.tokens >= 1: + self.tokens -= 1 + return True + return False + + async def deposit(self, retry: bool = False) -> None: + """Deposit a token back into the bucket.""" + retry_token = 1 if retry else 0 + async with self.lock: + self.tokens = min(self.capacity, self.tokens + retry_token + self.return_rate) + + +class _RetryPolicy: + """A retry limiter that performs exponential backoff with jitter. + + Retry attempts are limited by a token bucket to prevent overwhelming the server during + a prolonged outage or high load. + """ + + def __init__( + self, + token_bucket: _TokenBucket, + attempts: int = _MAX_RETRIES, + backoff_initial: float = _BACKOFF_INITIAL, + backoff_max: float = _BACKOFF_MAX, + ): + self.token_bucket = token_bucket + self.attempts = attempts + self.backoff_initial = backoff_initial + self.backoff_max = backoff_max + + async def record_success(self, retry: bool) -> None: + """Record a successful operation.""" + await self.token_bucket.deposit(retry) + + def backoff(self, attempt: int) -> float: + """Return the backoff duration for the given .""" + return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max) + + async def should_retry(self, attempt: int, delay: float) -> bool: + """Return if we have budget to retry and how long to backoff.""" + if attempt > self.attempts: + return False + + # If the delay would exceed the deadline, bail early before consuming a token. + if _csot.get_timeout(): + if time.monotonic() + delay > _csot.get_deadline(): + return False + + # Check token bucket last since we only want to consume a token if we actually retry. + if not await self.token_bucket.consume(): + # DRIVERS-3246 Improve diagnostics when this case happens. + # We could add info to the exception and log. + return False + return True def _retry_overload(func: F) -> F: @functools.wraps(func) - async def inner(*args: Any, **kwargs: Any) -> Any: + async def inner(self: Any, *args: Any, **kwargs: Any) -> Any: + retry_policy = self._retry_policy attempt = 0 while True: try: - return await func(*args, **kwargs) + res = await func(self, *args, **kwargs) + await retry_policy.record_success(retry=attempt > 0) + return res except PyMongoError as exc: if not exc.has_error_label("Retryable"): raise attempt += 1 - if attempt > _MAX_RETRIES: + delay = 0 + if exc.has_error_label("SystemOverloaded"): + delay = retry_policy.backoff(attempt) + if not await retry_policy.should_retry(attempt, delay): raise # Implement exponential backoff on retry. - if exc.has_error_label("SystemOverloaded"): - await _backoff(attempt) + if delay: + await asyncio.sleep(delay) continue return cast(F, inner) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index ae6e819334..d9994e9902 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -35,6 +35,7 @@ import asyncio import contextlib import os +import time as time # noqa: PLC0414 # needed in sync version import warnings import weakref from collections import defaultdict @@ -67,7 +68,11 @@ from pymongo.asynchronous.client_bulk import _AsyncClientBulk from pymongo.asynchronous.client_session import _EmptyServerSession from pymongo.asynchronous.command_cursor import AsyncCommandCursor -from pymongo.asynchronous.helpers import _MAX_RETRIES, _backoff, _retry_overload +from pymongo.asynchronous.helpers import ( + _retry_overload, + _RetryPolicy, + _TokenBucket, +) from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext from pymongo.client_options import ClientOptions @@ -774,6 +779,7 @@ def __init__( self._timeout: float | None = None self._topology_settings: TopologySettings = None # type: ignore[assignment] self._event_listeners: _EventListeners | None = None + self._retry_policy = _RetryPolicy(_TokenBucket()) # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -2740,7 +2746,7 @@ def __init__( self._always_retryable = False self._multiple_retries = _csot.get_timeout() is not None self._client = mongo_client - + self._retry_policy = mongo_client._retry_policy self._func = func self._bulk = bulk self._session = session @@ -2775,7 +2781,9 @@ async def run(self) -> T: while True: self._check_last_error(check_csot=True) try: - return await self._read() if self._is_read else await self._write() + res = await self._read() if self._is_read else await self._write() + await self._retry_policy.record_success(self._attempt_number > 0) + return res except ServerSelectionTimeoutError: # The application may think the write was never attempted # if we raise ServerSelectionTimeoutError on the retry @@ -2846,13 +2854,14 @@ async def run(self) -> T: self._always_retryable = always_retryable if always_retryable: - if self._attempt_number > _MAX_RETRIES: + delay = self._retry_policy.backoff(self._attempt_number) if overloaded else 0 + if not await self._retry_policy.should_retry(self._attempt_number, delay): if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: raise self._last_error from exc else: raise if overloaded: - await _backoff(self._attempt_number) + await asyncio.sleep(delay) def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 3df867f7bc..324139d40a 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -256,6 +256,7 @@ def __init__( unicode_decode_error_handler="replace", document_class=dict ) self._timeout = database.client.options.timeout + self._retry_policy = database.client._retry_policy if create or kwargs: if _IS_SYNC: diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index d8b9ae6a10..62f8f69067 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -136,6 +136,7 @@ def __init__( self._name = name self._client: MongoClient[_DocumentType] = client self._timeout = client.options.timeout + self._retry_policy = client._retry_policy @property def client(self) -> MongoClient[_DocumentType]: diff --git a/pymongo/synchronous/helpers.py b/pymongo/synchronous/helpers.py index 889382b19c..0a2cd71062 100644 --- a/pymongo/synchronous/helpers.py +++ b/pymongo/synchronous/helpers.py @@ -21,7 +21,7 @@ import random import socket import sys -import time +import time as time # noqa: PLC0414 # needed in sync version from typing import ( Any, Callable, @@ -29,11 +29,13 @@ cast, ) +from pymongo import _csot from pymongo.errors import ( OperationFailure, PyMongoError, ) from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE +from pymongo.lock import _create_lock _IS_SYNC = True @@ -78,34 +80,115 @@ def inner(*args: Any, **kwargs: Any) -> Any: _MAX_RETRIES = 3 _BACKOFF_INITIAL = 0.05 _BACKOFF_MAX = 10 -_TIME = time +# DRIVERS-3240 will determine these defaults. +DEFAULT_RETRY_TOKEN_CAPACITY = 1000.0 +DEFAULT_RETRY_TOKEN_RETURN = 0.1 def _backoff( attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX -) -> None: +) -> float: jitter = random.random() # noqa: S311 - backoff = jitter * min(initial_delay * (2**attempt), max_delay) - time.sleep(backoff) + return jitter * min(initial_delay * (2**attempt), max_delay) + + +class _TokenBucket: + """A token bucket implementation for rate limiting.""" + + def __init__( + self, + capacity: float = DEFAULT_RETRY_TOKEN_CAPACITY, + return_rate: float = DEFAULT_RETRY_TOKEN_RETURN, + ): + self.lock = _create_lock() + self.capacity = capacity + # DRIVERS-3240 will determine how full the bucket should start. + self.tokens = capacity + self.return_rate = return_rate + + def consume(self) -> bool: + """Consume a token from the bucket if available.""" + with self.lock: + if self.tokens >= 1: + self.tokens -= 1 + return True + return False + + def deposit(self, retry: bool = False) -> None: + """Deposit a token back into the bucket.""" + retry_token = 1 if retry else 0 + with self.lock: + self.tokens = min(self.capacity, self.tokens + retry_token + self.return_rate) + + +class _RetryPolicy: + """A retry limiter that performs exponential backoff with jitter. + + Retry attempts are limited by a token bucket to prevent overwhelming the server during + a prolonged outage or high load. + """ + + def __init__( + self, + token_bucket: _TokenBucket, + attempts: int = _MAX_RETRIES, + backoff_initial: float = _BACKOFF_INITIAL, + backoff_max: float = _BACKOFF_MAX, + ): + self.token_bucket = token_bucket + self.attempts = attempts + self.backoff_initial = backoff_initial + self.backoff_max = backoff_max + + def record_success(self, retry: bool) -> None: + """Record a successful operation.""" + self.token_bucket.deposit(retry) + + def backoff(self, attempt: int) -> float: + """Return the backoff duration for the given .""" + return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max) + + def should_retry(self, attempt: int, delay: float) -> bool: + """Return if we have budget to retry and how long to backoff.""" + if attempt > self.attempts: + return False + + # If the delay would exceed the deadline, bail early before consuming a token. + if _csot.get_timeout(): + if time.monotonic() + delay > _csot.get_deadline(): + return False + + # Check token bucket last since we only want to consume a token if we actually retry. + if not self.token_bucket.consume(): + # DRIVERS-3246 Improve diagnostics when this case happens. + # We could add info to the exception and log. + return False + return True def _retry_overload(func: F) -> F: @functools.wraps(func) - def inner(*args: Any, **kwargs: Any) -> Any: + def inner(self: Any, *args: Any, **kwargs: Any) -> Any: + retry_policy = self._retry_policy attempt = 0 while True: try: - return func(*args, **kwargs) + res = func(self, *args, **kwargs) + retry_policy.record_success(retry=attempt > 0) + return res except PyMongoError as exc: if not exc.has_error_label("Retryable"): raise attempt += 1 - if attempt > _MAX_RETRIES: + delay = 0 + if exc.has_error_label("SystemOverloaded"): + delay = retry_policy.backoff(attempt) + if not retry_policy.should_retry(attempt, delay): raise # Implement exponential backoff on retry. - if exc.has_error_label("SystemOverloaded"): - _backoff(attempt) + if delay: + time.sleep(delay) continue return cast(F, inner) diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index dcd8c50cca..9beda807ef 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -35,6 +35,7 @@ import asyncio import contextlib import os +import time as time # noqa: PLC0414 # needed in sync version import warnings import weakref from collections import defaultdict @@ -110,7 +111,11 @@ from pymongo.synchronous.client_bulk import _ClientBulk from pymongo.synchronous.client_session import _EmptyServerSession from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.helpers import _MAX_RETRIES, _backoff, _retry_overload +from pymongo.synchronous.helpers import ( + _retry_overload, + _RetryPolicy, + _TokenBucket, +) from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription @@ -774,6 +779,7 @@ def __init__( self._timeout: float | None = None self._topology_settings: TopologySettings = None # type: ignore[assignment] self._event_listeners: _EventListeners | None = None + self._retry_policy = _RetryPolicy(_TokenBucket()) # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. @@ -2730,7 +2736,7 @@ def __init__( self._always_retryable = False self._multiple_retries = _csot.get_timeout() is not None self._client = mongo_client - + self._retry_policy = mongo_client._retry_policy self._func = func self._bulk = bulk self._session = session @@ -2765,7 +2771,9 @@ def run(self) -> T: while True: self._check_last_error(check_csot=True) try: - return self._read() if self._is_read else self._write() + res = self._read() if self._is_read else self._write() + self._retry_policy.record_success(self._attempt_number > 0) + return res except ServerSelectionTimeoutError: # The application may think the write was never attempted # if we raise ServerSelectionTimeoutError on the retry @@ -2836,13 +2844,14 @@ def run(self) -> T: self._always_retryable = always_retryable if always_retryable: - if self._attempt_number > _MAX_RETRIES: + delay = self._retry_policy.backoff(self._attempt_number) if overloaded else 0 + if not self._retry_policy.should_retry(self._attempt_number, delay): if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error: raise self._last_error from exc else: raise if overloaded: - _backoff(self._attempt_number) + time.sleep(delay) def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" diff --git a/test/asynchronous/test_backpressure.py b/test/asynchronous/test_backpressure.py index a9a6fb56f5..598236dbfe 100644 --- a/test/asynchronous/test_backpressure.py +++ b/test/asynchronous/test_backpressure.py @@ -15,13 +15,22 @@ """Test Client Backpressure spec.""" from __future__ import annotations +import asyncio import sys +import pymongo + sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous import ( + AsyncIntegrationTest, + AsyncPyMongoTestCase, + async_client_context, + unittest, +) -from pymongo.asynchronous.helpers import _MAX_RETRIES +from pymongo.asynchronous import helpers +from pymongo.asynchronous.helpers import _MAX_RETRIES, _RetryPolicy, _TokenBucket from pymongo.errors import PyMongoError _IS_SYNC = False @@ -150,6 +159,72 @@ async def test_retry_overload_error_getMore(self): self.assertIn("Retryable", str(error.exception)) + @async_client_context.require_failCommand_appName + async def test_limit_retry_command(self): + client = await self.async_rs_or_single_client() + client._retry_policy.token_bucket.tokens = 1 + db = client.pymongo_test + await db.t.insert_one({"x": 1}) + + # Ensure command is retried once overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": 1} + async with self.fail_point(fail_many): + await db.command("find", "t") + + # Ensure command stops retrying when there are no tokens left. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": 2} + async with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + await db.command("find", "t") + + self.assertIn("Retryable", str(error.exception)) + + +class TestRetryPolicy(AsyncPyMongoTestCase): + async def test_retry_policy(self): + capacity = 10 + retry_policy = _RetryPolicy(_TokenBucket(capacity=capacity)) + self.assertEqual(retry_policy.attempts, helpers._MAX_RETRIES) + self.assertEqual(retry_policy.backoff_initial, helpers._BACKOFF_INITIAL) + self.assertEqual(retry_policy.backoff_max, helpers._BACKOFF_MAX) + for i in range(1, helpers._MAX_RETRIES + 1): + self.assertTrue(await retry_policy.should_retry(i, 0)) + self.assertFalse(await retry_policy.should_retry(helpers._MAX_RETRIES + 1, 0)) + for i in range(capacity - helpers._MAX_RETRIES): + self.assertTrue(await retry_policy.should_retry(1, 0)) + # No tokens left, should not retry. + self.assertFalse(await retry_policy.should_retry(1, 0)) + self.assertEqual(retry_policy.token_bucket.tokens, 0) + + # record_success should generate tokens. + for _ in range(int(2 / helpers.DEFAULT_RETRY_TOKEN_RETURN)): + await retry_policy.record_success(retry=False) + self.assertAlmostEqual(retry_policy.token_bucket.tokens, 2) + for i in range(2): + self.assertTrue(await retry_policy.should_retry(1, 0)) + self.assertFalse(await retry_policy.should_retry(1, 0)) + + # Recording a successful retry should return 1 additional token. + await retry_policy.record_success(retry=True) + self.assertAlmostEqual( + retry_policy.token_bucket.tokens, 1 + helpers.DEFAULT_RETRY_TOKEN_RETURN + ) + self.assertTrue(await retry_policy.should_retry(1, 0)) + self.assertFalse(await retry_policy.should_retry(1, 0)) + self.assertAlmostEqual(retry_policy.token_bucket.tokens, helpers.DEFAULT_RETRY_TOKEN_RETURN) + + async def test_retry_policy_csot(self): + retry_policy = _RetryPolicy(_TokenBucket()) + self.assertTrue(await retry_policy.should_retry(1, 0.5)) + with pymongo.timeout(0.5): + self.assertTrue(await retry_policy.should_retry(1, 0)) + self.assertTrue(await retry_policy.should_retry(1, 0.1)) + # Would exceed the timeout, should not retry. + self.assertFalse(await retry_policy.should_retry(1, 1.0)) + self.assertTrue(await retry_policy.should_retry(1, 1.0)) + if __name__ == "__main__": unittest.main() diff --git a/test/test_backpressure.py b/test/test_backpressure.py index 324dd6f15a..182ce424a9 100644 --- a/test/test_backpressure.py +++ b/test/test_backpressure.py @@ -15,14 +15,23 @@ """Test Client Backpressure spec.""" from __future__ import annotations +import asyncio import sys +import pymongo + sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest +from test import ( + IntegrationTest, + PyMongoTestCase, + client_context, + unittest, +) from pymongo.errors import PyMongoError -from pymongo.synchronous.helpers import _MAX_RETRIES +from pymongo.synchronous import helpers +from pymongo.synchronous.helpers import _MAX_RETRIES, _RetryPolicy, _TokenBucket _IS_SYNC = True @@ -150,6 +159,72 @@ def test_retry_overload_error_getMore(self): self.assertIn("Retryable", str(error.exception)) + @client_context.require_failCommand_appName + def test_limit_retry_command(self): + client = self.rs_or_single_client() + client._retry_policy.token_bucket.tokens = 1 + db = client.pymongo_test + db.t.insert_one({"x": 1}) + + # Ensure command is retried once overload error. + fail_many = mock_overload_error.copy() + fail_many["mode"] = {"times": 1} + with self.fail_point(fail_many): + db.command("find", "t") + + # Ensure command stops retrying when there are no tokens left. + fail_too_many = mock_overload_error.copy() + fail_too_many["mode"] = {"times": 2} + with self.fail_point(fail_too_many): + with self.assertRaises(PyMongoError) as error: + db.command("find", "t") + + self.assertIn("Retryable", str(error.exception)) + + +class TestRetryPolicy(PyMongoTestCase): + def test_retry_policy(self): + capacity = 10 + retry_policy = _RetryPolicy(_TokenBucket(capacity=capacity)) + self.assertEqual(retry_policy.attempts, helpers._MAX_RETRIES) + self.assertEqual(retry_policy.backoff_initial, helpers._BACKOFF_INITIAL) + self.assertEqual(retry_policy.backoff_max, helpers._BACKOFF_MAX) + for i in range(1, helpers._MAX_RETRIES + 1): + self.assertTrue(retry_policy.should_retry(i, 0)) + self.assertFalse(retry_policy.should_retry(helpers._MAX_RETRIES + 1, 0)) + for i in range(capacity - helpers._MAX_RETRIES): + self.assertTrue(retry_policy.should_retry(1, 0)) + # No tokens left, should not retry. + self.assertFalse(retry_policy.should_retry(1, 0)) + self.assertEqual(retry_policy.token_bucket.tokens, 0) + + # record_success should generate tokens. + for _ in range(int(2 / helpers.DEFAULT_RETRY_TOKEN_RETURN)): + retry_policy.record_success(retry=False) + self.assertAlmostEqual(retry_policy.token_bucket.tokens, 2) + for i in range(2): + self.assertTrue(retry_policy.should_retry(1, 0)) + self.assertFalse(retry_policy.should_retry(1, 0)) + + # Recording a successful retry should return 1 additional token. + retry_policy.record_success(retry=True) + self.assertAlmostEqual( + retry_policy.token_bucket.tokens, 1 + helpers.DEFAULT_RETRY_TOKEN_RETURN + ) + self.assertTrue(retry_policy.should_retry(1, 0)) + self.assertFalse(retry_policy.should_retry(1, 0)) + self.assertAlmostEqual(retry_policy.token_bucket.tokens, helpers.DEFAULT_RETRY_TOKEN_RETURN) + + def test_retry_policy_csot(self): + retry_policy = _RetryPolicy(_TokenBucket()) + self.assertTrue(retry_policy.should_retry(1, 0.5)) + with pymongo.timeout(0.5): + self.assertTrue(retry_policy.should_retry(1, 0)) + self.assertTrue(retry_policy.should_retry(1, 0.1)) + # Would exceed the timeout, should not retry. + self.assertFalse(retry_policy.should_retry(1, 1.0)) + self.assertTrue(retry_policy.should_retry(1, 1.0)) + if __name__ == "__main__": unittest.main()