Skip to content

Commit 559d8b1

Browse files
PYTHON-4596 Only encode each operation document once for MongoClient.bulk_write (mongodb#1797)
1 parent 768858e commit 559d8b1

File tree

1 file changed

+37
-25
lines changed

1 file changed

+37
-25
lines changed

pymongo/message.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -985,11 +985,10 @@ def _start(
985985

986986

987987
def _client_construct_op_msg(
988-
command: Mapping[str, Any],
989-
to_send_ops: list[Mapping[str, Any]],
990-
to_send_ns: list[Mapping[str, Any]],
988+
command_encoded: bytes,
989+
to_send_ops_encoded: list[bytes],
990+
to_send_ns_encoded: list[bytes],
991991
ack: bool,
992-
opts: CodecOptions,
993992
buf: _BytesIO,
994993
) -> int:
995994
# Write flags
@@ -998,7 +997,7 @@ def _client_construct_op_msg(
998997

999998
# Type 0 Section
1000999
buf.write(b"\x00")
1001-
buf.write(_dict_to_bson(command, False, opts))
1000+
buf.write(command_encoded)
10021001

10031002
# Type 1 Section for ops
10041003
buf.write(b"\x01")
@@ -1007,8 +1006,8 @@ def _client_construct_op_msg(
10071006
buf.write(b"\x00\x00\x00\x00")
10081007
buf.write(b"ops\x00")
10091008
# Write all the ops documents
1010-
for op in to_send_ops:
1011-
buf.write(_dict_to_bson(op, False, opts))
1009+
for op_encoded in to_send_ops_encoded:
1010+
buf.write(op_encoded)
10121011
resume_location = buf.tell()
10131012
# Write type 1 section size
10141013
length = buf.tell()
@@ -1023,8 +1022,8 @@ def _client_construct_op_msg(
10231022
buf.write(b"\x00\x00\x00\x00")
10241023
buf.write(b"nsInfo\x00")
10251024
# Write all the nsInfo documents
1026-
for ns in to_send_ns:
1027-
buf.write(_dict_to_bson(ns, False, opts))
1025+
for ns_encoded in to_send_ns_encoded:
1026+
buf.write(ns_encoded)
10281027
# Write type 1 section size
10291028
length = buf.tell()
10301029
buf.seek(size_location)
@@ -1045,19 +1044,23 @@ def _client_batched_op_msg_impl(
10451044

10461045
def _check_doc_size_limits(
10471046
op_type: str,
1048-
document: Mapping[str, Any],
1047+
doc_size: int,
10491048
limit: int,
1050-
) -> int:
1051-
doc_size = len(_dict_to_bson(document, False, opts))
1049+
) -> None:
10521050
if doc_size > limit:
10531051
_raise_document_too_large(op_type, doc_size, limit)
1054-
return doc_size
10551052

10561053
max_bson_size = ctx.max_bson_size
10571054
max_write_batch_size = ctx.max_write_batch_size
10581055
max_message_size = ctx.max_message_size
10591056

1060-
# Don't include bulkWrite-command-agnostic fields in document size calculations.
1057+
command_encoded = _dict_to_bson(command, False, opts)
1058+
# When OP_MSG is used unacknowledged we have to check command
1059+
# document size client-side or applications won't be notified.
1060+
if not ack:
1061+
_check_doc_size_limits("bulkWrite", len(command_encoded), max_bson_size + _COMMAND_OVERHEAD)
1062+
1063+
# Don't include bulkWrite-command-agnostic fields in batch-splitting calculations.
10611064
abridged_keys = ["bulkWrite", "errorsOnly", "ordered"]
10621065
if command.get("bypassDocumentValidation"):
10631066
abridged_keys.append("bypassDocumentValidation")
@@ -1068,17 +1071,14 @@ def _check_doc_size_limits(
10681071
command_abridged = {key: command[key] for key in abridged_keys}
10691072
command_len_abridged = len(_dict_to_bson(command_abridged, False, opts))
10701073

1071-
# When OP_MSG is used unacknowledged we have to check command
1072-
# document size client-side or applications won't be notified.
1073-
if not ack:
1074-
_check_doc_size_limits("bulkWrite", command_abridged, max_bson_size + _COMMAND_OVERHEAD)
1075-
10761074
# Maximum combined size of the ops and nsInfo document sequences.
10771075
max_doc_sequences_bytes = max_message_size - (_OP_MSG_OVERHEAD + command_len_abridged)
10781076

10791077
ns_info = {}
10801078
to_send_ops: list[Mapping[str, Any]] = []
10811079
to_send_ns: list[Mapping[str, int]] = []
1080+
to_send_ops_encoded: list[bytes] = []
1081+
to_send_ns_encoded: list[bytes] = []
10821082
total_ops_length = 0
10831083
total_ns_length = 0
10841084
idx = 0
@@ -1088,11 +1088,13 @@ def _check_doc_size_limits(
10881088
# Check insert/replace document size if unacknowledged.
10891089
if real_op_type == "insert":
10901090
if not ack:
1091-
_check_doc_size_limits(real_op_type, op_doc["document"], max_bson_size)
1091+
doc_size = len(_dict_to_bson(op_doc["document"], False, opts))
1092+
_check_doc_size_limits(real_op_type, doc_size, max_bson_size)
10921093
if real_op_type == "replace":
10931094
op_type = "update"
10941095
if not ack:
1095-
_check_doc_size_limits(real_op_type, op_doc["updateMods"], max_bson_size)
1096+
doc_size = len(_dict_to_bson(op_doc["updateMods"], False, opts))
1097+
_check_doc_size_limits(real_op_type, doc_size, max_bson_size)
10961098

10971099
ns_doc_to_send = None
10981100
ns_length = 0
@@ -1108,30 +1110,40 @@ def _check_doc_size_limits(
11081110
op_doc_to_send[op_type] = ns_info[namespace] # type: ignore[index]
11091111

11101112
# Encode current operation doc and, if newly added, namespace doc.
1111-
op_length = len(_dict_to_bson(op_doc_to_send, False, opts))
1113+
op_doc_encoded = _dict_to_bson(op_doc_to_send, False, opts)
1114+
op_length = len(op_doc_encoded)
11121115
if ns_doc_to_send:
1113-
ns_length = len(_dict_to_bson(ns_doc_to_send, False, opts))
1116+
ns_doc_encoded = _dict_to_bson(ns_doc_to_send, False, opts)
1117+
ns_length = len(ns_doc_encoded)
11141118

11151119
# Check operation document size if unacknowledged.
11161120
if not ack:
1117-
_check_doc_size_limits(op_type, op_doc_to_send, max_bson_size + _COMMAND_OVERHEAD)
1121+
_check_doc_size_limits(op_type, op_length, max_bson_size + _COMMAND_OVERHEAD)
11181122

11191123
new_message_size = total_ops_length + total_ns_length + op_length + ns_length
11201124
# We have enough data, return this batch.
11211125
if new_message_size > max_doc_sequences_bytes:
11221126
break
1127+
1128+
# Add op and ns documents to this batch.
11231129
to_send_ops.append(op_doc_to_send)
1130+
to_send_ops_encoded.append(op_doc_encoded)
11241131
total_ops_length += op_length
11251132
if ns_doc_to_send:
11261133
to_send_ns.append(ns_doc_to_send)
1134+
to_send_ns_encoded.append(ns_doc_encoded)
11271135
total_ns_length += ns_length
1136+
11281137
idx += 1
1138+
11291139
# We have enough documents, return this batch.
11301140
if idx == max_write_batch_size:
11311141
break
11321142

11331143
# Construct the entire OP_MSG.
1134-
length = _client_construct_op_msg(command, to_send_ops, to_send_ns, ack, opts, buf)
1144+
length = _client_construct_op_msg(
1145+
command_encoded, to_send_ops_encoded, to_send_ns_encoded, ack, buf
1146+
)
11351147

11361148
return to_send_ops, to_send_ns, length
11371149

0 commit comments

Comments
 (0)