Skip to content

Commit 0ea2979

Browse files
committed
PYTHON-2915 Fix bug when starting a transaction with a large bulk write (#743)
(cherry picked from commit 7467aa6)
1 parent a1cd624 commit 0ea2979

File tree

5 files changed

+110
-51
lines changed

5 files changed

+110
-51
lines changed

doc/changelog.rst

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,27 @@
11
Changelog
22
=========
33

4+
Changes in Version 3.12.1
5+
-------------------------
6+
7+
Issues Resolved
8+
...............
9+
10+
Version 3.12.1 fixes a number of bugs:
11+
12+
- Fixed a bug that caused a multi-document transaction to fail when the first
13+
operation was large bulk write (>48MB) that required splitting a batched
14+
write command (`PYTHON-2915`_).
15+
- Fixed a bug that caused the ``tlsDisableOCSPEndpointCheck`` URI option to
16+
be applied incorrectly (`PYTHON-2866`_).
17+
18+
See the `PyMongo 3.12.1 release notes in JIRA`_ for the list of resolved issues
19+
in this release.
20+
21+
.. _PYTHON-2915: https://jira.mongodb.org/browse/PYTHON-2915
22+
.. _PYTHON-2866: https://jira.mongodb.org/browse/PYTHON-2866
23+
.. _PyMongo 3.12.1 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=31527
24+
425
Changes in Version 3.12.0
526
-------------------------
627

pymongo/bulk.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -285,17 +285,18 @@ def _execute_command(self, generator, write_concern, session,
285285
# sock_info.write_command.
286286
sock_info.validate_session(client, session)
287287
while run:
288-
cmd = SON([(_COMMANDS[run.op_type], self.collection.name),
289-
('ordered', self.ordered)])
290-
if not write_concern.is_server_default:
291-
cmd['writeConcern'] = write_concern.document
292-
if self.bypass_doc_val and sock_info.max_wire_version >= 4:
293-
cmd['bypassDocumentValidation'] = True
288+
cmd_name = _COMMANDS[run.op_type]
294289
bwc = self.bulk_ctx_class(
295-
db_name, cmd, sock_info, op_id, listeners, session,
290+
db_name, cmd_name, sock_info, op_id, listeners, session,
296291
run.op_type, self.collection.codec_options)
297292

298293
while run.idx_offset < len(run.ops):
294+
cmd = SON([(cmd_name, self.collection.name),
295+
('ordered', self.ordered)])
296+
if not write_concern.is_server_default:
297+
cmd['writeConcern'] = write_concern.document
298+
if self.bypass_doc_val and sock_info.max_wire_version >= 4:
299+
cmd['bypassDocumentValidation'] = True
299300
if session:
300301
# Start a new retryable write unless one was already
301302
# started for this command.
@@ -305,9 +306,10 @@ def _execute_command(self, generator, write_concern, session,
305306
session._apply_to(cmd, retryable, ReadPreference.PRIMARY,
306307
sock_info)
307308
sock_info.send_cluster_time(cmd, session, client)
309+
sock_info.add_server_api(cmd)
308310
ops = islice(run.ops, run.idx_offset, None)
309311
# Run as many ops as possible in one command.
310-
result, to_send = bwc.execute(ops, client)
312+
result, to_send = bwc.execute(cmd, ops, client)
311313

312314
# Retryable writeConcernErrors halt the execution of this run.
313315
wce = result.get('writeConcernError', {})
@@ -367,16 +369,16 @@ def retryable_bulk(session, sock_info, retryable):
367369
def execute_insert_no_results(self, sock_info, run, op_id, acknowledged):
368370
"""Execute insert, returning no results.
369371
"""
370-
command = SON([('insert', self.collection.name),
371-
('ordered', self.ordered)])
372+
db = self.collection.database
372373
concern = {'w': int(self.ordered)}
373-
command['writeConcern'] = concern
374+
cmd = SON([('insert', self.collection.name),
375+
('ordered', self.ordered),
376+
('writeConcern', concern)])
374377
if self.bypass_doc_val and sock_info.max_wire_version >= 4:
375-
command['bypassDocumentValidation'] = True
376-
db = self.collection.database
378+
cmd['bypassDocumentValidation'] = True
377379
bwc = _BulkWriteContext(
378-
db.name, command, sock_info, op_id, db.client._event_listeners,
379-
None, _INSERT, self.collection.codec_options)
380+
db.name, 'insert', sock_info, op_id, db.client._event_listeners,
381+
None, _INSERT, self.collection.codec_options, cmd_legacy=cmd)
380382
# Legacy batched OP_INSERT.
381383
_do_batched_insert(
382384
self.collection.full_name, run.ops, True, acknowledged, concern,
@@ -395,17 +397,19 @@ def execute_op_msg_no_results(self, sock_info, generator):
395397
run = self.current_run
396398

397399
while run:
398-
cmd = SON([(_COMMANDS[run.op_type], self.collection.name),
399-
('ordered', False),
400-
('writeConcern', {'w': 0})])
400+
cmd_name = _COMMANDS[run.op_type]
401401
bwc = self.bulk_ctx_class(
402-
db_name, cmd, sock_info, op_id, listeners, None,
402+
db_name, cmd_name, sock_info, op_id, listeners, None,
403403
run.op_type, self.collection.codec_options)
404404

405405
while run.idx_offset < len(run.ops):
406+
cmd = SON([(cmd_name, self.collection.name),
407+
('ordered', False),
408+
('writeConcern', {'w': 0})])
409+
sock_info.add_server_api(cmd)
406410
ops = islice(run.ops, run.idx_offset, None)
407411
# Run as many ops as possible.
408-
to_send = bwc.execute_unack(ops, client)
412+
to_send = bwc.execute_unack(cmd, ops, client)
409413
run.idx_offset += len(to_send)
410414
self.current_run = run = next(generator, None)
411415

pymongo/message.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -916,50 +916,49 @@ def kill_cursors(cursor_ids):
916916
class _BulkWriteContext(object):
917917
"""A wrapper around SocketInfo for use with write splitting functions."""
918918

919-
__slots__ = ('db_name', 'command', 'sock_info', 'op_id',
919+
__slots__ = ('db_name', 'sock_info', 'op_id',
920920
'name', 'field', 'publish', 'start_time', 'listeners',
921-
'session', 'compress', 'op_type', 'codec')
921+
'session', 'compress', 'op_type', 'codec', 'cmd_legacy')
922922

923-
def __init__(self, database_name, command, sock_info, operation_id,
924-
listeners, session, op_type, codec):
923+
def __init__(self, database_name, cmd_name, sock_info, operation_id,
924+
listeners, session, op_type, codec, cmd_legacy=None):
925925
self.db_name = database_name
926-
self.command = command
927926
self.sock_info = sock_info
928927
self.op_id = operation_id
929928
self.listeners = listeners
930929
self.publish = listeners.enabled_for_commands
931-
self.name = next(iter(command))
930+
self.name = cmd_name
932931
self.field = _FIELD_MAP[self.name]
933932
self.start_time = datetime.datetime.now() if self.publish else None
934933
self.session = session
935934
self.compress = True if sock_info.compression_context else False
936935
self.op_type = op_type
937936
self.codec = codec
938-
sock_info.add_server_api(command)
937+
self.cmd_legacy = cmd_legacy
939938

940-
def _batch_command(self, docs):
939+
def _batch_command(self, cmd, docs):
941940
namespace = self.db_name + '.$cmd'
942941
request_id, msg, to_send = _do_bulk_write_command(
943-
namespace, self.op_type, self.command, docs, self.check_keys,
942+
namespace, self.op_type, cmd, docs, self.check_keys,
944943
self.codec, self)
945944
if not to_send:
946945
raise InvalidOperation("cannot do an empty bulk write")
947946
return request_id, msg, to_send
948947

949-
def execute(self, docs, client):
950-
request_id, msg, to_send = self._batch_command(docs)
951-
result = self.write_command(request_id, msg, to_send)
948+
def execute(self, cmd, docs, client):
949+
request_id, msg, to_send = self._batch_command(cmd, docs)
950+
result = self.write_command(cmd, request_id, msg, to_send)
952951
client._process_response(result, self.session)
953952
return result, to_send
954953

955-
def execute_unack(self, docs, client):
956-
request_id, msg, to_send = self._batch_command(docs)
954+
def execute_unack(self, cmd, docs, client):
955+
request_id, msg, to_send = self._batch_command(cmd, docs)
957956
# Though this isn't strictly a "legacy" write, the helper
958957
# handles publishing commands and sending our message
959958
# without receiving a result. Send 0 for max_doc_size
960959
# to disable size checking. Size checking is handled while
961960
# the documents are encoded to BSON.
962-
self.legacy_write(request_id, msg, 0, False, to_send)
961+
self.legacy_write(cmd, request_id, msg, 0, False, to_send)
963962
return to_send
964963

965964
@property
@@ -996,14 +995,16 @@ def legacy_bulk_insert(
996995
request_id, msg = _compress(
997996
2002, msg, self.sock_info.compression_context)
998997
return self.legacy_write(
999-
request_id, msg, max_doc_size, acknowledged, docs)
998+
self.cmd_legacy.copy(), request_id, msg, max_doc_size,
999+
acknowledged, docs)
10001000

1001-
def legacy_write(self, request_id, msg, max_doc_size, acknowledged, docs):
1001+
def legacy_write(self, cmd, request_id, msg, max_doc_size, acknowledged,
1002+
docs):
10021003
"""A proxy for SocketInfo.legacy_write that handles event publishing.
10031004
"""
10041005
if self.publish:
10051006
duration = datetime.datetime.now() - self.start_time
1006-
cmd = self._start(request_id, docs)
1007+
cmd = self._start(cmd, request_id, docs)
10071008
start = datetime.datetime.now()
10081009
try:
10091010
result = self.sock_info.legacy_write(
@@ -1032,12 +1033,12 @@ def legacy_write(self, request_id, msg, max_doc_size, acknowledged, docs):
10321033
self.start_time = datetime.datetime.now()
10331034
return result
10341035

1035-
def write_command(self, request_id, msg, docs):
1036+
def write_command(self, cmd, request_id, msg, docs):
10361037
"""A proxy for SocketInfo.write_command that handles event publishing.
10371038
"""
10381039
if self.publish:
10391040
duration = datetime.datetime.now() - self.start_time
1040-
self._start(request_id, docs)
1041+
self._start(cmd, request_id, docs)
10411042
start = datetime.datetime.now()
10421043
try:
10431044
reply = self.sock_info.write_command(request_id, msg)
@@ -1057,9 +1058,8 @@ def write_command(self, request_id, msg, docs):
10571058
self.start_time = datetime.datetime.now()
10581059
return reply
10591060

1060-
def _start(self, request_id, docs):
1061+
def _start(self, cmd, request_id, docs):
10611062
"""Publish a CommandStartedEvent."""
1062-
cmd = self.command.copy()
10631063
cmd[self.field] = docs
10641064
self.listeners.publish_command_start(
10651065
cmd, self.db_name,
@@ -1092,10 +1092,10 @@ def _fail(self, request_id, failure, duration):
10921092
class _EncryptedBulkWriteContext(_BulkWriteContext):
10931093
__slots__ = ()
10941094

1095-
def _batch_command(self, docs):
1095+
def _batch_command(self, cmd, docs):
10961096
namespace = self.db_name + '.$cmd'
10971097
msg, to_send = _encode_batched_write_command(
1098-
namespace, self.op_type, self.command, docs, self.check_keys,
1098+
namespace, self.op_type, cmd, docs, self.check_keys,
10991099
self.codec, self)
11001100
if not to_send:
11011101
raise InvalidOperation("cannot do an empty bulk write")
@@ -1106,17 +1106,18 @@ def _batch_command(self, docs):
11061106
DEFAULT_RAW_BSON_OPTIONS)
11071107
return cmd, to_send
11081108

1109-
def execute(self, docs, client):
1110-
cmd, to_send = self._batch_command(docs)
1109+
def execute(self, cmd, docs, client):
1110+
batched_cmd, to_send = self._batch_command(cmd, docs)
11111111
result = self.sock_info.command(
1112-
self.db_name, cmd, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
1112+
self.db_name, batched_cmd,
1113+
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
11131114
session=self.session, client=client)
11141115
return result, to_send
11151116

1116-
def execute_unack(self, docs, client):
1117-
cmd, to_send = self._batch_command(docs)
1117+
def execute_unack(self, cmd, docs, client):
1118+
batched_cmd, to_send = self._batch_command(cmd, docs)
11181119
self.sock_info.command(
1119-
self.db_name, cmd, write_concern=WriteConcern(w=0),
1120+
self.db_name, batched_cmd, write_concern=WriteConcern(w=0),
11201121
session=self.session, client=client)
11211122
return to_send
11221123

test/test_transactions.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,36 @@ def gridfs_open_upload_stream(*args, **kwargs):
288288
):
289289
op(*args, session=s)
290290

291+
# Require 4.2+ for large (16MB+) transactions.
292+
@client_context.require_version_min(4, 2)
293+
@client_context.require_transactions
294+
def test_transaction_starts_with_batched_write(self):
295+
# Start a transaction with a batch of operations that needs to be
296+
# split.
297+
listener = OvertCommandListener()
298+
client = rs_client(event_listeners=[listener])
299+
coll = client[self.db.name].test
300+
coll.delete_many({})
301+
listener.reset()
302+
self.addCleanup(client.close)
303+
self.addCleanup(coll.drop)
304+
ops = [InsertOne({'a': '1'*(10*1024*1024)}) for _ in range(10)]
305+
with client.start_session() as session:
306+
with session.start_transaction():
307+
coll.bulk_write(ops, session=session)
308+
# Assert commands were constructed properly.
309+
self.assertEqual(['insert', 'insert', 'insert', 'commitTransaction'],
310+
listener.started_command_names())
311+
first_cmd = listener.results['started'][0].command
312+
self.assertTrue(first_cmd['startTransaction'])
313+
lsid = first_cmd['lsid']
314+
txn_number = first_cmd['txnNumber']
315+
for event in listener.results['started'][1:]:
316+
self.assertNotIn('startTransaction', event.command)
317+
self.assertEqual(lsid, event.command['lsid'])
318+
self.assertEqual(txn_number, event.command['txnNumber'])
319+
self.assertEqual(10, coll.count_documents({}))
320+
291321

292322
class PatchSessionTimeout(object):
293323
"""Patches the client_session's with_transaction timeout for testing."""
@@ -336,6 +366,7 @@ def callback(session):
336366
def test_callback_not_retried_after_timeout(self):
337367
listener = OvertCommandListener()
338368
client = rs_client(event_listeners=[listener])
369+
self.addCleanup(client.close)
339370
coll = client[self.db.name].test
340371

341372
def callback(session):
@@ -365,6 +396,7 @@ def callback(session):
365396
def test_callback_not_retried_after_commit_timeout(self):
366397
listener = OvertCommandListener()
367398
client = rs_client(event_listeners=[listener])
399+
self.addCleanup(client.close)
368400
coll = client[self.db.name].test
369401

370402
def callback(session):
@@ -395,6 +427,7 @@ def callback(session):
395427
def test_commit_not_retried_after_timeout(self):
396428
listener = OvertCommandListener()
397429
client = rs_client(event_listeners=[listener])
430+
self.addCleanup(client.close)
398431
coll = client[self.db.name].test
399432

400433
def callback(session):

test/unified_format.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -985,9 +985,9 @@ def _testOperation_targetedFailPoint(self, spec):
985985
"session %s" % (spec['session'],))
986986

987987
client = single_client('%s:%s' % session._pinned_address)
988+
self.addCleanup(client.close)
988989
self.__set_fail_point(
989990
client=client, command_args=spec['failPoint'])
990-
self.addCleanup(client.close)
991991

992992
def _testOperation_assertSessionTransactionState(self, spec):
993993
session = self.entity_map[spec['session']]

0 commit comments

Comments
 (0)