Skip to content

Commit 0648bcf

Browse files
committed
wip
1 parent b933601 commit 0648bcf

File tree

14 files changed

+307
-252
lines changed

14 files changed

+307
-252
lines changed

pymongo/asynchronous/bulk.py

Lines changed: 76 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from typing import (
2727
TYPE_CHECKING,
2828
Any,
29-
Generator,
29+
Callable,
30+
Iterable,
3031
Iterator,
3132
Mapping,
3233
Optional,
@@ -111,9 +112,6 @@ def __init__(
111112
self.uses_hint_update = False
112113
self.uses_hint_delete = False
113114
self.uses_sort = False
114-
self.is_retryable = True
115-
self.retrying = False
116-
self.started_retryable_write = False
117115
# Extra state so that we know where to pick up on a retry attempt.
118116
self.current_run = None
119117
self.next_run = None
@@ -129,13 +127,32 @@ def bulk_ctx_class(self) -> Type[_BulkWriteContext]:
129127
self.is_encrypted = False
130128
return _BulkWriteContext
131129

132-
def add_insert(self, document: _DocumentOut) -> None:
130+
@property
131+
def is_retryable(self) -> bool:
132+
if self.current_run:
133+
return self.current_run.is_retryable
134+
return True
135+
136+
@property
137+
def retrying(self) -> bool:
138+
if self.current_run:
139+
return self.current_run.retrying
140+
return False
141+
142+
@property
143+
def started_retryable_write(self) -> bool:
144+
if self.current_run:
145+
return self.current_run.started_retryable_write
146+
return False
147+
148+
def add_insert(self, document: _DocumentOut) -> bool:
133149
"""Add an insert document to the list of ops."""
134150
validate_is_document_type("document", document)
135151
# Generate ObjectId client side.
136152
if not (isinstance(document, RawBSONDocument) or "_id" in document):
137153
document["_id"] = ObjectId()
138154
self.ops.append((_INSERT, document))
155+
return True
139156

140157
def add_update(
141158
self,
@@ -147,7 +164,7 @@ def add_update(
147164
array_filters: Optional[list[Mapping[str, Any]]] = None,
148165
hint: Union[str, dict[str, Any], None] = None,
149166
sort: Optional[Mapping[str, Any]] = None,
150-
) -> None:
167+
) -> bool:
151168
"""Create an update document and add it to the list of ops."""
152169
validate_ok_for_update(update)
153170
cmd: dict[str, Any] = {"q": selector, "u": update, "multi": multi}
@@ -165,10 +182,12 @@ def add_update(
165182
if sort is not None:
166183
self.uses_sort = True
167184
cmd["sort"] = sort
185+
186+
self.ops.append((_UPDATE, cmd))
168187
if multi:
169188
# A bulk_write containing an update_many is not retryable.
170-
self.is_retryable = False
171-
self.ops.append((_UPDATE, cmd))
189+
return False
190+
return True
172191

173192
def add_replace(
174193
self,
@@ -178,7 +197,7 @@ def add_replace(
178197
collation: Optional[Mapping[str, Any]] = None,
179198
hint: Union[str, dict[str, Any], None] = None,
180199
sort: Optional[Mapping[str, Any]] = None,
181-
) -> None:
200+
) -> bool:
182201
"""Create a replace document and add it to the list of ops."""
183202
validate_ok_for_replace(replacement)
184203
cmd: dict[str, Any] = {"q": selector, "u": replacement}
@@ -194,14 +213,15 @@ def add_replace(
194213
self.uses_sort = True
195214
cmd["sort"] = sort
196215
self.ops.append((_UPDATE, cmd))
216+
return True
197217

198218
def add_delete(
199219
self,
200220
selector: Mapping[str, Any],
201221
limit: int,
202222
collation: Optional[Mapping[str, Any]] = None,
203223
hint: Union[str, dict[str, Any], None] = None,
204-
) -> None:
224+
) -> bool:
205225
"""Create a delete document and add it to the list of ops."""
206226
cmd: dict[str, Any] = {"q": selector, "limit": limit}
207227
if collation is not None:
@@ -210,44 +230,50 @@ def add_delete(
210230
if hint is not None:
211231
self.uses_hint_delete = True
212232
cmd["hint"] = hint
233+
234+
self.ops.append((_DELETE, cmd))
213235
if limit == _DELETE_ALL:
214236
# A bulk_write containing a delete_many is not retryable.
215-
self.is_retryable = False
216-
self.ops.append((_DELETE, cmd))
237+
return False
238+
return True
217239

218-
def gen_ordered(self, requests) -> Iterator[Optional[_Run]]:
240+
def gen_ordered(
241+
self,
242+
requests: Iterable[Any],
243+
process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool],
244+
) -> Iterator[_Run]:
219245
"""Generate batches of operations, batched by type of
220246
operation, in the order **provided**.
221247
"""
222248
run = None
223249
for idx, request in enumerate(requests):
224-
try:
225-
request._add_to_bulk(self)
226-
except AttributeError:
227-
raise TypeError(f"{request!r} is not a valid request") from None
250+
retryable = process(request)
228251
(op_type, operation) = self.ops[idx]
229252
if run is None:
230253
run = _Run(op_type)
231254
elif run.op_type != op_type:
232255
yield run
233256
run = _Run(op_type)
234257
run.add(idx, operation)
258+
run.is_retryable = run.is_retryable and retryable
235259
if run is None:
236260
raise InvalidOperation("No operations to execute")
237261
yield run
238262

239-
def gen_unordered(self, requests) -> Iterator[_Run]:
263+
def gen_unordered(
264+
self,
265+
requests: Iterable[Any],
266+
process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool],
267+
) -> Iterator[_Run]:
240268
"""Generate batches of operations, batched by type of
241269
operation, in arbitrary order.
242270
"""
243271
operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)]
244272
for idx, request in enumerate(requests):
245-
try:
246-
request._add_to_bulk(self)
247-
except AttributeError:
248-
raise TypeError(f"{request!r} is not a valid request") from None
273+
retryable = process(request)
249274
(op_type, operation) = self.ops[idx]
250275
operations[op_type].add(idx, operation)
276+
operations[op_type].is_retryable = operations[op_type].is_retryable and retryable
251277
if (
252278
len(operations[_INSERT].ops) == 0
253279
and len(operations[_UPDATE].ops) == 0
@@ -488,8 +514,8 @@ async def _execute_command(
488514
session: Optional[AsyncClientSession],
489515
conn: AsyncConnection,
490516
op_id: int,
491-
retryable: bool,
492517
full_result: MutableMapping[str, Any],
518+
validate: bool,
493519
final_write_concern: Optional[WriteConcern] = None,
494520
) -> None:
495521
db_name = self.collection.database.name
@@ -507,7 +533,7 @@ async def _execute_command(
507533
last_run = False
508534

509535
while run:
510-
if not self.retrying:
536+
if not run.retrying:
511537
self.next_run = next(generator, None)
512538
if self.next_run is None:
513539
last_run = True
@@ -541,20 +567,21 @@ async def _execute_command(
541567
if session:
542568
# Start a new retryable write unless one was already
543569
# started for this command.
544-
if retryable and not self.started_retryable_write:
570+
if run.is_retryable and not run.started_retryable_write:
545571
session._start_retryable_write()
546572
self.started_retryable_write = True
547-
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn)
573+
session._apply_to(cmd, run.is_retryable, ReadPreference.PRIMARY, conn)
548574
conn.send_cluster_time(cmd, session, client)
549575
conn.add_server_api(cmd)
550576
# CSOT: apply timeout before encoding the command.
551577
conn.apply_timeout(client, cmd)
552578
ops = islice(run.ops, run.idx_offset, None)
553579

554580
# Run as many ops as possible in one command.
581+
if validate:
582+
await self.validate_batch(conn, write_concern)
555583
if write_concern.acknowledged:
556584
result, to_send = await self._execute_batch(bwc, cmd, ops, client)
557-
558585
# Retryable writeConcernErrors halt the execution of this run.
559586
wce = result.get("writeConcernError", {})
560587
if wce.get("code", 0) in _RETRYABLE_ERROR_CODES:
@@ -567,8 +594,8 @@ async def _execute_command(
567594
_merge_command(run, full_result, run.idx_offset, result)
568595

569596
# We're no longer in a retry once a command succeeds.
570-
self.retrying = False
571-
self.started_retryable_write = False
597+
run.retrying = False
598+
run.started_retryable_write = False
572599

573600
if self.ordered and "writeErrors" in result:
574601
break
@@ -606,34 +633,33 @@ async def execute_command(
606633
op_id = _randint()
607634

608635
async def retryable_bulk(
609-
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable: bool
636+
session: Optional[AsyncClientSession],
637+
conn: AsyncConnection,
610638
) -> None:
611639
await self._execute_command(
612640
generator,
613641
write_concern,
614642
session,
615643
conn,
616644
op_id,
617-
retryable,
618645
full_result,
646+
validate=False,
619647
)
620648

621649
client = self.collection.database.client
622650
_ = await client._retryable_write(
623-
self.is_retryable,
624651
retryable_bulk,
625652
session,
626653
operation,
627654
bulk=self, # type: ignore[arg-type]
628655
operation_id=op_id,
629656
)
630-
631657
if full_result["writeErrors"] or full_result["writeConcernErrors"]:
632658
_raise_bulk_write_error(full_result)
633659
return full_result
634660

635661
async def execute_op_msg_no_results(
636-
self, conn: AsyncConnection, generator: Iterator[Any]
662+
self, conn: AsyncConnection, generator: Iterator[Any], write_concern: WriteConcern
637663
) -> None:
638664
"""Execute write commands with OP_MSG and w=0 writeConcern, unordered."""
639665
db_name = self.collection.database.name
@@ -667,6 +693,7 @@ async def execute_op_msg_no_results(
667693
conn.add_server_api(cmd)
668694
ops = islice(run.ops, run.idx_offset, None)
669695
# Run as many ops as possible.
696+
await self.validate_batch(conn, write_concern)
670697
to_send = await self._execute_batch_unack(bwc, cmd, ops, client)
671698
run.idx_offset += len(to_send)
672699
self.current_run = run = next(generator, None)
@@ -700,12 +727,15 @@ async def execute_command_no_results(
700727
None,
701728
conn,
702729
op_id,
703-
False,
704730
full_result,
731+
True,
705732
write_concern,
706733
)
707-
except OperationFailure:
708-
pass
734+
except OperationFailure as exc:
735+
if "Cannot set bypass_document_validation with unacknowledged write concern" in str(
736+
exc
737+
):
738+
raise exc
709739

710740
async def execute_no_results(
711741
self,
@@ -714,6 +744,11 @@ async def execute_no_results(
714744
write_concern: WriteConcern,
715745
) -> None:
716746
"""Execute all operations, returning no results (w=0)."""
747+
if self.ordered:
748+
return await self.execute_command_no_results(conn, generator, write_concern)
749+
return await self.execute_op_msg_no_results(conn, generator, write_concern)
750+
751+
async def validate_batch(self, conn: AsyncConnection, write_concern: WriteConcern) -> None:
717752
if self.uses_collation:
718753
raise ConfigurationError("Collation is unsupported for unacknowledged writes.")
719754
if self.uses_array_filters:
@@ -738,13 +773,10 @@ async def execute_no_results(
738773
"Cannot set bypass_document_validation with unacknowledged write concern"
739774
)
740775

741-
if self.ordered:
742-
return await self.execute_command_no_results(conn, generator, write_concern)
743-
return await self.execute_op_msg_no_results(conn, generator)
744-
745776
async def execute(
746777
self,
747-
generator: Generator[_WriteOp[_DocumentType]],
778+
generator: Iterable[Any],
779+
process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool],
748780
write_concern: WriteConcern,
749781
session: Optional[AsyncClientSession],
750782
operation: str,
@@ -757,9 +789,9 @@ async def execute(
757789
session = _validate_session_write_concern(session, write_concern)
758790

759791
if self.ordered:
760-
generator = self.gen_ordered(generator)
792+
generator = self.gen_ordered(generator, process)
761793
else:
762-
generator = self.gen_unordered(generator)
794+
generator = self.gen_unordered(generator, process)
763795

764796
client = self.collection.database.client
765797
if not write_concern.acknowledged:

pymongo/asynchronous/client_bulk.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __init__(
116116
self.is_retryable = self.client.options.retry_writes
117117
self.retrying = False
118118
self.started_retryable_write = False
119+
self.current_run = None
119120

120121
@property
121122
def bulk_ctx_class(self) -> Type[_ClientBulkWriteContext]:
@@ -488,7 +489,6 @@ async def _execute_command(
488489
session: Optional[AsyncClientSession],
489490
conn: AsyncConnection,
490491
op_id: int,
491-
retryable: bool,
492492
full_result: MutableMapping[str, Any],
493493
final_write_concern: Optional[WriteConcern] = None,
494494
) -> None:
@@ -534,10 +534,10 @@ async def _execute_command(
534534
if session:
535535
# Start a new retryable write unless one was already
536536
# started for this command.
537-
if retryable and not self.started_retryable_write:
537+
if self.is_retryable and not self.started_retryable_write:
538538
session._start_retryable_write()
539539
self.started_retryable_write = True
540-
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn)
540+
session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn)
541541
conn.send_cluster_time(cmd, session, self.client)
542542
conn.add_server_api(cmd)
543543
# CSOT: apply timeout before encoding the command.
@@ -564,7 +564,7 @@ async def _execute_command(
564564

565565
# Synthesize the full bulk result without modifying the
566566
# current one because this write operation may be retried.
567-
if retryable and (retryable_top_level_error or retryable_network_error):
567+
if self.is_retryable and (retryable_top_level_error or retryable_network_error):
568568
full = copy.deepcopy(full_result)
569569
_merge_command(self.ops, self.idx_offset, full, result)
570570
_throw_client_bulk_write_exception(full, self.verbose_results)
@@ -583,7 +583,7 @@ async def _execute_command(
583583
_merge_command(self.ops, self.idx_offset, full_result, result)
584584
break
585585

586-
if retryable:
586+
if self.is_retryable:
587587
# Retryable writeConcernErrors halt the execution of this batch.
588588
wce = result.get("writeConcernError", {})
589589
if wce.get("code", 0) in _RETRYABLE_ERROR_CODES:
@@ -638,7 +638,6 @@ async def execute_command(
638638
async def retryable_bulk(
639639
session: Optional[AsyncClientSession],
640640
conn: AsyncConnection,
641-
retryable: bool,
642641
) -> None:
643642
if conn.max_wire_version < 25:
644643
raise InvalidOperation(
@@ -649,12 +648,10 @@ async def retryable_bulk(
649648
session,
650649
conn,
651650
op_id,
652-
retryable,
653651
full_result,
654652
)
655653

656654
await self.client._retryable_write(
657-
self.is_retryable,
658655
retryable_bulk,
659656
session,
660657
operation,

0 commit comments

Comments
 (0)