Skip to content

Commit b933601

Browse files
committed
bulk_write should be able to accept a generator
1 parent 894782e commit b933601

File tree

7 files changed

+100
-36
lines changed

7 files changed

+100
-36
lines changed

pymongo/asynchronous/bulk.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from typing import (
2727
TYPE_CHECKING,
2828
Any,
29+
Generator,
2930
Iterator,
3031
Mapping,
3132
Optional,
@@ -72,7 +73,7 @@
7273
from pymongo.write_concern import WriteConcern
7374

7475
if TYPE_CHECKING:
75-
from pymongo.asynchronous.collection import AsyncCollection
76+
from pymongo.asynchronous.collection import AsyncCollection, _WriteOp
7677
from pymongo.asynchronous.mongo_client import AsyncMongoClient
7778
from pymongo.asynchronous.pool import AsyncConnection
7879
from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline
@@ -214,28 +215,45 @@ def add_delete(
214215
self.is_retryable = False
215216
self.ops.append((_DELETE, cmd))
216217

217-
def gen_ordered(self) -> Iterator[Optional[_Run]]:
218+
def gen_ordered(self, requests) -> Iterator[Optional[_Run]]:
218219
"""Generate batches of operations, batched by type of
219220
operation, in the order **provided**.
220221
"""
221222
run = None
222-
for idx, (op_type, operation) in enumerate(self.ops):
223+
for idx, request in enumerate(requests):
224+
try:
225+
request._add_to_bulk(self)
226+
except AttributeError:
227+
raise TypeError(f"{request!r} is not a valid request") from None
228+
(op_type, operation) = self.ops[idx]
223229
if run is None:
224230
run = _Run(op_type)
225231
elif run.op_type != op_type:
226232
yield run
227233
run = _Run(op_type)
228234
run.add(idx, operation)
235+
if run is None:
236+
raise InvalidOperation("No operations to execute")
229237
yield run
230238

231-
def gen_unordered(self) -> Iterator[_Run]:
239+
def gen_unordered(self, requests) -> Iterator[_Run]:
232240
"""Generate batches of operations, batched by type of
233241
operation, in arbitrary order.
234242
"""
235243
operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)]
236-
for idx, (op_type, operation) in enumerate(self.ops):
244+
for idx, request in enumerate(requests):
245+
try:
246+
request._add_to_bulk(self)
247+
except AttributeError:
248+
raise TypeError(f"{request!r} is not a valid request") from None
249+
(op_type, operation) = self.ops[idx]
237250
operations[op_type].add(idx, operation)
238-
251+
if (
252+
len(operations[_INSERT].ops) == 0
253+
and len(operations[_UPDATE].ops) == 0
254+
and len(operations[_DELETE].ops) == 0
255+
):
256+
raise InvalidOperation("No operations to execute")
239257
for run in operations:
240258
if run.ops:
241259
yield run
@@ -726,23 +744,22 @@ async def execute_no_results(
726744

727745
async def execute(
728746
self,
747+
generator: Generator[_WriteOp[_DocumentType]],
729748
write_concern: WriteConcern,
730749
session: Optional[AsyncClientSession],
731750
operation: str,
732751
) -> Any:
733752
"""Execute operations."""
734-
if not self.ops:
735-
raise InvalidOperation("No operations to execute")
736753
if self.executed:
737754
raise InvalidOperation("Bulk operations can only be executed once.")
738755
self.executed = True
739756
write_concern = write_concern or self.collection.write_concern
740757
session = _validate_session_write_concern(session, write_concern)
741758

742759
if self.ordered:
743-
generator = self.gen_ordered()
760+
generator = self.gen_ordered(generator)
744761
else:
745-
generator = self.gen_unordered()
762+
generator = self.gen_unordered(generator)
746763

747764
client = self.collection.database.client
748765
if not write_concern.acknowledged:

pymongo/asynchronous/collection.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
AsyncContextManager,
2424
Callable,
2525
Coroutine,
26+
Generator,
2627
Generic,
2728
Iterable,
2829
Iterator,
@@ -699,7 +700,7 @@ async def _create(
699700
@_csot.apply
700701
async def bulk_write(
701702
self,
702-
requests: Sequence[_WriteOp[_DocumentType]],
703+
requests: Sequence[_WriteOp[_DocumentType]] | Generator[_WriteOp[_DocumentType]],
703704
ordered: bool = True,
704705
bypass_document_validation: Optional[bool] = None,
705706
session: Optional[AsyncClientSession] = None,
@@ -779,17 +780,12 @@ async def bulk_write(
779780
780781
.. versionadded:: 3.0
781782
"""
782-
common.validate_list("requests", requests)
783+
common.validate_list_or_generator("requests", requests)
783784

784785
blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment, let=let)
785-
for request in requests:
786-
try:
787-
request._add_to_bulk(blk)
788-
except AttributeError:
789-
raise TypeError(f"{request!r} is not a valid request") from None
790786

791787
write_concern = self._write_concern_for(session)
792-
bulk_api_result = await blk.execute(write_concern, session, _Op.INSERT)
788+
bulk_api_result = await blk.execute(requests, write_concern, session, _Op.INSERT)
793789
if bulk_api_result is not None:
794790
return BulkWriteResult(bulk_api_result, True)
795791
return BulkWriteResult({}, False)

pymongo/common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
TYPE_CHECKING,
2525
Any,
2626
Callable,
27+
Generator,
2728
Iterator,
2829
Mapping,
2930
MutableMapping,
@@ -530,6 +531,13 @@ def validate_list(option: str, value: Any) -> list:
530531
return value
531532

532533

534+
def validate_list_or_generator(option: str, value: Any) -> Union[list, Generator]:
535+
"""Validates that 'value' is a list or generator."""
536+
if isinstance(value, Generator):
537+
return value
538+
return validate_list(option, value)
539+
540+
533541
def validate_list_or_none(option: Any, value: Any) -> Optional[list]:
534542
"""Validates that 'value' is a list or None."""
535543
if value is None:

pymongo/synchronous/bulk.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from typing import (
2727
TYPE_CHECKING,
2828
Any,
29+
Generator,
2930
Iterator,
3031
Mapping,
3132
Optional,
@@ -72,7 +73,7 @@
7273
from pymongo.write_concern import WriteConcern
7374

7475
if TYPE_CHECKING:
75-
from pymongo.synchronous.collection import Collection
76+
from pymongo.synchronous.collection import Collection, _WriteOp
7677
from pymongo.synchronous.mongo_client import MongoClient
7778
from pymongo.synchronous.pool import Connection
7879
from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline
@@ -214,28 +215,45 @@ def add_delete(
214215
self.is_retryable = False
215216
self.ops.append((_DELETE, cmd))
216217

217-
def gen_ordered(self) -> Iterator[Optional[_Run]]:
218+
def gen_ordered(self, requests) -> Iterator[Optional[_Run]]:
218219
"""Generate batches of operations, batched by type of
219220
operation, in the order **provided**.
220221
"""
221222
run = None
222-
for idx, (op_type, operation) in enumerate(self.ops):
223+
for idx, request in enumerate(requests):
224+
try:
225+
request._add_to_bulk(self)
226+
except AttributeError:
227+
raise TypeError(f"{request!r} is not a valid request") from None
228+
(op_type, operation) = self.ops[idx]
223229
if run is None:
224230
run = _Run(op_type)
225231
elif run.op_type != op_type:
226232
yield run
227233
run = _Run(op_type)
228234
run.add(idx, operation)
235+
if run is None:
236+
raise InvalidOperation("No operations to execute")
229237
yield run
230238

231-
def gen_unordered(self) -> Iterator[_Run]:
239+
def gen_unordered(self, requests) -> Iterator[_Run]:
232240
"""Generate batches of operations, batched by type of
233241
operation, in arbitrary order.
234242
"""
235243
operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)]
236-
for idx, (op_type, operation) in enumerate(self.ops):
244+
for idx, request in enumerate(requests):
245+
try:
246+
request._add_to_bulk(self)
247+
except AttributeError:
248+
raise TypeError(f"{request!r} is not a valid request") from None
249+
(op_type, operation) = self.ops[idx]
237250
operations[op_type].add(idx, operation)
238-
251+
if (
252+
len(operations[_INSERT].ops) == 0
253+
and len(operations[_UPDATE].ops) == 0
254+
and len(operations[_DELETE].ops) == 0
255+
):
256+
raise InvalidOperation("No operations to execute")
239257
for run in operations:
240258
if run.ops:
241259
yield run
@@ -724,23 +742,22 @@ def execute_no_results(
724742

725743
def execute(
726744
self,
745+
generator: Generator[_WriteOp[_DocumentType]],
727746
write_concern: WriteConcern,
728747
session: Optional[ClientSession],
729748
operation: str,
730749
) -> Any:
731750
"""Execute operations."""
732-
if not self.ops:
733-
raise InvalidOperation("No operations to execute")
734751
if self.executed:
735752
raise InvalidOperation("Bulk operations can only be executed once.")
736753
self.executed = True
737754
write_concern = write_concern or self.collection.write_concern
738755
session = _validate_session_write_concern(session, write_concern)
739756

740757
if self.ordered:
741-
generator = self.gen_ordered()
758+
generator = self.gen_ordered(generator)
742759
else:
743-
generator = self.gen_unordered()
760+
generator = self.gen_unordered(generator)
744761

745762
client = self.collection.database.client
746763
if not write_concern.acknowledged:

pymongo/synchronous/collection.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
Any,
2323
Callable,
2424
ContextManager,
25+
Generator,
2526
Generic,
2627
Iterable,
2728
Iterator,
@@ -698,7 +699,7 @@ def _create(
698699
@_csot.apply
699700
def bulk_write(
700701
self,
701-
requests: Sequence[_WriteOp[_DocumentType]],
702+
requests: Sequence[_WriteOp[_DocumentType]] | Generator[_WriteOp[_DocumentType]],
702703
ordered: bool = True,
703704
bypass_document_validation: Optional[bool] = None,
704705
session: Optional[ClientSession] = None,
@@ -778,17 +779,12 @@ def bulk_write(
778779
779780
.. versionadded:: 3.0
780781
"""
781-
common.validate_list("requests", requests)
782+
common.validate_list_or_generator("requests", requests)
782783

783784
blk = _Bulk(self, ordered, bypass_document_validation, comment=comment, let=let)
784-
for request in requests:
785-
try:
786-
request._add_to_bulk(blk)
787-
except AttributeError:
788-
raise TypeError(f"{request!r} is not a valid request") from None
789785

790786
write_concern = self._write_concern_for(session)
791-
bulk_api_result = blk.execute(write_concern, session, _Op.INSERT)
787+
bulk_api_result = blk.execute(requests, write_concern, session, _Op.INSERT)
792788
if bulk_api_result is not None:
793789
return BulkWriteResult(bulk_api_result, True)
794790
return BulkWriteResult({}, False)

test/asynchronous/test_bulk.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,21 @@ async def test_numerous_inserts(self):
299299
self.assertEqual(n_docs, result.inserted_count)
300300
self.assertEqual(n_docs, await self.coll.count_documents({}))
301301

302+
async def test_numerous_inserts_generator(self):
303+
# Ensure we don't exceed server's maxWriteBatchSize size limit.
304+
n_docs = await async_client_context.max_write_batch_size + 100
305+
requests = (InsertOne[dict]({}) for _ in range(n_docs))
306+
result = await self.coll.bulk_write(requests, ordered=False)
307+
self.assertEqual(n_docs, result.inserted_count)
308+
self.assertEqual(n_docs, await self.coll.count_documents({}))
309+
310+
# Same with ordered bulk.
311+
await self.coll.drop()
312+
requests = (InsertOne[dict]({}) for _ in range(n_docs))
313+
result = await self.coll.bulk_write(requests)
314+
self.assertEqual(n_docs, result.inserted_count)
315+
self.assertEqual(n_docs, await self.coll.count_documents({}))
316+
302317
async def test_bulk_max_message_size(self):
303318
await self.coll.delete_many({})
304319
self.addAsyncCleanup(self.coll.delete_many, {})

test/test_bulk.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,21 @@ def test_numerous_inserts(self):
299299
self.assertEqual(n_docs, result.inserted_count)
300300
self.assertEqual(n_docs, self.coll.count_documents({}))
301301

302+
def test_numerous_inserts_generator(self):
303+
# Ensure we don't exceed server's maxWriteBatchSize size limit.
304+
n_docs = client_context.max_write_batch_size + 100
305+
requests = (InsertOne[dict]({}) for _ in range(n_docs))
306+
result = self.coll.bulk_write(requests, ordered=False)
307+
self.assertEqual(n_docs, result.inserted_count)
308+
self.assertEqual(n_docs, self.coll.count_documents({}))
309+
310+
# Same with ordered bulk.
311+
self.coll.drop()
312+
requests = (InsertOne[dict]({}) for _ in range(n_docs))
313+
result = self.coll.bulk_write(requests)
314+
self.assertEqual(n_docs, result.inserted_count)
315+
self.assertEqual(n_docs, self.coll.count_documents({}))
316+
302317
def test_bulk_max_message_size(self):
303318
self.coll.delete_many({})
304319
self.addCleanup(self.coll.delete_many, {})

0 commit comments

Comments
 (0)