26
26
from typing import (
27
27
TYPE_CHECKING ,
28
28
Any ,
29
- Generator ,
29
+ Callable ,
30
+ Iterable ,
30
31
Iterator ,
31
32
Mapping ,
32
33
Optional ,
@@ -111,9 +112,6 @@ def __init__(
111
112
self .uses_hint_update = False
112
113
self .uses_hint_delete = False
113
114
self .uses_sort = False
114
- self .is_retryable = True
115
- self .retrying = False
116
- self .started_retryable_write = False
117
115
# Extra state so that we know where to pick up on a retry attempt.
118
116
self .current_run = None
119
117
self .next_run = None
@@ -129,13 +127,32 @@ def bulk_ctx_class(self) -> Type[_BulkWriteContext]:
129
127
self .is_encrypted = False
130
128
return _BulkWriteContext
131
129
132
- def add_insert (self , document : _DocumentOut ) -> None :
130
+ @property
131
+ def is_retryable (self ) -> bool :
132
+ if self .current_run :
133
+ return self .current_run .is_retryable
134
+ return True
135
+
136
+ @property
137
+ def retrying (self ) -> bool :
138
+ if self .current_run :
139
+ return self .current_run .retrying
140
+ return False
141
+
142
+ @property
143
+ def started_retryable_write (self ) -> bool :
144
+ if self .current_run :
145
+ return self .current_run .started_retryable_write
146
+ return False
147
+
148
+ def add_insert (self , document : _DocumentOut ) -> bool :
133
149
"""Add an insert document to the list of ops."""
134
150
validate_is_document_type ("document" , document )
135
151
# Generate ObjectId client side.
136
152
if not (isinstance (document , RawBSONDocument ) or "_id" in document ):
137
153
document ["_id" ] = ObjectId ()
138
154
self .ops .append ((_INSERT , document ))
155
+ return True
139
156
140
157
def add_update (
141
158
self ,
@@ -147,7 +164,7 @@ def add_update(
147
164
array_filters : Optional [list [Mapping [str , Any ]]] = None ,
148
165
hint : Union [str , dict [str , Any ], None ] = None ,
149
166
sort : Optional [Mapping [str , Any ]] = None ,
150
- ) -> None :
167
+ ) -> bool :
151
168
"""Create an update document and add it to the list of ops."""
152
169
validate_ok_for_update (update )
153
170
cmd : dict [str , Any ] = {"q" : selector , "u" : update , "multi" : multi }
@@ -165,10 +182,12 @@ def add_update(
165
182
if sort is not None :
166
183
self .uses_sort = True
167
184
cmd ["sort" ] = sort
185
+
186
+ self .ops .append ((_UPDATE , cmd ))
168
187
if multi :
169
188
# A bulk_write containing an update_many is not retryable.
170
- self . is_retryable = False
171
- self . ops . append (( _UPDATE , cmd ))
189
+ return False
190
+ return True
172
191
173
192
def add_replace (
174
193
self ,
@@ -178,7 +197,7 @@ def add_replace(
178
197
collation : Optional [Mapping [str , Any ]] = None ,
179
198
hint : Union [str , dict [str , Any ], None ] = None ,
180
199
sort : Optional [Mapping [str , Any ]] = None ,
181
- ) -> None :
200
+ ) -> bool :
182
201
"""Create a replace document and add it to the list of ops."""
183
202
validate_ok_for_replace (replacement )
184
203
cmd : dict [str , Any ] = {"q" : selector , "u" : replacement }
@@ -194,14 +213,15 @@ def add_replace(
194
213
self .uses_sort = True
195
214
cmd ["sort" ] = sort
196
215
self .ops .append ((_UPDATE , cmd ))
216
+ return True
197
217
198
218
def add_delete (
199
219
self ,
200
220
selector : Mapping [str , Any ],
201
221
limit : int ,
202
222
collation : Optional [Mapping [str , Any ]] = None ,
203
223
hint : Union [str , dict [str , Any ], None ] = None ,
204
- ) -> None :
224
+ ) -> bool :
205
225
"""Create a delete document and add it to the list of ops."""
206
226
cmd : dict [str , Any ] = {"q" : selector , "limit" : limit }
207
227
if collation is not None :
@@ -210,44 +230,50 @@ def add_delete(
210
230
if hint is not None :
211
231
self .uses_hint_delete = True
212
232
cmd ["hint" ] = hint
233
+
234
+ self .ops .append ((_DELETE , cmd ))
213
235
if limit == _DELETE_ALL :
214
236
# A bulk_write containing a delete_many is not retryable.
215
- self . is_retryable = False
216
- self . ops . append (( _DELETE , cmd ))
237
+ return False
238
+ return True
217
239
218
- def gen_ordered (self , requests ) -> Iterator [Optional [_Run ]]:
240
+ def gen_ordered (
241
+ self ,
242
+ requests : Iterable [Any ],
243
+ process : Callable [[Union [_DocumentType , RawBSONDocument , _WriteOp ]], bool ],
244
+ ) -> Iterator [_Run ]:
219
245
"""Generate batches of operations, batched by type of
220
246
operation, in the order **provided**.
221
247
"""
222
248
run = None
223
249
for idx , request in enumerate (requests ):
224
- try :
225
- request ._add_to_bulk (self )
226
- except AttributeError :
227
- raise TypeError (f"{ request !r} is not a valid request" ) from None
250
+ retryable = process (request )
228
251
(op_type , operation ) = self .ops [idx ]
229
252
if run is None :
230
253
run = _Run (op_type )
231
254
elif run .op_type != op_type :
232
255
yield run
233
256
run = _Run (op_type )
234
257
run .add (idx , operation )
258
+ run .is_retryable = run .is_retryable and retryable
235
259
if run is None :
236
260
raise InvalidOperation ("No operations to execute" )
237
261
yield run
238
262
239
- def gen_unordered (self , requests ) -> Iterator [_Run ]:
263
+ def gen_unordered (
264
+ self ,
265
+ requests : Iterable [Any ],
266
+ process : Callable [[Union [_DocumentType , RawBSONDocument , _WriteOp ]], bool ],
267
+ ) -> Iterator [_Run ]:
240
268
"""Generate batches of operations, batched by type of
241
269
operation, in arbitrary order.
242
270
"""
243
271
operations = [_Run (_INSERT ), _Run (_UPDATE ), _Run (_DELETE )]
244
272
for idx , request in enumerate (requests ):
245
- try :
246
- request ._add_to_bulk (self )
247
- except AttributeError :
248
- raise TypeError (f"{ request !r} is not a valid request" ) from None
273
+ retryable = process (request )
249
274
(op_type , operation ) = self .ops [idx ]
250
275
operations [op_type ].add (idx , operation )
276
+ operations [op_type ].is_retryable = operations [op_type ].is_retryable and retryable
251
277
if (
252
278
len (operations [_INSERT ].ops ) == 0
253
279
and len (operations [_UPDATE ].ops ) == 0
@@ -488,8 +514,8 @@ async def _execute_command(
488
514
session : Optional [AsyncClientSession ],
489
515
conn : AsyncConnection ,
490
516
op_id : int ,
491
- retryable : bool ,
492
517
full_result : MutableMapping [str , Any ],
518
+ validate : bool ,
493
519
final_write_concern : Optional [WriteConcern ] = None ,
494
520
) -> None :
495
521
db_name = self .collection .database .name
@@ -507,7 +533,7 @@ async def _execute_command(
507
533
last_run = False
508
534
509
535
while run :
510
- if not self .retrying :
536
+ if not run .retrying :
511
537
self .next_run = next (generator , None )
512
538
if self .next_run is None :
513
539
last_run = True
@@ -541,20 +567,21 @@ async def _execute_command(
541
567
if session :
542
568
# Start a new retryable write unless one was already
543
569
# started for this command.
544
- if retryable and not self .started_retryable_write :
570
+ if run . is_retryable and not run .started_retryable_write :
545
571
session ._start_retryable_write ()
546
572
self .started_retryable_write = True
547
- session ._apply_to (cmd , retryable , ReadPreference .PRIMARY , conn )
573
+ session ._apply_to (cmd , run . is_retryable , ReadPreference .PRIMARY , conn )
548
574
conn .send_cluster_time (cmd , session , client )
549
575
conn .add_server_api (cmd )
550
576
# CSOT: apply timeout before encoding the command.
551
577
conn .apply_timeout (client , cmd )
552
578
ops = islice (run .ops , run .idx_offset , None )
553
579
554
580
# Run as many ops as possible in one command.
581
+ if validate :
582
+ await self .validate_batch (conn , write_concern )
555
583
if write_concern .acknowledged :
556
584
result , to_send = await self ._execute_batch (bwc , cmd , ops , client )
557
-
558
585
# Retryable writeConcernErrors halt the execution of this run.
559
586
wce = result .get ("writeConcernError" , {})
560
587
if wce .get ("code" , 0 ) in _RETRYABLE_ERROR_CODES :
@@ -567,8 +594,8 @@ async def _execute_command(
567
594
_merge_command (run , full_result , run .idx_offset , result )
568
595
569
596
# We're no longer in a retry once a command succeeds.
570
- self .retrying = False
571
- self .started_retryable_write = False
597
+ run .retrying = False
598
+ run .started_retryable_write = False
572
599
573
600
if self .ordered and "writeErrors" in result :
574
601
break
@@ -606,34 +633,33 @@ async def execute_command(
606
633
op_id = _randint ()
607
634
608
635
async def retryable_bulk (
609
- session : Optional [AsyncClientSession ], conn : AsyncConnection , retryable : bool
636
+ session : Optional [AsyncClientSession ],
637
+ conn : AsyncConnection ,
610
638
) -> None :
611
639
await self ._execute_command (
612
640
generator ,
613
641
write_concern ,
614
642
session ,
615
643
conn ,
616
644
op_id ,
617
- retryable ,
618
645
full_result ,
646
+ validate = False ,
619
647
)
620
648
621
649
client = self .collection .database .client
622
650
_ = await client ._retryable_write (
623
- self .is_retryable ,
624
651
retryable_bulk ,
625
652
session ,
626
653
operation ,
627
654
bulk = self , # type: ignore[arg-type]
628
655
operation_id = op_id ,
629
656
)
630
-
631
657
if full_result ["writeErrors" ] or full_result ["writeConcernErrors" ]:
632
658
_raise_bulk_write_error (full_result )
633
659
return full_result
634
660
635
661
async def execute_op_msg_no_results (
636
- self , conn : AsyncConnection , generator : Iterator [Any ]
662
+ self , conn : AsyncConnection , generator : Iterator [Any ], write_concern : WriteConcern
637
663
) -> None :
638
664
"""Execute write commands with OP_MSG and w=0 writeConcern, unordered."""
639
665
db_name = self .collection .database .name
@@ -667,6 +693,7 @@ async def execute_op_msg_no_results(
667
693
conn .add_server_api (cmd )
668
694
ops = islice (run .ops , run .idx_offset , None )
669
695
# Run as many ops as possible.
696
+ await self .validate_batch (conn , write_concern )
670
697
to_send = await self ._execute_batch_unack (bwc , cmd , ops , client )
671
698
run .idx_offset += len (to_send )
672
699
self .current_run = run = next (generator , None )
@@ -700,12 +727,15 @@ async def execute_command_no_results(
700
727
None ,
701
728
conn ,
702
729
op_id ,
703
- False ,
704
730
full_result ,
731
+ True ,
705
732
write_concern ,
706
733
)
707
- except OperationFailure :
708
- pass
734
+ except OperationFailure as exc :
735
+ if "Cannot set bypass_document_validation with unacknowledged write concern" in str (
736
+ exc
737
+ ):
738
+ raise exc
709
739
710
740
async def execute_no_results (
711
741
self ,
@@ -714,6 +744,11 @@ async def execute_no_results(
714
744
write_concern : WriteConcern ,
715
745
) -> None :
716
746
"""Execute all operations, returning no results (w=0)."""
747
+ if self .ordered :
748
+ return await self .execute_command_no_results (conn , generator , write_concern )
749
+ return await self .execute_op_msg_no_results (conn , generator , write_concern )
750
+
751
+ async def validate_batch (self , conn : AsyncConnection , write_concern : WriteConcern ) -> None :
717
752
if self .uses_collation :
718
753
raise ConfigurationError ("Collation is unsupported for unacknowledged writes." )
719
754
if self .uses_array_filters :
@@ -738,13 +773,10 @@ async def execute_no_results(
738
773
"Cannot set bypass_document_validation with unacknowledged write concern"
739
774
)
740
775
741
- if self .ordered :
742
- return await self .execute_command_no_results (conn , generator , write_concern )
743
- return await self .execute_op_msg_no_results (conn , generator )
744
-
745
776
async def execute (
746
777
self ,
747
- generator : Generator [_WriteOp [_DocumentType ]],
778
+ generator : Iterable [Any ],
779
+ process : Callable [[Union [_DocumentType , RawBSONDocument , _WriteOp ]], bool ],
748
780
write_concern : WriteConcern ,
749
781
session : Optional [AsyncClientSession ],
750
782
operation : str ,
@@ -757,9 +789,9 @@ async def execute(
757
789
session = _validate_session_write_concern (session , write_concern )
758
790
759
791
if self .ordered :
760
- generator = self .gen_ordered (generator )
792
+ generator = self .gen_ordered (generator , process )
761
793
else :
762
- generator = self .gen_unordered (generator )
794
+ generator = self .gen_unordered (generator , process )
763
795
764
796
client = self .collection .database .client
765
797
if not write_concern .acknowledged :
0 commit comments