Skip to content

Commit 752d8c0

Browse files
committed
PYTHON-5506 Check CSOT deadline before consuming a token
1 parent 6cdf7cb commit 752d8c0

File tree

6 files changed

+91
-41
lines changed

6 files changed

+91
-41
lines changed

pymongo/asynchronous/helpers.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
cast,
3030
)
3131

32+
from pymongo import _csot
3233
from pymongo.errors import (
3334
OperationFailure,
3435
PyMongoError,
@@ -85,12 +86,11 @@ async def inner(*args: Any, **kwargs: Any) -> Any:
8586
_TIME = time # Added so synchro script doesn't remove the time import.
8687

8788

88-
async def _backoff(
89+
def _backoff(
8990
attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX
90-
) -> None:
91+
) -> float:
9192
jitter = random.random() # noqa: S311
92-
backoff = jitter * min(initial_delay * (2**attempt), max_delay)
93-
await asyncio.sleep(backoff)
93+
return jitter * min(initial_delay * (2**attempt), max_delay)
9494

9595

9696
class _TokenBucket:
@@ -145,15 +145,20 @@ async def record_success(self, retry: bool) -> None:
145145
"""Record a successful operation."""
146146
await self.token_bucket.deposit(retry)
147147

148-
async def backoff(self, attempt: int) -> None:
148+
def backoff(self, attempt: int) -> float:
149149
"""Return the backoff duration for the given ."""
150-
await _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)
150+
return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)
151151

152-
async def should_retry(self, attempt: int) -> bool:
152+
async def should_retry(self, attempt: int, delay: float) -> bool:
153153
"""Return if we have budget to retry and how long to backoff."""
154-
# TODO: Check CSOT deadline here.
155154
if attempt > self.attempts:
156155
return False
156+
157+
# If the delay would exceed the deadline, bail early before consuming a token.
158+
if _csot.get_timeout():
159+
if time.monotonic() + delay > _csot.get_deadline():
160+
return False
161+
157162
# Check token bucket last since we only want to consume a token if we actually retry.
158163
if not await self.token_bucket.consume():
159164
# DRIVERS-3246 Improve diagnostics when this case happens.
@@ -176,12 +181,15 @@ async def inner(self: Any, *args: Any, **kwargs: Any) -> Any:
176181
if not exc.has_error_label("Retryable"):
177182
raise
178183
attempt += 1
179-
if not await retry_policy.should_retry(attempt):
184+
delay = 0
185+
if exc.has_error_label("SystemOverloaded"):
186+
delay = retry_policy.backoff(attempt)
187+
if not await retry_policy.should_retry(attempt, delay):
180188
raise
181189

182190
# Implement exponential backoff on retry.
183-
if exc.has_error_label("SystemOverloaded"):
184-
await retry_policy.backoff(attempt)
191+
if delay:
192+
await asyncio.sleep(delay)
185193
continue
186194

187195
return cast(F, inner)

pymongo/asynchronous/mongo_client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import asyncio
3636
import contextlib
3737
import os
38+
import time
3839
import warnings
3940
import weakref
4041
from collections import defaultdict
@@ -174,6 +175,8 @@
174175
UpdateMany,
175176
]
176177

178+
_TIME = time # Added so synchro script doesn't remove the time import.
179+
177180

178181
class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
179182
HOST = "localhost"
@@ -2853,13 +2856,14 @@ async def run(self) -> T:
28532856

28542857
self._always_retryable = always_retryable
28552858
if always_retryable:
2856-
if not await self._retry_policy.should_retry(self._attempt_number):
2859+
delay = self._retry_policy.backoff(self._attempt_number) if overloaded else 0
2860+
if not await self._retry_policy.should_retry(self._attempt_number, delay):
28572861
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
28582862
raise self._last_error from exc
28592863
else:
28602864
raise
28612865
if overloaded:
2862-
await self._retry_policy.backoff(self._attempt_number)
2866+
await asyncio.sleep(delay)
28632867

28642868
def _is_not_eligible_for_retry(self) -> bool:
28652869
"""Checks if the exchange is not eligible for retry"""

pymongo/synchronous/helpers.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
cast,
3030
)
3131

32+
from pymongo import _csot
3233
from pymongo.errors import (
3334
OperationFailure,
3435
PyMongoError,
@@ -87,10 +88,9 @@ def inner(*args: Any, **kwargs: Any) -> Any:
8788

8889
def _backoff(
8990
attempt: int, initial_delay: float = _BACKOFF_INITIAL, max_delay: float = _BACKOFF_MAX
90-
) -> None:
91+
) -> float:
9192
jitter = random.random() # noqa: S311
92-
backoff = jitter * min(initial_delay * (2**attempt), max_delay)
93-
time.sleep(backoff)
93+
return jitter * min(initial_delay * (2**attempt), max_delay)
9494

9595

9696
class _TokenBucket:
@@ -145,15 +145,20 @@ def record_success(self, retry: bool) -> None:
145145
"""Record a successful operation."""
146146
self.token_bucket.deposit(retry)
147147

148-
def backoff(self, attempt: int) -> None:
148+
def backoff(self, attempt: int) -> float:
149149
"""Return the backoff duration for the given ."""
150-
_backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)
150+
return _backoff(max(0, attempt - 1), self.backoff_initial, self.backoff_max)
151151

152-
def should_retry(self, attempt: int) -> bool:
152+
def should_retry(self, attempt: int, delay: float) -> bool:
153153
"""Return if we have budget to retry and how long to backoff."""
154-
# TODO: Check CSOT deadline here.
155154
if attempt > self.attempts:
156155
return False
156+
157+
# If the delay would exceed the deadline, bail early before consuming a token.
158+
if _csot.get_timeout():
159+
if time.monotonic() + delay > _csot.get_deadline():
160+
return False
161+
157162
# Check token bucket last since we only want to consume a token if we actually retry.
158163
if not self.token_bucket.consume():
159164
# DRIVERS-3246 Improve diagnostics when this case happens.
@@ -176,12 +181,15 @@ def inner(self: Any, *args: Any, **kwargs: Any) -> Any:
176181
if not exc.has_error_label("Retryable"):
177182
raise
178183
attempt += 1
179-
if not retry_policy.should_retry(attempt):
184+
delay = 0
185+
if exc.has_error_label("SystemOverloaded"):
186+
delay = retry_policy.backoff(attempt)
187+
if not retry_policy.should_retry(attempt, delay):
180188
raise
181189

182190
# Implement exponential backoff on retry.
183-
if exc.has_error_label("SystemOverloaded"):
184-
retry_policy.backoff(attempt)
191+
if delay:
192+
time.sleep(delay)
185193
continue
186194

187195
return cast(F, inner)

pymongo/synchronous/mongo_client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import asyncio
3636
import contextlib
3737
import os
38+
import time
3839
import warnings
3940
import weakref
4041
from collections import defaultdict
@@ -171,6 +172,8 @@
171172
UpdateMany,
172173
]
173174

175+
_TIME = time # Added so synchro script doesn't remove the time import.
176+
174177

175178
class MongoClient(common.BaseObject, Generic[_DocumentType]):
176179
HOST = "localhost"
@@ -2843,13 +2846,14 @@ def run(self) -> T:
28432846

28442847
self._always_retryable = always_retryable
28452848
if always_retryable:
2846-
if not self._retry_policy.should_retry(self._attempt_number):
2849+
delay = self._retry_policy.backoff(self._attempt_number) if overloaded else 0
2850+
if not self._retry_policy.should_retry(self._attempt_number, delay):
28472851
if exc_to_check.has_error_label("NoWritesPerformed") and self._last_error:
28482852
raise self._last_error from exc
28492853
else:
28502854
raise
28512855
if overloaded:
2852-
self._retry_policy.backoff(self._attempt_number)
2856+
time.sleep(delay)
28532857

28542858
def _is_not_eligible_for_retry(self) -> bool:
28552859
"""Checks if the exchange is not eligible for retry"""

test/asynchronous/test_backpressure.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
"""Test Client Backpressure spec."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import sys
1920

21+
import pymongo
22+
2023
sys.path[0:0] = [""]
2124

2225
from test.asynchronous import (
@@ -187,31 +190,41 @@ async def test_retry_policy(self):
187190
self.assertEqual(retry_policy.backoff_initial, helpers._BACKOFF_INITIAL)
188191
self.assertEqual(retry_policy.backoff_max, helpers._BACKOFF_MAX)
189192
for i in range(1, helpers._MAX_RETRIES + 1):
190-
self.assertTrue(await retry_policy.should_retry(i))
191-
self.assertFalse(await retry_policy.should_retry(helpers._MAX_RETRIES + 1))
193+
self.assertTrue(await retry_policy.should_retry(i, 0))
194+
self.assertFalse(await retry_policy.should_retry(helpers._MAX_RETRIES + 1, 0))
192195
for i in range(capacity - helpers._MAX_RETRIES):
193-
self.assertTrue(await retry_policy.should_retry(1))
196+
self.assertTrue(await retry_policy.should_retry(1, 0))
194197
# No tokens left, should not retry.
195-
self.assertFalse(await retry_policy.should_retry(1))
198+
self.assertFalse(await retry_policy.should_retry(1, 0))
196199
self.assertEqual(retry_policy.token_bucket.tokens, 0)
197200

198201
# record_success should generate tokens.
199202
for _ in range(int(2 / helpers.DEFAULT_RETRY_TOKEN_RETURN)):
200203
await retry_policy.record_success(retry=False)
201204
self.assertAlmostEqual(retry_policy.token_bucket.tokens, 2)
202205
for i in range(2):
203-
self.assertTrue(await retry_policy.should_retry(1))
204-
self.assertFalse(await retry_policy.should_retry(1))
206+
self.assertTrue(await retry_policy.should_retry(1, 0))
207+
self.assertFalse(await retry_policy.should_retry(1, 0))
205208

206209
# Recording a successful retry should return 1 additional token.
207210
await retry_policy.record_success(retry=True)
208211
self.assertAlmostEqual(
209212
retry_policy.token_bucket.tokens, 1 + helpers.DEFAULT_RETRY_TOKEN_RETURN
210213
)
211-
self.assertTrue(await retry_policy.should_retry(1))
212-
self.assertFalse(await retry_policy.should_retry(1))
214+
self.assertTrue(await retry_policy.should_retry(1, 0))
215+
self.assertFalse(await retry_policy.should_retry(1, 0))
213216
self.assertAlmostEqual(retry_policy.token_bucket.tokens, helpers.DEFAULT_RETRY_TOKEN_RETURN)
214217

218+
async def test_retry_policy_csot(self):
219+
retry_policy = _RetryPolicy(_TokenBucket())
220+
self.assertTrue(await retry_policy.should_retry(1, 0.5))
221+
with pymongo.timeout(0.5):
222+
self.assertTrue(await retry_policy.should_retry(1, 0))
223+
self.assertTrue(await retry_policy.should_retry(1, 0.1))
224+
# Would exceed the timeout, should not retry.
225+
self.assertFalse(await retry_policy.should_retry(1, 1.0))
226+
self.assertTrue(await retry_policy.should_retry(1, 1.0))
227+
215228

216229
if __name__ == "__main__":
217230
unittest.main()

test/test_backpressure.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,11 @@
1515
"""Test Client Backpressure spec."""
1616
from __future__ import annotations
1717

18+
import asyncio
1819
import sys
1920

21+
import pymongo
22+
2023
sys.path[0:0] = [""]
2124

2225
from test import (
@@ -187,31 +190,41 @@ def test_retry_policy(self):
187190
self.assertEqual(retry_policy.backoff_initial, helpers._BACKOFF_INITIAL)
188191
self.assertEqual(retry_policy.backoff_max, helpers._BACKOFF_MAX)
189192
for i in range(1, helpers._MAX_RETRIES + 1):
190-
self.assertTrue(retry_policy.should_retry(i))
191-
self.assertFalse(retry_policy.should_retry(helpers._MAX_RETRIES + 1))
193+
self.assertTrue(retry_policy.should_retry(i, 0))
194+
self.assertFalse(retry_policy.should_retry(helpers._MAX_RETRIES + 1, 0))
192195
for i in range(capacity - helpers._MAX_RETRIES):
193-
self.assertTrue(retry_policy.should_retry(1))
196+
self.assertTrue(retry_policy.should_retry(1, 0))
194197
# No tokens left, should not retry.
195-
self.assertFalse(retry_policy.should_retry(1))
198+
self.assertFalse(retry_policy.should_retry(1, 0))
196199
self.assertEqual(retry_policy.token_bucket.tokens, 0)
197200

198201
# record_success should generate tokens.
199202
for _ in range(int(2 / helpers.DEFAULT_RETRY_TOKEN_RETURN)):
200203
retry_policy.record_success(retry=False)
201204
self.assertAlmostEqual(retry_policy.token_bucket.tokens, 2)
202205
for i in range(2):
203-
self.assertTrue(retry_policy.should_retry(1))
204-
self.assertFalse(retry_policy.should_retry(1))
206+
self.assertTrue(retry_policy.should_retry(1, 0))
207+
self.assertFalse(retry_policy.should_retry(1, 0))
205208

206209
# Recording a successful retry should return 1 additional token.
207210
retry_policy.record_success(retry=True)
208211
self.assertAlmostEqual(
209212
retry_policy.token_bucket.tokens, 1 + helpers.DEFAULT_RETRY_TOKEN_RETURN
210213
)
211-
self.assertTrue(retry_policy.should_retry(1))
212-
self.assertFalse(retry_policy.should_retry(1))
214+
self.assertTrue(retry_policy.should_retry(1, 0))
215+
self.assertFalse(retry_policy.should_retry(1, 0))
213216
self.assertAlmostEqual(retry_policy.token_bucket.tokens, helpers.DEFAULT_RETRY_TOKEN_RETURN)
214217

218+
def test_retry_policy_csot(self):
219+
retry_policy = _RetryPolicy(_TokenBucket())
220+
self.assertTrue(retry_policy.should_retry(1, 0.5))
221+
with pymongo.timeout(0.5):
222+
self.assertTrue(retry_policy.should_retry(1, 0))
223+
self.assertTrue(retry_policy.should_retry(1, 0.1))
224+
# Would exceed the timeout, should not retry.
225+
self.assertFalse(retry_policy.should_retry(1, 1.0))
226+
self.assertTrue(retry_policy.should_retry(1, 1.0))
227+
215228

216229
if __name__ == "__main__":
217230
unittest.main()

0 commit comments

Comments
 (0)