Skip to content

Commit 45a7963

Browse files
committed
PYTHON-2082 Retryable writes use the RetryableWriteError label
Use retryable write logic for transaction commit/abort. Do not assign the TransientTransactionError label to errors outside a transaction.
1 parent 48df9b0 commit 45a7963

34 files changed

+2536
-192
lines changed

pymongo/client_session.py

Lines changed: 19 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,11 @@
9898
"""
9999

100100
import collections
101-
import os
102-
import sys
103101
import uuid
104102

105103
from bson.binary import Binary
106104
from bson.int64 import Int64
107-
from bson.py3compat import abc, integer_types, reraise_instance
105+
from bson.py3compat import abc, integer_types
108106
from bson.son import SON
109107
from bson.timestamp import Timestamp
110108

@@ -114,7 +112,6 @@
114112
InvalidOperation,
115113
OperationFailure,
116114
PyMongoError,
117-
ServerSelectionTimeoutError,
118115
WTimeoutError)
119116
from pymongo.helpers import _RETRYABLE_ERROR_CODES
120117
from pymongo.read_concern import ReadConcern
@@ -295,6 +292,7 @@ def __init__(self, opts):
295292
self.sharded = False
296293
self.pinned_address = None
297294
self.recovery_token = None
295+
self.attempt = 0
298296

299297
def active(self):
300298
return self.state in (_TxnState.STARTING, _TxnState.IN_PROGRESS)
@@ -304,12 +302,13 @@ def reset(self):
304302
self.sharded = False
305303
self.pinned_address = None
306304
self.recovery_token = None
305+
self.attempt = 0
307306

308307

309308
def _reraise_with_unknown_commit(exc):
310309
"""Re-raise an exception with the UnknownTransactionCommitResult label."""
311310
exc._add_error_label("UnknownTransactionCommitResult")
312-
reraise_instance(exc, trace=sys.exc_info()[2])
311+
raise
313312

314313

315314
def _max_time_expired_error(exc):
@@ -579,7 +578,6 @@ def commit_transaction(self):
579578
.. versionadded:: 3.7
580579
"""
581580
self._check_ended()
582-
retry = False
583581
state = self._transaction.state
584582
if state is _TxnState.NONE:
585583
raise InvalidOperation("No transaction started")
@@ -594,10 +592,9 @@ def commit_transaction(self):
594592
# We're explicitly retrying the commit, move the state back to
595593
# "in progress" so that in_transaction returns true.
596594
self._transaction.state = _TxnState.IN_PROGRESS
597-
retry = True
598595

599596
try:
600-
self._finish_transaction_with_retry("commitTransaction", retry)
597+
self._finish_transaction_with_retry("commitTransaction")
601598
except ConnectionFailure as exc:
602599
# We do not know if the commit was successfully applied on the
603600
# server or if it satisfied the provided write concern, set the
@@ -640,44 +637,25 @@ def abort_transaction(self):
640637
"Cannot call abortTransaction after calling commitTransaction")
641638

642639
try:
643-
self._finish_transaction_with_retry("abortTransaction", False)
640+
self._finish_transaction_with_retry("abortTransaction")
644641
except (OperationFailure, ConnectionFailure):
645642
# The transactions spec says to ignore abortTransaction errors.
646643
pass
647644
finally:
648645
self._transaction.state = _TxnState.ABORTED
649646

650-
def _finish_transaction_with_retry(self, command_name, explict_retry):
647+
def _finish_transaction_with_retry(self, command_name):
651648
"""Run commit or abort with one retry after any retryable error.
652649
653650
:Parameters:
654651
- `command_name`: Either "commitTransaction" or "abortTransaction".
655-
- `explict_retry`: True when this is an explict commit retry attempt,
656-
ie the application called session.commit_transaction() twice.
657652
"""
658-
# This can be refactored with MongoClient._retry_with_session.
659-
try:
660-
return self._finish_transaction(command_name, explict_retry)
661-
except ServerSelectionTimeoutError:
662-
raise
663-
except ConnectionFailure as exc:
664-
try:
665-
return self._finish_transaction(command_name, True)
666-
except ServerSelectionTimeoutError:
667-
# Raise the original error so the application can infer that
668-
# an attempt was made.
669-
raise exc
670-
except OperationFailure as exc:
671-
if exc.code not in _RETRYABLE_ERROR_CODES:
672-
raise
673-
try:
674-
return self._finish_transaction(command_name, True)
675-
except ServerSelectionTimeoutError:
676-
# Raise the original error so the application can infer that
677-
# an attempt was made.
678-
raise exc
653+
def func(session, sock_info, retryable):
654+
return self._finish_transaction(sock_info, command_name)
655+
return self._client._retry_internal(True, func, self, None)
679656

680-
def _finish_transaction(self, command_name, retrying):
657+
def _finish_transaction(self, sock_info, command_name):
658+
self._transaction.attempt += 1
681659
opts = self._transaction.opts
682660
wc = opts.write_concern
683661
cmd = SON([(command_name, 1)])
@@ -688,7 +666,7 @@ def _finish_transaction(self, command_name, retrying):
688666
# Transaction spec says that after the initial commit attempt,
689667
# subsequent commitTransaction commands should be upgraded to use
690668
# w:"majority" and set a default value of 10 seconds for wtimeout.
691-
if retrying:
669+
if self._transaction.attempt > 1:
692670
wc_doc = wc.document
693671
wc_doc["w"] = "majority"
694672
wc_doc.setdefault("wtimeout", 10000)
@@ -697,13 +675,12 @@ def _finish_transaction(self, command_name, retrying):
697675
if self._transaction.recovery_token:
698676
cmd['recoveryToken'] = self._transaction.recovery_token
699677

700-
with self._client._socket_for_writes(self) as sock_info:
701-
return self._client.admin._command(
702-
sock_info,
703-
cmd,
704-
session=self,
705-
write_concern=wc,
706-
parse_write_concern_error=True)
678+
return self._client.admin._command(
679+
sock_info,
680+
cmd,
681+
session=self,
682+
write_concern=wc,
683+
parse_write_concern_error=True)
707684

708685
def _advance_cluster_time(self, cluster_time):
709686
"""Internal cluster time helper."""

pymongo/errors.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _add_error_label(self, label):
4848

4949
def _remove_error_label(self, label):
5050
"""Remove the given label from this error."""
51-
self._error_labels.remove(label)
51+
self._error_labels.discard(label)
5252

5353
if sys.version_info[0] == 2:
5454
def __str__(self):
@@ -68,12 +68,6 @@ class ProtocolError(PyMongoError):
6868

6969
class ConnectionFailure(PyMongoError):
7070
"""Raised when a connection to the database cannot be made or is lost."""
71-
def __init__(self, message='', error_labels=None):
72-
if error_labels is None:
73-
# Connection errors are transient errors by default.
74-
error_labels = ("TransientTransactionError",)
75-
super(ConnectionFailure, self).__init__(
76-
message, error_labels=error_labels)
7771

7872

7973
class AutoReconnect(ConnectionFailure):
@@ -89,7 +83,10 @@ class AutoReconnect(ConnectionFailure):
8983
Subclass of :exc:`~pymongo.errors.ConnectionFailure`.
9084
"""
9185
def __init__(self, message='', errors=None):
92-
super(AutoReconnect, self).__init__(message)
86+
error_labels = None
87+
if errors is not None and isinstance(errors, dict):
88+
error_labels = errors.get('errorLabels')
89+
super(AutoReconnect, self).__init__(message, error_labels)
9390
self.errors = self.details = errors or []
9491

9592

pymongo/mongo_client.py

Lines changed: 77 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
ConfigurationError,
6060
ConnectionFailure,
6161
InvalidOperation,
62+
NotMasterError,
6263
OperationFailure,
6364
PyMongoError,
6465
ServerSelectionTimeoutError)
@@ -1265,7 +1266,9 @@ def _select_server(self, server_selector, session, address=None):
12651266
session._pin_mongos(server)
12661267
return server
12671268
except PyMongoError as exc:
1268-
if session and exc.has_error_label("TransientTransactionError"):
1269+
# Server selection errors in a transaction are transient.
1270+
if session and session.in_transaction:
1271+
exc._add_error_label("TransientTransactionError")
12691272
session._unpin_mongos()
12701273
raise
12711274

@@ -1361,6 +1364,11 @@ def _retry_with_session(self, retryable, func, session, bulk):
13611364
"""
13621365
retryable = (retryable and self.retry_writes
13631366
and session and not session.in_transaction)
1367+
return self._retry_internal(retryable, func, session, bulk)
1368+
1369+
def _retry_internal(self, retryable, func, session, bulk):
1370+
"""Internal retryable write helper."""
1371+
max_wire_version = 0
13641372
last_error = None
13651373
retrying = False
13661374

@@ -1369,7 +1377,7 @@ def is_retrying():
13691377
# Increment the transaction id up front to ensure any retry attempt
13701378
# will use the proper txnNumber, even if server or socket selection
13711379
# fails before the command can be sent.
1372-
if retryable:
1380+
if retryable and session and not session.in_transaction:
13731381
session._start_retryable_write()
13741382
if bulk:
13751383
bulk.started_retryable_write = True
@@ -1381,6 +1389,7 @@ def is_retrying():
13811389
session is not None and
13821390
server.description.retryable_writes_supported)
13831391
with self._get_socket(server, session) as sock_info:
1392+
max_wire_version = sock_info.max_wire_version
13841393
if retryable and not supports_session:
13851394
if is_retrying():
13861395
# A retry is not possible because this server does
@@ -1398,40 +1407,12 @@ def is_retrying():
13981407
# be a persistent outage. Attempting to retry in this case will
13991408
# most likely be a waste of time.
14001409
raise
1401-
except ConnectionFailure as exc:
1402-
if not retryable or is_retrying():
1410+
except Exception as exc:
1411+
if not retryable:
14031412
raise
1404-
if bulk:
1405-
bulk.retrying = True
1406-
else:
1407-
retrying = True
1408-
last_error = exc
1409-
except BulkWriteError as exc:
1410-
if not retryable or is_retrying():
1411-
raise
1412-
# Check the last writeConcernError to determine if this
1413-
# BulkWriteError is retryable.
1414-
wces = exc.details['writeConcernErrors']
1415-
wce = wces[-1] if wces else {}
1416-
if wce.get('code', 0) not in helpers._RETRYABLE_ERROR_CODES:
1417-
raise
1418-
if bulk:
1419-
bulk.retrying = True
1420-
else:
1421-
retrying = True
1422-
last_error = exc
1423-
except OperationFailure as exc:
1424-
# retryWrites on MMAPv1 should raise an actionable error.
1425-
if (exc.code == 20 and
1426-
str(exc).startswith("Transaction numbers")):
1427-
errmsg = (
1428-
"This MongoDB deployment does not support "
1429-
"retryable writes. Please add retryWrites=false "
1430-
"to your connection string.")
1431-
raise OperationFailure(errmsg, exc.code, exc.details)
1432-
if not retryable or is_retrying():
1433-
raise
1434-
if exc.code not in helpers._RETRYABLE_ERROR_CODES:
1413+
# Add the RetryableWriteError label.
1414+
if (not _retryable_writes_error(exc, max_wire_version)
1415+
or is_retrying()):
14351416
raise
14361417
if bulk:
14371418
bulk.retrying = True
@@ -2162,26 +2143,66 @@ def __next__(self):
21622143
next = __next__
21632144

21642145

2146+
def _retryable_error_doc(exc):
2147+
"""Return the server response from PyMongo exception or None."""
2148+
if isinstance(exc, BulkWriteError):
2149+
# Check the last writeConcernError to determine if this
2150+
# BulkWriteError is retryable.
2151+
wces = exc.details['writeConcernErrors']
2152+
wce = wces[-1] if wces else None
2153+
return wce
2154+
if isinstance(exc, (NotMasterError, OperationFailure)):
2155+
return exc.details
2156+
return None
2157+
2158+
2159+
def _retryable_writes_error(exc, max_wire_version):
2160+
doc = _retryable_error_doc(exc)
2161+
if doc:
2162+
code = doc.get('code', 0)
2163+
# retryWrites on MMAPv1 should raise an actionable error.
2164+
if (code == 20 and
2165+
str(exc).startswith("Transaction numbers")):
2166+
errmsg = (
2167+
"This MongoDB deployment does not support "
2168+
"retryable writes. Please add retryWrites=false "
2169+
"to your connection string.")
2170+
raise OperationFailure(errmsg, code, exc.details)
2171+
if max_wire_version >= 9:
2172+
# MongoDB 4.4+ utilizes RetryableWriteError.
2173+
return 'RetryableWriteError' in doc.get('errorLabels', [])
2174+
else:
2175+
if code in helpers._RETRYABLE_ERROR_CODES:
2176+
exc._add_error_label("RetryableWriteError")
2177+
return True
2178+
return False
2179+
2180+
if isinstance(exc, ConnectionFailure):
2181+
exc._add_error_label("RetryableWriteError")
2182+
return True
2183+
return False
2184+
2185+
21652186
class _MongoClientErrorHandler(object):
2166-
"""Error handler for MongoClient."""
2167-
__slots__ = ('_client', '_server_address', '_session',
2168-
'_max_wire_version', '_sock_generation')
2187+
"""Handle errors raised when executing an operation."""
2188+
__slots__ = ('client', 'server_address', 'session', 'max_wire_version',
2189+
'sock_generation')
21692190

21702191
def __init__(self, client, server, session):
2171-
self._client = client
2172-
self._server_address = server.description.address
2173-
self._session = session
2174-
self._max_wire_version = common.MIN_WIRE_VERSION
2192+
self.client = client
2193+
self.server_address = server.description.address
2194+
self.session = session
2195+
self.max_wire_version = common.MIN_WIRE_VERSION
21752196
# XXX: When get_socket fails, this generation could be out of date:
21762197
# "Note that when a network error occurs before the handshake
21772198
# completes then the error's generation number is the generation
21782199
# of the pool at the time the connection attempt was started."
2179-
self._sock_generation = server.pool.generation
2200+
self.sock_generation = server.pool.generation
21802201

21812202
def contribute_socket(self, sock_info):
21822203
"""Provide socket information to the error handler."""
2183-
self._max_wire_version = sock_info.max_wire_version
2184-
self._sock_generation = sock_info.generation
2204+
self.max_wire_version = sock_info.max_wire_version
2205+
self.sock_generation = sock_info.generation
21852206

21862207
def __enter__(self):
21872208
return self
@@ -2190,15 +2211,16 @@ def __exit__(self, exc_type, exc_val, exc_tb):
21902211
if exc_type is None:
21912212
return
21922213

2193-
err_ctx = _ErrorContext(
2194-
exc_val, self._max_wire_version, self._sock_generation)
2195-
self._client._topology.handle_error(self._server_address, err_ctx)
2214+
if self.session:
2215+
if issubclass(exc_type, ConnectionFailure):
2216+
if self.session.in_transaction:
2217+
exc_val._add_error_label("TransientTransactionError")
2218+
self.session._server_session.mark_dirty()
21962219

2197-
if issubclass(exc_type, PyMongoError):
2198-
if self._session and exc_val.has_error_label(
2199-
"TransientTransactionError"):
2200-
self._session._unpin_mongos()
2220+
if issubclass(exc_type, PyMongoError):
2221+
if exc_val.has_error_label("TransientTransactionError"):
2222+
self.session._unpin_mongos()
22012223

2202-
if issubclass(exc_type, ConnectionFailure):
2203-
if self._session:
2204-
self._session._server_session.mark_dirty()
2224+
err_ctx = _ErrorContext(
2225+
exc_val, self.max_wire_version, self.sock_generation)
2226+
self.client._topology.handle_error(self.server_address, err_ctx)

0 commit comments

Comments
 (0)