21
21
"""
22
22
from __future__ import annotations
23
23
24
- import copy
25
24
import datetime
26
25
import random
27
26
import struct
@@ -950,10 +949,13 @@ def __init__(
950
949
)
951
950
952
951
def batch_command (
953
- self , cmd : MutableMapping [str , Any ], operations : list [tuple [str , Mapping [str , Any ]]]
952
+ self ,
953
+ cmd : MutableMapping [str , Any ],
954
+ operations : list [tuple [str , Mapping [str , Any ]]],
955
+ namespaces : list [str ],
954
956
) -> tuple [int , Union [bytes , dict [str , Any ]], list [Mapping [str , Any ]], list [Mapping [str , Any ]]]:
955
957
request_id , msg , to_send_ops , to_send_ns = _client_do_batched_op_msg (
956
- cmd , operations , self .codec , self
958
+ cmd , operations , namespaces , self .codec , self
957
959
)
958
960
if not to_send_ops :
959
961
raise InvalidOperation ("cannot do an empty bulk write" )
@@ -1035,6 +1037,7 @@ def _client_construct_op_msg(
1035
1037
def _client_batched_op_msg_impl (
1036
1038
command : Mapping [str , Any ],
1037
1039
operations : list [tuple [str , Mapping [str , Any ]]],
1040
+ namespaces : list [str ],
1038
1041
ack : bool ,
1039
1042
opts : CodecOptions ,
1040
1043
ctx : _ClientBulkWriteContext ,
@@ -1076,14 +1079,14 @@ def _check_doc_size_limits(
1076
1079
1077
1080
ns_info = {}
1078
1081
to_send_ops : list [Mapping [str , Any ]] = []
1079
- to_send_ns : list [Mapping [str , int ]] = []
1082
+ to_send_ns : list [Mapping [str , str ]] = []
1080
1083
to_send_ops_encoded : list [bytes ] = []
1081
1084
to_send_ns_encoded : list [bytes ] = []
1082
1085
total_ops_length = 0
1083
1086
total_ns_length = 0
1084
1087
idx = 0
1085
1088
1086
- for real_op_type , op_doc in operations :
1089
+ for ( real_op_type , op_doc ), namespace in zip ( operations , namespaces ) :
1087
1090
op_type = real_op_type
1088
1091
# Check insert/replace document size if unacknowledged.
1089
1092
if real_op_type == "insert" :
@@ -1096,24 +1099,23 @@ def _check_doc_size_limits(
1096
1099
doc_size = len (_dict_to_bson (op_doc ["updateMods" ], False , opts ))
1097
1100
_check_doc_size_limits (real_op_type , doc_size , max_bson_size )
1098
1101
1099
- ns_doc_to_send = None
1102
+ ns_doc = None
1100
1103
ns_length = 0
1101
- namespace = op_doc [ op_type ]
1104
+
1102
1105
if namespace not in ns_info :
1103
- ns_doc_to_send = {"ns" : namespace }
1106
+ ns_doc = {"ns" : namespace }
1104
1107
new_ns_index = len (to_send_ns )
1105
1108
ns_info [namespace ] = new_ns_index
1106
1109
1107
1110
# First entry in the operation doc has the operation type as its
1108
1111
# key and the index of its namespace within ns_info as its value.
1109
- op_doc_to_send = copy .deepcopy (op_doc )
1110
- op_doc_to_send [op_type ] = ns_info [namespace ] # type: ignore[index]
1112
+ op_doc [op_type ] = ns_info [namespace ] # type: ignore[index]
1111
1113
1112
1114
# Encode current operation doc and, if newly added, namespace doc.
1113
- op_doc_encoded = _dict_to_bson (op_doc_to_send , False , opts )
1115
+ op_doc_encoded = _dict_to_bson (op_doc , False , opts )
1114
1116
op_length = len (op_doc_encoded )
1115
- if ns_doc_to_send :
1116
- ns_doc_encoded = _dict_to_bson (ns_doc_to_send , False , opts )
1117
+ if ns_doc :
1118
+ ns_doc_encoded = _dict_to_bson (ns_doc , False , opts )
1117
1119
ns_length = len (ns_doc_encoded )
1118
1120
1119
1121
# Check operation document size if unacknowledged.
@@ -1128,11 +1130,11 @@ def _check_doc_size_limits(
1128
1130
break
1129
1131
1130
1132
# Add op and ns documents to this batch.
1131
- to_send_ops .append (op_doc_to_send )
1133
+ to_send_ops .append (op_doc )
1132
1134
to_send_ops_encoded .append (op_doc_encoded )
1133
1135
total_ops_length += op_length
1134
- if ns_doc_to_send :
1135
- to_send_ns .append (ns_doc_to_send )
1136
+ if ns_doc :
1137
+ to_send_ns .append (ns_doc )
1136
1138
to_send_ns_encoded .append (ns_doc_encoded )
1137
1139
total_ns_length += ns_length
1138
1140
@@ -1153,6 +1155,7 @@ def _check_doc_size_limits(
1153
1155
def _client_encode_batched_op_msg (
1154
1156
command : Mapping [str , Any ],
1155
1157
operations : list [tuple [str , Mapping [str , Any ]]],
1158
+ namespaces : list [str ],
1156
1159
ack : bool ,
1157
1160
opts : CodecOptions ,
1158
1161
ctx : _ClientBulkWriteContext ,
@@ -1163,14 +1166,15 @@ def _client_encode_batched_op_msg(
1163
1166
buf = _BytesIO ()
1164
1167
1165
1168
to_send_ops , to_send_ns , _ = _client_batched_op_msg_impl (
1166
- command , operations , ack , opts , ctx , buf
1169
+ command , operations , namespaces , ack , opts , ctx , buf
1167
1170
)
1168
1171
return buf .getvalue (), to_send_ops , to_send_ns
1169
1172
1170
1173
1171
1174
def _client_batched_op_msg_compressed (
1172
1175
command : Mapping [str , Any ],
1173
1176
operations : list [tuple [str , Mapping [str , Any ]]],
1177
+ namespaces : list [str ],
1174
1178
ack : bool ,
1175
1179
opts : CodecOptions ,
1176
1180
ctx : _ClientBulkWriteContext ,
@@ -1179,7 +1183,7 @@ def _client_batched_op_msg_compressed(
1179
1183
with OP_MSG, compressed.
1180
1184
"""
1181
1185
data , to_send_ops , to_send_ns = _client_encode_batched_op_msg (
1182
- command , operations , ack , opts , ctx
1186
+ command , operations , namespaces , ack , opts , ctx
1183
1187
)
1184
1188
1185
1189
assert ctx .conn .compression_context is not None
@@ -1190,6 +1194,7 @@ def _client_batched_op_msg_compressed(
1190
1194
def _client_batched_op_msg (
1191
1195
command : Mapping [str , Any ],
1192
1196
operations : list [tuple [str , Mapping [str , Any ]]],
1197
+ namespaces : list [str ],
1193
1198
ack : bool ,
1194
1199
opts : CodecOptions ,
1195
1200
ctx : _ClientBulkWriteContext ,
@@ -1203,7 +1208,7 @@ def _client_batched_op_msg(
1203
1208
buf .write (b"\x00 \x00 \x00 \x00 \xdd \x07 \x00 \x00 " )
1204
1209
1205
1210
to_send_ops , to_send_ns , length = _client_batched_op_msg_impl (
1206
- command , operations , ack , opts , ctx , buf
1211
+ command , operations , namespaces , ack , opts , ctx , buf
1207
1212
)
1208
1213
1209
1214
# Header - request id and message length
@@ -1219,6 +1224,7 @@ def _client_batched_op_msg(
1219
1224
def _client_do_batched_op_msg (
1220
1225
command : MutableMapping [str , Any ],
1221
1226
operations : list [tuple [str , Mapping [str , Any ]]],
1227
+ namespaces : list [str ],
1222
1228
opts : CodecOptions ,
1223
1229
ctx : _ClientBulkWriteContext ,
1224
1230
) -> tuple [int , bytes , list [Mapping [str , Any ]], list [Mapping [str , Any ]]]:
@@ -1231,8 +1237,8 @@ def _client_do_batched_op_msg(
1231
1237
else :
1232
1238
ack = True
1233
1239
if ctx .conn .compression_context :
1234
- return _client_batched_op_msg_compressed (command , operations , ack , opts , ctx )
1235
- return _client_batched_op_msg (command , operations , ack , opts , ctx )
1240
+ return _client_batched_op_msg_compressed (command , operations , namespaces , ack , opts , ctx )
1241
+ return _client_batched_op_msg (command , operations , namespaces , ack , opts , ctx )
1236
1242
1237
1243
1238
1244
# End OP_MSG -----------------------------------------------------
0 commit comments