Skip to content

Commit a62171c

Browse files
committed
PYTHON-4924 - PoolClearedError should have TransientTransactionError label appended to it
1 parent a1b4a74 commit a62171c

File tree

4 files changed

+52
-14
lines changed

4 files changed

+52
-14
lines changed

pymongo/asynchronous/pool.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,10 @@ def _raise_connection_failure(
190190
) -> NoReturn:
191191
"""Convert a socket.error to ConnectionFailure and raise it."""
192192
host, port = address
193+
if isinstance(error, PyMongoError) and error._error_labels:
194+
labels = error._error_labels
195+
else:
196+
labels = None
193197
# If connecting to a Unix socket, port will be None.
194198
if port is not None:
195199
msg = "%s:%d: %s" % (host, port, error)
@@ -200,15 +204,15 @@ def _raise_connection_failure(
200204
if "configured timeouts" not in msg:
201205
msg += format_timeout_details(timeout_details)
202206
if isinstance(error, socket.timeout):
203-
raise NetworkTimeout(msg) from error
207+
raise NetworkTimeout(msg, errors={"errorLabels": labels}) from error
204208
elif isinstance(error, SSLError) and "timed out" in str(error):
205209
# Eventlet does not distinguish TLS network timeouts from other
206210
# SSLErrors (https://github.com/eventlet/eventlet/issues/692).
207211
# Luckily, we can work around this limitation because the phrase
208212
# 'timed out' appears in all the timeout related SSLErrors raised.
209-
raise NetworkTimeout(msg) from error
213+
raise NetworkTimeout(msg, errors={"errorLabels": labels}) from error
210214
else:
211-
raise AutoReconnect(msg) from error
215+
raise AutoReconnect(msg, errors={"errorLabels": labels}) from error
212216

213217

214218
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
@@ -1420,9 +1424,9 @@ def _raise_if_not_ready(self, checkout_started_time: float, emit_event: bool) ->
14201424
)
14211425

14221426
details = _get_timeout_details(self.opts)
1423-
_raise_connection_failure(
1424-
self.address, AutoReconnect("connection pool paused"), timeout_details=details
1425-
)
1427+
error = AutoReconnect("connection pool paused")
1428+
error._add_error_label("TransientTransactionError")
1429+
_raise_connection_failure(self.address, error, timeout_details=details)
14261430

14271431
async def _get_conn(
14281432
self, checkout_started_time: float, handler: Optional[_MongoClientErrorHandler] = None

pymongo/synchronous/pool.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,10 @@ def _raise_connection_failure(
190190
) -> NoReturn:
191191
"""Convert a socket.error to ConnectionFailure and raise it."""
192192
host, port = address
193+
if isinstance(error, PyMongoError) and error._error_labels:
194+
labels = error._error_labels
195+
else:
196+
labels = None
193197
# If connecting to a Unix socket, port will be None.
194198
if port is not None:
195199
msg = "%s:%d: %s" % (host, port, error)
@@ -200,15 +204,15 @@ def _raise_connection_failure(
200204
if "configured timeouts" not in msg:
201205
msg += format_timeout_details(timeout_details)
202206
if isinstance(error, socket.timeout):
203-
raise NetworkTimeout(msg) from error
207+
raise NetworkTimeout(msg, errors={"errorLabels": labels}) from error
204208
elif isinstance(error, SSLError) and "timed out" in str(error):
205209
# Eventlet does not distinguish TLS network timeouts from other
206210
# SSLErrors (https://github.com/eventlet/eventlet/issues/692).
207211
# Luckily, we can work around this limitation because the phrase
208212
# 'timed out' appears in all the timeout related SSLErrors raised.
209-
raise NetworkTimeout(msg) from error
213+
raise NetworkTimeout(msg, errors={"errorLabels": labels}) from error
210214
else:
211-
raise AutoReconnect(msg) from error
215+
raise AutoReconnect(msg, errors={"errorLabels": labels}) from error
212216

213217

214218
def _get_timeout_details(options: PoolOptions) -> dict[str, float]:
@@ -1414,9 +1418,9 @@ def _raise_if_not_ready(self, checkout_started_time: float, emit_event: bool) ->
14141418
)
14151419

14161420
details = _get_timeout_details(self.opts)
1417-
_raise_connection_failure(
1418-
self.address, AutoReconnect("connection pool paused"), timeout_details=details
1419-
)
1421+
error = AutoReconnect("connection pool paused")
1422+
error._add_error_label("TransientTransactionError")
1423+
_raise_connection_failure(self.address, error, timeout_details=details)
14201424

14211425
def _get_conn(
14221426
self, checkout_started_time: float, handler: Optional[_MongoClientErrorHandler] = None

test/asynchronous/test_pooling.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from test.asynchronous.helpers import ConcurrentRunner
3737
from test.utils_shared import delay
3838

39-
from pymongo.asynchronous.pool import Pool, PoolOptions
39+
from pymongo.asynchronous.pool import Pool, PoolOptions, PoolState
4040
from pymongo.socket_checker import SocketChecker
4141

4242
_IS_SYNC = False
@@ -608,6 +608,21 @@ async def test_max_pool_size_with_connection_failure(self):
608608
# seems error-prone, so check the message too.
609609
self.assertNotIn("waiting for socket from pool", str(context.exception))
610610

611+
async def test_pool_cleared_error_labelled_transient(self):
612+
test_pool = Pool(
613+
("localhost", 27017),
614+
PoolOptions(max_pool_size=1),
615+
)
616+
# Pause the pool, causing it to fail connection checkout.
617+
test_pool.state = PoolState.PAUSED
618+
619+
with self.assertRaises(AutoReconnect) as context:
620+
async with test_pool.checkout():
621+
pass
622+
623+
# Verify that the TransientTransactionError label is present in the error.
624+
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
625+
611626

612627
if __name__ == "__main__":
613628
unittest.main()

test/test_pooling.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from test.utils_shared import delay
3838

3939
from pymongo.socket_checker import SocketChecker
40-
from pymongo.synchronous.pool import Pool, PoolOptions
40+
from pymongo.synchronous.pool import Pool, PoolOptions, PoolState
4141

4242
_IS_SYNC = True
4343

@@ -606,6 +606,21 @@ def test_max_pool_size_with_connection_failure(self):
606606
# seems error-prone, so check the message too.
607607
self.assertNotIn("waiting for socket from pool", str(context.exception))
608608

609+
def test_pool_cleared_error_labelled_transient(self):
610+
test_pool = Pool(
611+
("localhost", 27017),
612+
PoolOptions(max_pool_size=1),
613+
)
614+
# Pause the pool, causing it to fail connection checkout.
615+
test_pool.state = PoolState.PAUSED
616+
617+
with self.assertRaises(AutoReconnect) as context:
618+
with test_pool.checkout():
619+
pass
620+
621+
# Verify that the TransientTransactionError label is present in the error.
622+
self.assertTrue(context.exception.has_error_label("TransientTransactionError"))
623+
609624

610625
if __name__ == "__main__":
611626
unittest.main()

0 commit comments

Comments
 (0)