Skip to content

Commit 7467aa6

Browse files
authored
PYTHON-2915 Fix bug when starting a transaction with a large bulk write (#743)
1 parent a80169d commit 7467aa6

File tree

3 files changed

+72
-40
lines changed

3 files changed

+72
-40
lines changed

pymongo/bulk.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -267,17 +267,18 @@ def _execute_command(self, generator, write_concern, session,
267267
# sock_info.write_command.
268268
sock_info.validate_session(client, session)
269269
while run:
270-
cmd = SON([(_COMMANDS[run.op_type], self.collection.name),
271-
('ordered', self.ordered)])
272-
if not write_concern.is_server_default:
273-
cmd['writeConcern'] = write_concern.document
274-
if self.bypass_doc_val:
275-
cmd['bypassDocumentValidation'] = True
270+
cmd_name = _COMMANDS[run.op_type]
276271
bwc = self.bulk_ctx_class(
277-
db_name, cmd, sock_info, op_id, listeners, session,
272+
db_name, cmd_name, sock_info, op_id, listeners, session,
278273
run.op_type, self.collection.codec_options)
279274

280275
while run.idx_offset < len(run.ops):
276+
cmd = SON([(cmd_name, self.collection.name),
277+
('ordered', self.ordered)])
278+
if not write_concern.is_server_default:
279+
cmd['writeConcern'] = write_concern.document
280+
if self.bypass_doc_val:
281+
cmd['bypassDocumentValidation'] = True
281282
if session:
282283
# Start a new retryable write unless one was already
283284
# started for this command.
@@ -287,9 +288,10 @@ def _execute_command(self, generator, write_concern, session,
287288
session._apply_to(cmd, retryable, ReadPreference.PRIMARY,
288289
sock_info)
289290
sock_info.send_cluster_time(cmd, session, client)
291+
sock_info.add_server_api(cmd)
290292
ops = islice(run.ops, run.idx_offset, None)
291293
# Run as many ops as possible in one command.
292-
result, to_send = bwc.execute(ops, client)
294+
result, to_send = bwc.execute(cmd, ops, client)
293295

294296
# Retryable writeConcernErrors halt the execution of this run.
295297
wce = result.get('writeConcernError', {})
@@ -359,17 +361,19 @@ def execute_op_msg_no_results(self, sock_info, generator):
359361
run = self.current_run
360362

361363
while run:
362-
cmd = SON([(_COMMANDS[run.op_type], self.collection.name),
363-
('ordered', False),
364-
('writeConcern', {'w': 0})])
364+
cmd_name = _COMMANDS[run.op_type]
365365
bwc = self.bulk_ctx_class(
366-
db_name, cmd, sock_info, op_id, listeners, None,
366+
db_name, cmd_name, sock_info, op_id, listeners, None,
367367
run.op_type, self.collection.codec_options)
368368

369369
while run.idx_offset < len(run.ops):
370+
cmd = SON([(cmd_name, self.collection.name),
371+
('ordered', False),
372+
('writeConcern', {'w': 0})])
373+
sock_info.add_server_api(cmd)
370374
ops = islice(run.ops, run.idx_offset, None)
371375
# Run as many ops as possible.
372-
to_send = bwc.execute_unack(ops, client)
376+
to_send = bwc.execute_unack(cmd, ops, client)
373377
run.idx_offset += len(to_send)
374378
self.current_run = run = next(generator, None)
375379

pymongo/message.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -704,50 +704,48 @@ def _get_more(collection_name, num_to_return, cursor_id, ctx=None):
704704
class _BulkWriteContext(object):
705705
"""A wrapper around SocketInfo for use with write splitting functions."""
706706

707-
__slots__ = ('db_name', 'command', 'sock_info', 'op_id',
707+
__slots__ = ('db_name', 'sock_info', 'op_id',
708708
'name', 'field', 'publish', 'start_time', 'listeners',
709709
'session', 'compress', 'op_type', 'codec')
710710

711-
def __init__(self, database_name, command, sock_info, operation_id,
711+
def __init__(self, database_name, cmd_name, sock_info, operation_id,
712712
listeners, session, op_type, codec):
713713
self.db_name = database_name
714-
self.command = command
715714
self.sock_info = sock_info
716715
self.op_id = operation_id
717716
self.listeners = listeners
718717
self.publish = listeners.enabled_for_commands
719-
self.name = next(iter(command))
718+
self.name = cmd_name
720719
self.field = _FIELD_MAP[self.name]
721720
self.start_time = datetime.datetime.now() if self.publish else None
722721
self.session = session
723722
self.compress = True if sock_info.compression_context else False
724723
self.op_type = op_type
725724
self.codec = codec
726-
sock_info.add_server_api(command)
727725

728-
def _batch_command(self, docs):
726+
def _batch_command(self, cmd, docs):
729727
namespace = self.db_name + '.$cmd'
730728
request_id, msg, to_send = _do_batched_op_msg(
731-
namespace, self.op_type, self.command, docs, self.check_keys,
729+
namespace, self.op_type, cmd, docs, self.check_keys,
732730
self.codec, self)
733731
if not to_send:
734732
raise InvalidOperation("cannot do an empty bulk write")
735733
return request_id, msg, to_send
736734

737-
def execute(self, docs, client):
738-
request_id, msg, to_send = self._batch_command(docs)
739-
result = self.write_command(request_id, msg, to_send)
735+
def execute(self, cmd, docs, client):
736+
request_id, msg, to_send = self._batch_command(cmd, docs)
737+
result = self.write_command(cmd, request_id, msg, to_send)
740738
client._process_response(result, self.session)
741739
return result, to_send
742740

743-
def execute_unack(self, docs, client):
744-
request_id, msg, to_send = self._batch_command(docs)
741+
def execute_unack(self, cmd, docs, client):
742+
request_id, msg, to_send = self._batch_command(cmd, docs)
745743
# Though this isn't strictly a "legacy" write, the helper
746744
# handles publishing commands and sending our message
747745
# without receiving a result. Send 0 for max_doc_size
748746
# to disable size checking. Size checking is handled while
749747
# the documents are encoded to BSON.
750-
self.unack_write(request_id, msg, 0, to_send)
748+
self.unack_write(cmd, request_id, msg, 0, to_send)
751749
return to_send
752750

753751
@property
@@ -778,12 +776,12 @@ def max_split_size(self):
778776
"""The maximum size of a BSON command before batch splitting."""
779777
return self.max_bson_size
780778

781-
def unack_write(self, request_id, msg, max_doc_size, docs):
779+
def unack_write(self, cmd, request_id, msg, max_doc_size, docs):
782780
"""A proxy for SocketInfo.unack_write that handles event publishing.
783781
"""
784782
if self.publish:
785783
duration = datetime.datetime.now() - self.start_time
786-
cmd = self._start(request_id, docs)
784+
cmd = self._start(cmd, request_id, docs)
787785
start = datetime.datetime.now()
788786
try:
789787
result = self.sock_info.unack_write(msg, max_doc_size)
@@ -811,12 +809,12 @@ def unack_write(self, request_id, msg, max_doc_size, docs):
811809
self.start_time = datetime.datetime.now()
812810
return result
813811

814-
def write_command(self, request_id, msg, docs):
812+
def write_command(self, cmd, request_id, msg, docs):
815813
"""A proxy for SocketInfo.write_command that handles event publishing.
816814
"""
817815
if self.publish:
818816
duration = datetime.datetime.now() - self.start_time
819-
self._start(request_id, docs)
817+
self._start(cmd, request_id, docs)
820818
start = datetime.datetime.now()
821819
try:
822820
reply = self.sock_info.write_command(request_id, msg)
@@ -836,9 +834,8 @@ def write_command(self, request_id, msg, docs):
836834
self.start_time = datetime.datetime.now()
837835
return reply
838836

839-
def _start(self, request_id, docs):
837+
def _start(self, cmd, request_id, docs):
840838
"""Publish a CommandStartedEvent."""
841-
cmd = self.command.copy()
842839
cmd[self.field] = docs
843840
self.listeners.publish_command_start(
844841
cmd, self.db_name,
@@ -871,10 +868,10 @@ def _fail(self, request_id, failure, duration):
871868
class _EncryptedBulkWriteContext(_BulkWriteContext):
872869
__slots__ = ()
873870

874-
def _batch_command(self, docs):
871+
def _batch_command(self, cmd, docs):
875872
namespace = self.db_name + '.$cmd'
876873
msg, to_send = _encode_batched_write_command(
877-
namespace, self.op_type, self.command, docs, self.check_keys,
874+
namespace, self.op_type, cmd, docs, self.check_keys,
878875
self.codec, self)
879876
if not to_send:
880877
raise InvalidOperation("cannot do an empty bulk write")
@@ -885,17 +882,18 @@ def _batch_command(self, docs):
885882
DEFAULT_RAW_BSON_OPTIONS)
886883
return cmd, to_send
887884

888-
def execute(self, docs, client):
889-
cmd, to_send = self._batch_command(docs)
885+
def execute(self, cmd, docs, client):
886+
batched_cmd, to_send = self._batch_command(cmd, docs)
890887
result = self.sock_info.command(
891-
self.db_name, cmd, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
888+
self.db_name, batched_cmd,
889+
codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
892890
session=self.session, client=client)
893891
return result, to_send
894892

895-
def execute_unack(self, docs, client):
896-
cmd, to_send = self._batch_command(docs)
893+
def execute_unack(self, cmd, docs, client):
894+
batched_cmd, to_send = self._batch_command(cmd, docs)
897895
self.sock_info.command(
898-
self.db_name, cmd, write_concern=WriteConcern(w=0),
896+
self.db_name, batched_cmd, write_concern=WriteConcern(w=0),
899897
session=self.session, client=client)
900898
return to_send
901899

test/test_transactions.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,36 @@ def gridfs_open_upload_stream(*args, **kwargs):
286286
):
287287
op(*args, session=s)
288288

289+
# Require 4.2+ for large (16MB+) transactions.
290+
@client_context.require_version_min(4, 2)
291+
@client_context.require_transactions
292+
def test_transaction_starts_with_batched_write(self):
293+
# Start a transaction with a batch of operations that needs to be
294+
# split.
295+
listener = OvertCommandListener()
296+
client = rs_client(event_listeners=[listener])
297+
coll = client[self.db.name].test
298+
coll.delete_many({})
299+
listener.reset()
300+
self.addCleanup(client.close)
301+
self.addCleanup(coll.drop)
302+
ops = [InsertOne({'a': '1'*(10*1024*1024)}) for _ in range(10)]
303+
with client.start_session() as session:
304+
with session.start_transaction():
305+
coll.bulk_write(ops, session=session)
306+
# Assert commands were constructed properly.
307+
self.assertEqual(['insert', 'insert', 'insert', 'commitTransaction'],
308+
listener.started_command_names())
309+
first_cmd = listener.results['started'][0].command
310+
self.assertTrue(first_cmd['startTransaction'])
311+
lsid = first_cmd['lsid']
312+
txn_number = first_cmd['txnNumber']
313+
for event in listener.results['started'][1:]:
314+
self.assertNotIn('startTransaction', event.command)
315+
self.assertEqual(lsid, event.command['lsid'])
316+
self.assertEqual(txn_number, event.command['txnNumber'])
317+
self.assertEqual(10, coll.count_documents({}))
318+
289319

290320
class PatchSessionTimeout(object):
291321
"""Patches the client_session's with_transaction timeout for testing."""

0 commit comments

Comments
 (0)