Skip to content

Commit 4024a1b

Browse files
PYTHON-4668 Improve performance of client.bulk_write (mongodb#1800)
1 parent c03721c commit 4024a1b

File tree

7 files changed

+103
-63
lines changed

7 files changed

+103
-63
lines changed

doc/examples/client_bulk.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ For example, a duplicate key error on the third operation below aborts the remai
145145
'idx': 2,
146146
'code': 11000,
147147
'errmsg': 'E11000 duplicate key error ... dup key: { _id: 3 }', ...
148-
'op': {'insert': 'db.test_three', 'document': {'_id': 3}}}]
148+
'op': {'insert': 0, 'document': {'_id': 3}}}]
149149
>>> exception.partial_result.inserted_count
150150
2
151151
>>> exception.partial_result.deleted_count
@@ -181,7 +181,7 @@ For example, the fourth and fifth write operations below get executed successful
181181
'idx': 2,
182182
'code': 11000,
183183
'errmsg': 'E11000 duplicate key error ... dup key: { _id: 5 }', ...
184-
'op': {'insert': 'db.test_five', 'document': {'_id': 5}}}]
184+
'op': {'insert': 0, 'document': {'_id': 5}}}]
185185
>>> exception.partial_result.inserted_count
186186
3
187187
>>> exception.partial_result.deleted_count

pymongo/asynchronous/client_bulk.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(
108108
self.verbose_results = verbose_results
109109

110110
self.ops: list[tuple[str, Mapping[str, Any]]] = []
111+
self.namespaces: list[str] = []
111112
self.idx_offset: int = 0
112113
self.total_ops: int = 0
113114

@@ -132,8 +133,9 @@ def add_insert(self, namespace: str, document: _DocumentOut) -> None:
132133
# Generate ObjectId client side.
133134
if not (isinstance(document, RawBSONDocument) or "_id" in document):
134135
document["_id"] = ObjectId()
135-
cmd = {"insert": namespace, "document": document}
136+
cmd = {"insert": -1, "document": document}
136137
self.ops.append(("insert", cmd))
138+
self.namespaces.append(namespace)
137139
self.total_ops += 1
138140

139141
def add_update(
@@ -150,7 +152,7 @@ def add_update(
150152
"""Create an update document and add it to the list of ops."""
151153
validate_ok_for_update(update)
152154
cmd = {
153-
"update": namespace,
155+
"update": -1,
154156
"filter": selector,
155157
"updateMods": update,
156158
"multi": multi,
@@ -171,6 +173,7 @@ def add_update(
171173
# A bulk_write containing an update_many is not retryable.
172174
self.is_retryable = False
173175
self.ops.append(("update", cmd))
176+
self.namespaces.append(namespace)
174177
self.total_ops += 1
175178

176179
def add_replace(
@@ -185,7 +188,7 @@ def add_replace(
185188
"""Create a replace document and add it to the list of ops."""
186189
validate_ok_for_replace(replacement)
187190
cmd = {
188-
"update": namespace,
191+
"update": -1,
189192
"filter": selector,
190193
"updateMods": replacement,
191194
"multi": False,
@@ -200,6 +203,7 @@ def add_replace(
200203
self.uses_collation = True
201204
cmd["collation"] = collation
202205
self.ops.append(("replace", cmd))
206+
self.namespaces.append(namespace)
203207
self.total_ops += 1
204208

205209
def add_delete(
@@ -211,7 +215,7 @@ def add_delete(
211215
hint: Union[str, dict[str, Any], None] = None,
212216
) -> None:
213217
"""Create a delete document and add it to the list of ops."""
214-
cmd = {"delete": namespace, "filter": selector, "multi": multi}
218+
cmd = {"delete": -1, "filter": selector, "multi": multi}
215219
if hint is not None:
216220
self.uses_hint_delete = True
217221
cmd["hint"] = hint
@@ -222,6 +226,7 @@ def add_delete(
222226
# A bulk_write containing an update_many is not retryable.
223227
self.is_retryable = False
224228
self.ops.append(("delete", cmd))
229+
self.namespaces.append(namespace)
225230
self.total_ops += 1
226231

227232
@_handle_reauth
@@ -407,9 +412,10 @@ async def _execute_batch_unack(
407412
bwc: _ClientBulkWriteContext,
408413
cmd: dict[str, Any],
409414
ops: list[tuple[str, Mapping[str, Any]]],
415+
namespaces: list[str],
410416
) -> tuple[list[Mapping[str, Any]], list[Mapping[str, Any]]]:
411417
"""Executes a batch of bulkWrite server commands (unack)."""
412-
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops)
418+
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops, namespaces)
413419
await self.unack_write(bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client) # type: ignore[arg-type]
414420
return to_send_ops, to_send_ns
415421

@@ -418,9 +424,10 @@ async def _execute_batch(
418424
bwc: _ClientBulkWriteContext,
419425
cmd: dict[str, Any],
420426
ops: list[tuple[str, Mapping[str, Any]]],
427+
namespaces: list[str],
421428
) -> tuple[dict[str, Any], list[Mapping[str, Any]], list[Mapping[str, Any]]]:
422429
"""Executes a batch of bulkWrite server commands (ack)."""
423-
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops)
430+
request_id, msg, to_send_ops, to_send_ns = bwc.batch_command(cmd, ops, namespaces)
424431
result = await self.write_command(
425432
bwc, cmd, request_id, msg, to_send_ops, to_send_ns, self.client
426433
) # type: ignore[arg-type]
@@ -540,11 +547,12 @@ async def _execute_command(
540547
# CSOT: apply timeout before encoding the command.
541548
conn.apply_timeout(self.client, cmd)
542549
ops = islice(self.ops, self.idx_offset, None)
550+
namespaces = islice(self.namespaces, self.idx_offset, None)
543551

544552
# Run as many ops as possible in one server command.
545553
if write_concern.acknowledged:
546-
raw_result, to_send_ops, _ = await self._execute_batch(bwc, cmd, ops) # type: ignore[arg-type]
547-
result = copy.deepcopy(raw_result)
554+
raw_result, to_send_ops, _ = await self._execute_batch(bwc, cmd, ops, namespaces) # type: ignore[arg-type]
555+
result = raw_result
548556

549557
# Top-level server/network error.
550558
if result.get("error"):
@@ -600,7 +608,7 @@ async def _execute_command(
600608
self.started_retryable_write = False
601609

602610
else:
603-
to_send_ops, _ = await self._execute_batch_unack(bwc, cmd, ops) # type: ignore[arg-type]
611+
to_send_ops, _ = await self._execute_batch_unack(bwc, cmd, ops, namespaces) # type: ignore[arg-type]
604612

605613
self.idx_offset += len(to_send_ops)
606614

@@ -697,9 +705,10 @@ async def execute_command_unack_unordered(
697705

698706
conn.add_server_api(cmd)
699707
ops = islice(self.ops, self.idx_offset, None)
708+
namespaces = islice(self.namespaces, self.idx_offset, None)
700709

701710
# Run as many ops as possible in one server command.
702-
to_send_ops, _ = await self._execute_batch_unack(bwc, cmd, ops) # type: ignore[arg-type]
711+
to_send_ops, _ = await self._execute_batch_unack(bwc, cmd, ops, namespaces) # type: ignore[arg-type]
703712

704713
self.idx_offset += len(to_send_ops)
705714

pymongo/message.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
"""
2222
from __future__ import annotations
2323

24-
import copy
2524
import datetime
2625
import random
2726
import struct
@@ -950,10 +949,13 @@ def __init__(
950949
)
951950

952951
def batch_command(
953-
self, cmd: MutableMapping[str, Any], operations: list[tuple[str, Mapping[str, Any]]]
952+
self,
953+
cmd: MutableMapping[str, Any],
954+
operations: list[tuple[str, Mapping[str, Any]]],
955+
namespaces: list[str],
954956
) -> tuple[int, Union[bytes, dict[str, Any]], list[Mapping[str, Any]], list[Mapping[str, Any]]]:
955957
request_id, msg, to_send_ops, to_send_ns = _client_do_batched_op_msg(
956-
cmd, operations, self.codec, self
958+
cmd, operations, namespaces, self.codec, self
957959
)
958960
if not to_send_ops:
959961
raise InvalidOperation("cannot do an empty bulk write")
@@ -1035,6 +1037,7 @@ def _client_construct_op_msg(
10351037
def _client_batched_op_msg_impl(
10361038
command: Mapping[str, Any],
10371039
operations: list[tuple[str, Mapping[str, Any]]],
1040+
namespaces: list[str],
10381041
ack: bool,
10391042
opts: CodecOptions,
10401043
ctx: _ClientBulkWriteContext,
@@ -1076,14 +1079,14 @@ def _check_doc_size_limits(
10761079

10771080
ns_info = {}
10781081
to_send_ops: list[Mapping[str, Any]] = []
1079-
to_send_ns: list[Mapping[str, int]] = []
1082+
to_send_ns: list[Mapping[str, str]] = []
10801083
to_send_ops_encoded: list[bytes] = []
10811084
to_send_ns_encoded: list[bytes] = []
10821085
total_ops_length = 0
10831086
total_ns_length = 0
10841087
idx = 0
10851088

1086-
for real_op_type, op_doc in operations:
1089+
for (real_op_type, op_doc), namespace in zip(operations, namespaces):
10871090
op_type = real_op_type
10881091
# Check insert/replace document size if unacknowledged.
10891092
if real_op_type == "insert":
@@ -1096,24 +1099,23 @@ def _check_doc_size_limits(
10961099
doc_size = len(_dict_to_bson(op_doc["updateMods"], False, opts))
10971100
_check_doc_size_limits(real_op_type, doc_size, max_bson_size)
10981101

1099-
ns_doc_to_send = None
1102+
ns_doc = None
11001103
ns_length = 0
1101-
namespace = op_doc[op_type]
1104+
11021105
if namespace not in ns_info:
1103-
ns_doc_to_send = {"ns": namespace}
1106+
ns_doc = {"ns": namespace}
11041107
new_ns_index = len(to_send_ns)
11051108
ns_info[namespace] = new_ns_index
11061109

11071110
# First entry in the operation doc has the operation type as its
11081111
# key and the index of its namespace within ns_info as its value.
1109-
op_doc_to_send = copy.deepcopy(op_doc)
1110-
op_doc_to_send[op_type] = ns_info[namespace] # type: ignore[index]
1112+
op_doc[op_type] = ns_info[namespace] # type: ignore[index]
11111113

11121114
# Encode current operation doc and, if newly added, namespace doc.
1113-
op_doc_encoded = _dict_to_bson(op_doc_to_send, False, opts)
1115+
op_doc_encoded = _dict_to_bson(op_doc, False, opts)
11141116
op_length = len(op_doc_encoded)
1115-
if ns_doc_to_send:
1116-
ns_doc_encoded = _dict_to_bson(ns_doc_to_send, False, opts)
1117+
if ns_doc:
1118+
ns_doc_encoded = _dict_to_bson(ns_doc, False, opts)
11171119
ns_length = len(ns_doc_encoded)
11181120

11191121
# Check operation document size if unacknowledged.
@@ -1128,11 +1130,11 @@ def _check_doc_size_limits(
11281130
break
11291131

11301132
# Add op and ns documents to this batch.
1131-
to_send_ops.append(op_doc_to_send)
1133+
to_send_ops.append(op_doc)
11321134
to_send_ops_encoded.append(op_doc_encoded)
11331135
total_ops_length += op_length
1134-
if ns_doc_to_send:
1135-
to_send_ns.append(ns_doc_to_send)
1136+
if ns_doc:
1137+
to_send_ns.append(ns_doc)
11361138
to_send_ns_encoded.append(ns_doc_encoded)
11371139
total_ns_length += ns_length
11381140

@@ -1153,6 +1155,7 @@ def _check_doc_size_limits(
11531155
def _client_encode_batched_op_msg(
11541156
command: Mapping[str, Any],
11551157
operations: list[tuple[str, Mapping[str, Any]]],
1158+
namespaces: list[str],
11561159
ack: bool,
11571160
opts: CodecOptions,
11581161
ctx: _ClientBulkWriteContext,
@@ -1163,14 +1166,15 @@ def _client_encode_batched_op_msg(
11631166
buf = _BytesIO()
11641167

11651168
to_send_ops, to_send_ns, _ = _client_batched_op_msg_impl(
1166-
command, operations, ack, opts, ctx, buf
1169+
command, operations, namespaces, ack, opts, ctx, buf
11671170
)
11681171
return buf.getvalue(), to_send_ops, to_send_ns
11691172

11701173

11711174
def _client_batched_op_msg_compressed(
11721175
command: Mapping[str, Any],
11731176
operations: list[tuple[str, Mapping[str, Any]]],
1177+
namespaces: list[str],
11741178
ack: bool,
11751179
opts: CodecOptions,
11761180
ctx: _ClientBulkWriteContext,
@@ -1179,7 +1183,7 @@ def _client_batched_op_msg_compressed(
11791183
with OP_MSG, compressed.
11801184
"""
11811185
data, to_send_ops, to_send_ns = _client_encode_batched_op_msg(
1182-
command, operations, ack, opts, ctx
1186+
command, operations, namespaces, ack, opts, ctx
11831187
)
11841188

11851189
assert ctx.conn.compression_context is not None
@@ -1190,6 +1194,7 @@ def _client_batched_op_msg_compressed(
11901194
def _client_batched_op_msg(
11911195
command: Mapping[str, Any],
11921196
operations: list[tuple[str, Mapping[str, Any]]],
1197+
namespaces: list[str],
11931198
ack: bool,
11941199
opts: CodecOptions,
11951200
ctx: _ClientBulkWriteContext,
@@ -1203,7 +1208,7 @@ def _client_batched_op_msg(
12031208
buf.write(b"\x00\x00\x00\x00\xdd\x07\x00\x00")
12041209

12051210
to_send_ops, to_send_ns, length = _client_batched_op_msg_impl(
1206-
command, operations, ack, opts, ctx, buf
1211+
command, operations, namespaces, ack, opts, ctx, buf
12071212
)
12081213

12091214
# Header - request id and message length
@@ -1219,6 +1224,7 @@ def _client_batched_op_msg(
12191224
def _client_do_batched_op_msg(
12201225
command: MutableMapping[str, Any],
12211226
operations: list[tuple[str, Mapping[str, Any]]],
1227+
namespaces: list[str],
12221228
opts: CodecOptions,
12231229
ctx: _ClientBulkWriteContext,
12241230
) -> tuple[int, bytes, list[Mapping[str, Any]], list[Mapping[str, Any]]]:
@@ -1231,8 +1237,8 @@ def _client_do_batched_op_msg(
12311237
else:
12321238
ack = True
12331239
if ctx.conn.compression_context:
1234-
return _client_batched_op_msg_compressed(command, operations, ack, opts, ctx)
1235-
return _client_batched_op_msg(command, operations, ack, opts, ctx)
1240+
return _client_batched_op_msg_compressed(command, operations, namespaces, ack, opts, ctx)
1241+
return _client_batched_op_msg(command, operations, namespaces, ack, opts, ctx)
12361242

12371243

12381244
# End OP_MSG -----------------------------------------------------

0 commit comments

Comments
 (0)