Skip to content

Commit 740f38f

Browse files
committed
WIP
1 parent bfaab82 commit 740f38f

22 files changed

+123
-118
lines changed

gridfs/asynchronous/grid_file.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def _disallow_transactions(session: Optional[AsyncClientSession]) -> None:
7070
class AsyncGridFS:
7171
"""An instance of GridFS on top of a single Database."""
7272

73-
def __init__(self, database: AsyncDatabase, collection: str = "fs"):
73+
def __init__(self, database: AsyncDatabase[Any], collection: str = "fs"):
7474
"""Create a new instance of :class:`GridFS`.
7575
7676
Raises :class:`TypeError` if `database` is not an instance of
@@ -463,7 +463,7 @@ class AsyncGridFSBucket:
463463

464464
def __init__(
465465
self,
466-
db: AsyncDatabase,
466+
db: AsyncDatabase[Any],
467467
bucket_name: str = "fs",
468468
chunk_size_bytes: int = DEFAULT_CHUNK_SIZE,
469469
write_concern: Optional[WriteConcern] = None,
@@ -513,11 +513,11 @@ def __init__(
513513

514514
self._bucket_name = bucket_name
515515
self._collection = db[bucket_name]
516-
self._chunks: AsyncCollection = self._collection.chunks.with_options(
516+
self._chunks: AsyncCollection[Any] = self._collection.chunks.with_options(
517517
write_concern=write_concern, read_preference=read_preference
518518
)
519519

520-
self._files: AsyncCollection = self._collection.files.with_options(
520+
self._files: AsyncCollection[Any] = self._collection.files.with_options(
521521
write_concern=write_concern, read_preference=read_preference
522522
)
523523

@@ -1085,7 +1085,7 @@ class AsyncGridIn:
10851085

10861086
def __init__(
10871087
self,
1088-
root_collection: AsyncCollection,
1088+
root_collection: AsyncCollection[Any],
10891089
session: Optional[AsyncClientSession] = None,
10901090
**kwargs: Any,
10911091
) -> None:
@@ -1141,7 +1141,7 @@ def __init__(
11411141
"""
11421142
if not isinstance(root_collection, AsyncCollection):
11431143
raise TypeError(
1144-
f"root_collection must be an instance of AsyncCollection, not {type(root_collection)}"
1144+
f"root_collection must be an instance of AsyncCollection[Any], not {type(root_collection)}"
11451145
)
11461146

11471147
if not root_collection.write_concern.acknowledged:
@@ -1172,7 +1172,7 @@ def __init__(
11721172
object.__setattr__(self, "_buffered_docs_size", 0)
11731173

11741174
async def _create_index(
1175-
self, collection: AsyncCollection, index_key: Any, unique: bool
1175+
self, collection: AsyncCollection[Any], index_key: Any, unique: bool
11761176
) -> None:
11771177
doc = await collection.find_one(projection={"_id": 1}, session=self._session)
11781178
if doc is None:
@@ -1456,7 +1456,7 @@ class AsyncGridOut(GRIDOUT_BASE_CLASS): # type: ignore
14561456

14571457
def __init__(
14581458
self,
1459-
root_collection: AsyncCollection,
1459+
root_collection: AsyncCollection[Any],
14601460
file_id: Optional[int] = None,
14611461
file_document: Optional[Any] = None,
14621462
session: Optional[AsyncClientSession] = None,
@@ -1494,7 +1494,7 @@ def __init__(
14941494
"""
14951495
if not isinstance(root_collection, AsyncCollection):
14961496
raise TypeError(
1497-
f"root_collection must be an instance of AsyncCollection, not {type(root_collection)}"
1497+
f"root_collection must be an instance of AsyncCollection[Any], not {type(root_collection)}"
14981498
)
14991499
_disallow_transactions(session)
15001500

@@ -1829,7 +1829,7 @@ class _AsyncGridOutChunkIterator:
18291829
def __init__(
18301830
self,
18311831
grid_out: AsyncGridOut,
1832-
chunks: AsyncCollection,
1832+
chunks: AsyncCollection[Any],
18331833
session: Optional[AsyncClientSession],
18341834
next_chunk: Any,
18351835
) -> None:
@@ -1842,7 +1842,7 @@ def __init__(
18421842
self._num_chunks = math.ceil(float(self._length) / self._chunk_size)
18431843
self._cursor = None
18441844

1845-
_cursor: Optional[AsyncCursor]
1845+
_cursor: Optional[AsyncCursor[Any]]
18461846

18471847
def expected_chunk_length(self, chunk_n: int) -> int:
18481848
if chunk_n < self._num_chunks - 1:
@@ -1921,7 +1921,7 @@ async def close(self) -> None:
19211921

19221922
class AsyncGridOutIterator:
19231923
def __init__(
1924-
self, grid_out: AsyncGridOut, chunks: AsyncCollection, session: AsyncClientSession
1924+
self, grid_out: AsyncGridOut, chunks: AsyncCollection[Any], session: AsyncClientSession
19251925
):
19261926
self._chunk_iter = _AsyncGridOutChunkIterator(grid_out, chunks, session, 0)
19271927

@@ -1935,14 +1935,14 @@ async def next(self) -> bytes:
19351935
__anext__ = next
19361936

19371937

1938-
class AsyncGridOutCursor(AsyncCursor):
1938+
class AsyncGridOutCursor(AsyncCursor): # type: ignore[type-arg]
19391939
"""A cursor / iterator for returning GridOut objects as the result
19401940
of an arbitrary query against the GridFS files collection.
19411941
"""
19421942

19431943
def __init__(
19441944
self,
1945-
collection: AsyncCollection,
1945+
collection: AsyncCollection[Any],
19461946
filter: Optional[Mapping[str, Any]] = None,
19471947
skip: int = 0,
19481948
limit: int = 0,

gridfs/synchronous/grid_file.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def _disallow_transactions(session: Optional[ClientSession]) -> None:
7070
class GridFS:
7171
"""An instance of GridFS on top of a single Database."""
7272

73-
def __init__(self, database: Database, collection: str = "fs"):
73+
def __init__(self, database: Database[Any], collection: str = "fs"):
7474
"""Create a new instance of :class:`GridFS`.
7575
7676
Raises :class:`TypeError` if `database` is not an instance of
@@ -461,7 +461,7 @@ class GridFSBucket:
461461

462462
def __init__(
463463
self,
464-
db: Database,
464+
db: Database[Any],
465465
bucket_name: str = "fs",
466466
chunk_size_bytes: int = DEFAULT_CHUNK_SIZE,
467467
write_concern: Optional[WriteConcern] = None,
@@ -511,11 +511,11 @@ def __init__(
511511

512512
self._bucket_name = bucket_name
513513
self._collection = db[bucket_name]
514-
self._chunks: Collection = self._collection.chunks.with_options(
514+
self._chunks: Collection[Any] = self._collection.chunks.with_options(
515515
write_concern=write_concern, read_preference=read_preference
516516
)
517517

518-
self._files: Collection = self._collection.files.with_options(
518+
self._files: Collection[Any] = self._collection.files.with_options(
519519
write_concern=write_concern, read_preference=read_preference
520520
)
521521

@@ -1077,7 +1077,7 @@ class GridIn:
10771077

10781078
def __init__(
10791079
self,
1080-
root_collection: Collection,
1080+
root_collection: Collection[Any],
10811081
session: Optional[ClientSession] = None,
10821082
**kwargs: Any,
10831083
) -> None:
@@ -1133,7 +1133,7 @@ def __init__(
11331133
"""
11341134
if not isinstance(root_collection, Collection):
11351135
raise TypeError(
1136-
f"root_collection must be an instance of Collection, not {type(root_collection)}"
1136+
f"root_collection must be an instance of Collection[Any], not {type(root_collection)}"
11371137
)
11381138

11391139
if not root_collection.write_concern.acknowledged:
@@ -1163,7 +1163,7 @@ def __init__(
11631163
object.__setattr__(self, "_buffered_docs", [])
11641164
object.__setattr__(self, "_buffered_docs_size", 0)
11651165

1166-
def _create_index(self, collection: Collection, index_key: Any, unique: bool) -> None:
1166+
def _create_index(self, collection: Collection[Any], index_key: Any, unique: bool) -> None:
11671167
doc = collection.find_one(projection={"_id": 1}, session=self._session)
11681168
if doc is None:
11691169
try:
@@ -1444,7 +1444,7 @@ class GridOut(GRIDOUT_BASE_CLASS): # type: ignore
14441444

14451445
def __init__(
14461446
self,
1447-
root_collection: Collection,
1447+
root_collection: Collection[Any],
14481448
file_id: Optional[int] = None,
14491449
file_document: Optional[Any] = None,
14501450
session: Optional[ClientSession] = None,
@@ -1482,7 +1482,7 @@ def __init__(
14821482
"""
14831483
if not isinstance(root_collection, Collection):
14841484
raise TypeError(
1485-
f"root_collection must be an instance of Collection, not {type(root_collection)}"
1485+
f"root_collection must be an instance of Collection[Any], not {type(root_collection)}"
14861486
)
14871487
_disallow_transactions(session)
14881488

@@ -1817,7 +1817,7 @@ class GridOutChunkIterator:
18171817
def __init__(
18181818
self,
18191819
grid_out: GridOut,
1820-
chunks: Collection,
1820+
chunks: Collection[Any],
18211821
session: Optional[ClientSession],
18221822
next_chunk: Any,
18231823
) -> None:
@@ -1830,7 +1830,7 @@ def __init__(
18301830
self._num_chunks = math.ceil(float(self._length) / self._chunk_size)
18311831
self._cursor = None
18321832

1833-
_cursor: Optional[Cursor]
1833+
_cursor: Optional[Cursor[Any]]
18341834

18351835
def expected_chunk_length(self, chunk_n: int) -> int:
18361836
if chunk_n < self._num_chunks - 1:
@@ -1908,7 +1908,7 @@ def close(self) -> None:
19081908

19091909

19101910
class GridOutIterator:
1911-
def __init__(self, grid_out: GridOut, chunks: Collection, session: ClientSession):
1911+
def __init__(self, grid_out: GridOut, chunks: Collection[Any], session: ClientSession):
19121912
self._chunk_iter = GridOutChunkIterator(grid_out, chunks, session, 0)
19131913

19141914
def __iter__(self) -> GridOutIterator:
@@ -1921,14 +1921,14 @@ def next(self) -> bytes:
19211921
__next__ = next
19221922

19231923

1924-
class GridOutCursor(Cursor):
1924+
class GridOutCursor(Cursor): # type: ignore[type-arg]
19251925
"""A cursor / iterator for returning GridOut objects as the result
19261926
of an arbitrary query against the GridFS files collection.
19271927
"""
19281928

19291929
def __init__(
19301930
self,
1931-
collection: Collection,
1931+
collection: Collection[Any],
19321932
filter: Optional[Mapping[str, Any]] = None,
19331933
skip: int = 0,
19341934
limit: int = 0,

pymongo/asynchronous/auth_oidc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ async def _sasl_continue_jwt(
259259
) -> Mapping[str, Any]:
260260
self.access_token = None
261261
self.refresh_token = None
262-
start_payload: dict = bson.decode(start_resp["payload"])
262+
start_payload: dict[str, Any] = bson.decode(start_resp["payload"])
263263
if "issuer" in start_payload:
264264
self.idp_info = OIDCIdPInfo(**start_payload)
265265
access_token = await self._get_access_token()

pymongo/asynchronous/bulk.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ async def write_command(
248248
request_id: int,
249249
msg: bytes,
250250
docs: list[Mapping[str, Any]],
251-
client: AsyncMongoClient,
251+
client: AsyncMongoClient[Any],
252252
) -> dict[str, Any]:
253253
"""A proxy for SocketInfo.write_command that handles event publishing."""
254254
cmd[bwc.field] = docs
@@ -334,7 +334,7 @@ async def unack_write(
334334
msg: bytes,
335335
max_doc_size: int,
336336
docs: list[Mapping[str, Any]],
337-
client: AsyncMongoClient,
337+
client: AsyncMongoClient[Any],
338338
) -> Optional[Mapping[str, Any]]:
339339
"""A proxy for AsyncConnection.unack_write that handles event publishing."""
340340
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):
@@ -419,7 +419,7 @@ async def _execute_batch_unack(
419419
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
420420
cmd: dict[str, Any],
421421
ops: list[Mapping[str, Any]],
422-
client: AsyncMongoClient,
422+
client: AsyncMongoClient[Any],
423423
) -> list[Mapping[str, Any]]:
424424
if self.is_encrypted:
425425
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)
@@ -446,7 +446,7 @@ async def _execute_batch(
446446
bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext],
447447
cmd: dict[str, Any],
448448
ops: list[Mapping[str, Any]],
449-
client: AsyncMongoClient,
449+
client: AsyncMongoClient[Any],
450450
) -> tuple[dict[str, Any], list[Mapping[str, Any]]]:
451451
if self.is_encrypted:
452452
_, batched_cmd, to_send = bwc.batch_command(cmd, ops)

pymongo/asynchronous/change_stream.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def _aggregation_command_class(self) -> Type[_AggregationCommand]:
164164
raise NotImplementedError
165165

166166
@property
167-
def _client(self) -> AsyncMongoClient:
167+
def _client(self) -> AsyncMongoClient: # type: ignore[type-arg]
168168
"""The client against which the aggregation commands for
169169
this AsyncChangeStream will be run.
170170
"""
@@ -206,7 +206,7 @@ def _command_options(self) -> dict[str, Any]:
206206
def _aggregation_pipeline(self) -> list[dict[str, Any]]:
207207
"""Return the full aggregation pipeline for this AsyncChangeStream."""
208208
options = self._change_stream_options()
209-
full_pipeline: list = [{"$changeStream": options}]
209+
full_pipeline: list[dict[str, Any]] = [{"$changeStream": options}]
210210
full_pipeline.extend(self._pipeline)
211211
return full_pipeline
212212

@@ -237,7 +237,7 @@ def _process_result(self, result: Mapping[str, Any], conn: AsyncConnection) -> N
237237

238238
async def _run_aggregation_cmd(
239239
self, session: Optional[AsyncClientSession], explicit_session: bool
240-
) -> AsyncCommandCursor:
240+
) -> AsyncCommandCursor: # type: ignore[type-arg]
241241
"""Run the full aggregation pipeline for this AsyncChangeStream and return
242242
the corresponding AsyncCommandCursor.
243243
"""
@@ -257,7 +257,7 @@ async def _run_aggregation_cmd(
257257
operation=_Op.AGGREGATE,
258258
)
259259

260-
async def _create_cursor(self) -> AsyncCommandCursor:
260+
async def _create_cursor(self) -> AsyncCommandCursor: # type: ignore[type-arg]
261261
async with self._client._tmp_session(self._session, close=False) as s:
262262
return await self._run_aggregation_cmd(
263263
session=s, explicit_session=self._session is not None

pymongo/asynchronous/client_bulk.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class _AsyncClientBulk:
8888

8989
def __init__(
9090
self,
91-
client: AsyncMongoClient,
91+
client: AsyncMongoClient[Any],
9292
write_concern: WriteConcern,
9393
ordered: bool = True,
9494
bypass_document_validation: Optional[bool] = None,
@@ -233,7 +233,7 @@ async def write_command(
233233
msg: Union[bytes, dict[str, Any]],
234234
op_docs: list[Mapping[str, Any]],
235235
ns_docs: list[Mapping[str, Any]],
236-
client: AsyncMongoClient,
236+
client: AsyncMongoClient[Any],
237237
) -> dict[str, Any]:
238238
"""A proxy for AsyncConnection.write_command that handles event publishing."""
239239
cmd["ops"] = op_docs
@@ -324,7 +324,7 @@ async def unack_write(
324324
msg: bytes,
325325
op_docs: list[Mapping[str, Any]],
326326
ns_docs: list[Mapping[str, Any]],
327-
client: AsyncMongoClient,
327+
client: AsyncMongoClient[Any],
328328
) -> Optional[Mapping[str, Any]]:
329329
"""A proxy for AsyncConnection.unack_write that handles event publishing."""
330330
if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG):

0 commit comments

Comments
 (0)