Skip to content

Commit 92e6150

Browse files
authored
PYTHON-3493 Bulk Write InsertOne Should Be Parameter Of Collection Type (#1106)
1 parent 133c55d commit 92e6150

17 files changed

+144
-38
lines changed

doc/examples/type_hints.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,26 @@ These methods automatically add an "_id" field.
113113
>>> assert result is not None
114114
>>> assert result["year"] == 1993
115115
>>> # This will raise a type-checking error, despite being present, because it is added by PyMongo.
116+
>>> assert result["_id"] # type:ignore[typeddict-item]
117+
118+
This same typing scheme works for all of the insert methods (:meth:`~pymongo.collection.Collection.insert_one`,
119+
:meth:`~pymongo.collection.Collection.insert_many`, and :meth:`~pymongo.collection.Collection.bulk_write`).
120+
For `bulk_write` both :class:`~pymongo.operations.InsertOne` and :class:`~pymongo.operations.ReplaceOne` operators are generic.
121+
122+
.. doctest::
123+
:pyversion: >= 3.8
124+
125+
>>> from typing import TypedDict
126+
>>> from pymongo import MongoClient
127+
>>> from pymongo.operations import InsertOne
128+
>>> from pymongo.collection import Collection
129+
>>> client: MongoClient = MongoClient()
130+
>>> collection: Collection[Movie] = client.test.test
131+
>>> inserted = collection.bulk_write([InsertOne(Movie(name="Jurassic Park", year=1993))])
132+
>>> result = collection.find_one({"name": "Jurassic Park"})
133+
>>> assert result is not None
134+
>>> assert result["year"] == 1993
135+
>>> # This will raise a type-checking error, despite being present, because it is added by PyMongo.
116136
>>> assert result["_id"] # type:ignore[typeddict-item]
117137

118138
Modeling Document Types with TypedDict

mypy.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ ignore_missing_imports = True
3333
ignore_missing_imports = True
3434

3535
[mypy-test.test_mypy]
36-
warn_unused_ignores = false
36+
warn_unused_ignores = True
3737

3838
[mypy-winkerberos.*]
3939
ignore_missing_imports = True

pymongo/collection.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,14 @@
7777
_FIND_AND_MODIFY_DOC_FIELDS = {"value": 1}
7878

7979

80-
_WriteOp = Union[InsertOne, DeleteOne, DeleteMany, ReplaceOne, UpdateOne, UpdateMany]
80+
_WriteOp = Union[
81+
InsertOne[_DocumentType],
82+
DeleteOne,
83+
DeleteMany,
84+
ReplaceOne[_DocumentType],
85+
UpdateOne,
86+
UpdateMany,
87+
]
8188
# Hint supports index name, "myIndex", or list of index pairs: [('x', 1), ('y', -1)]
8289
_IndexList = Sequence[Tuple[str, Union[int, str, Mapping[str, Any]]]]
8390
_IndexKeyHint = Union[str, _IndexList]
@@ -436,7 +443,7 @@ def with_options(
436443
@_csot.apply
437444
def bulk_write(
438445
self,
439-
requests: Sequence[_WriteOp],
446+
requests: Sequence[_WriteOp[_DocumentType]],
440447
ordered: bool = True,
441448
bypass_document_validation: bool = False,
442449
session: Optional["ClientSession"] = None,

pymongo/encryption.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import enum
1919
import socket
2020
import weakref
21-
from typing import Any, Mapping, Optional, Sequence
21+
from typing import Any, Generic, Mapping, Optional, Sequence
2222

2323
try:
2424
from pymongocrypt.auto_encrypter import AutoEncrypter
@@ -55,6 +55,7 @@
5555
from pymongo.read_concern import ReadConcern
5656
from pymongo.results import BulkWriteResult, DeleteResult
5757
from pymongo.ssl_support import get_ssl_context
58+
from pymongo.typings import _DocumentType
5859
from pymongo.uri_parser import parse_host
5960
from pymongo.write_concern import WriteConcern
6061

@@ -430,7 +431,7 @@ class QueryType(str, enum.Enum):
430431
"""Used to encrypt a value for an equality query."""
431432

432433

433-
class ClientEncryption(object):
434+
class ClientEncryption(Generic[_DocumentType]):
434435
"""Explicit client-side field level encryption."""
435436

436437
def __init__(

pymongo/operations.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,22 @@
1313
# limitations under the License.
1414

1515
"""Operation class definitions."""
16-
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
16+
from typing import Any, Dict, Generic, List, Mapping, Optional, Sequence, Tuple, Union
1717

18+
from bson.raw_bson import RawBSONDocument
1819
from pymongo import helpers
1920
from pymongo.collation import validate_collation_or_none
2021
from pymongo.common import validate_boolean, validate_is_mapping, validate_list
2122
from pymongo.helpers import _gen_index_name, _index_document, _index_list
22-
from pymongo.typings import _CollationIn, _DocumentIn, _Pipeline
23+
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
2324

2425

25-
class InsertOne(object):
26+
class InsertOne(Generic[_DocumentType]):
2627
"""Represents an insert_one operation."""
2728

2829
__slots__ = ("_doc",)
2930

30-
def __init__(self, document: _DocumentIn) -> None:
31+
def __init__(self, document: Union[_DocumentType, RawBSONDocument]) -> None:
3132
"""Create an InsertOne instance.
3233
3334
For use with :meth:`~pymongo.collection.Collection.bulk_write`.
@@ -170,15 +171,15 @@ def __ne__(self, other: Any) -> bool:
170171
return not self == other
171172

172173

173-
class ReplaceOne(object):
174+
class ReplaceOne(Generic[_DocumentType]):
174175
"""Represents a replace_one operation."""
175176

176177
__slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint")
177178

178179
def __init__(
179180
self,
180181
filter: Mapping[str, Any],
181-
replacement: Mapping[str, Any],
182+
replacement: Union[_DocumentType, RawBSONDocument],
182183
upsert: bool = False,
183184
collation: Optional[_CollationIn] = None,
184185
hint: Optional[_IndexKeyHint] = None,

pymongo/typings.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,10 @@
3737
_Pipeline = Sequence[Mapping[str, Any]]
3838
_DocumentOut = _DocumentIn
3939
_DocumentType = TypeVar("_DocumentType", bound=Mapping[str, Any])
40+
41+
42+
def strip_optional(elem):
43+
"""This function is to allow us to cast all of the elements of an iterator from Optional[_T] to _T
44+
while inside a list comprehension."""
45+
assert elem is not None
46+
return elem

test/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1090,7 +1090,7 @@ def print_thread_stacks(pid: int) -> None:
10901090
class IntegrationTest(PyMongoTestCase):
10911091
"""Base class for TestCases that need a connection to MongoDB to pass."""
10921092

1093-
client: MongoClient
1093+
client: MongoClient[dict]
10941094
db: Database
10951095
credentials: Dict[str, str]
10961096

test/mockupdb/test_cluster_time.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def callback(client):
6060
self.cluster_time_conversation(callback, [{"ok": 1}] * 2)
6161

6262
def test_bulk(self):
63-
def callback(client):
63+
def callback(client: MongoClient[dict]) -> None:
6464
client.db.collection.bulk_write(
6565
[InsertOne({}), InsertOne({}), UpdateOne({}, {"$inc": {"x": 1}}), DeleteMany({})]
6666
)

test/mockupdb/test_op_msg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,22 +137,22 @@
137137
# Legacy methods
138138
Operation(
139139
"bulk_write_insert",
140-
lambda coll: coll.bulk_write([InsertOne({}), InsertOne({})]),
140+
lambda coll: coll.bulk_write([InsertOne[dict]({}), InsertOne[dict]({})]),
141141
request=OpMsg({"insert": "coll"}, flags=0),
142142
reply={"ok": 1, "n": 2},
143143
),
144144
Operation(
145145
"bulk_write_insert-w0",
146146
lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).bulk_write(
147-
[InsertOne({}), InsertOne({})]
147+
[InsertOne[dict]({}), InsertOne[dict]({})]
148148
),
149149
request=OpMsg({"insert": "coll"}, flags=0),
150150
reply={"ok": 1, "n": 2},
151151
),
152152
Operation(
153153
"bulk_write_insert-w0-unordered",
154154
lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).bulk_write(
155-
[InsertOne({}), InsertOne({})], ordered=False
155+
[InsertOne[dict]({}), InsertOne[dict]({})], ordered=False
156156
),
157157
request=OpMsg({"insert": "coll"}, flags=OP_MSG_FLAGS["moreToCome"]),
158158
reply=None,

test/test_bulk.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def test_upsert(self):
296296
def test_numerous_inserts(self):
297297
# Ensure we don't exceed server's maxWriteBatchSize size limit.
298298
n_docs = client_context.max_write_batch_size + 100
299-
requests = [InsertOne({}) for _ in range(n_docs)]
299+
requests = [InsertOne[dict]({}) for _ in range(n_docs)]
300300
result = self.coll.bulk_write(requests, ordered=False)
301301
self.assertEqual(n_docs, result.inserted_count)
302302
self.assertEqual(n_docs, self.coll.count_documents({}))
@@ -347,7 +347,7 @@ def test_bulk_write_no_results(self):
347347

348348
def test_bulk_write_invalid_arguments(self):
349349
# The requests argument must be a list.
350-
generator = (InsertOne({}) for _ in range(10))
350+
generator = (InsertOne[dict]({}) for _ in range(10))
351351
with self.assertRaises(TypeError):
352352
self.coll.bulk_write(generator) # type: ignore[arg-type]
353353

0 commit comments

Comments
 (0)