Skip to content

Commit 98eee8f

Browse files
committed
PYTHON-5257 - Turn on mypy disallow_any_generics
1 parent 740f38f commit 98eee8f

36 files changed

+211
-190
lines changed

bson/json_util.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ def _encode_binary(data: bytes, subtype: int, json_options: JSONOptions) -> Any:
844844
return {"$binary": {"base64": base64.b64encode(data).decode(), "subType": "%02x" % subtype}}
845845

846846

847-
def _encode_datetimems(obj: Any, json_options: JSONOptions) -> dict:
847+
def _encode_datetimems(obj: Any, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
848848
if (
849849
json_options.datetime_representation == DatetimeRepresentation.ISO8601
850850
and 0 <= int(obj) <= _MAX_UTC_MS
@@ -855,7 +855,7 @@ def _encode_datetimems(obj: Any, json_options: JSONOptions) -> dict:
855855
return {"$date": {"$numberLong": str(int(obj))}}
856856

857857

858-
def _encode_code(obj: Code, json_options: JSONOptions) -> dict:
858+
def _encode_code(obj: Code, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
859859
if obj.scope is None:
860860
return {"$code": str(obj)}
861861
else:
@@ -873,7 +873,7 @@ def _encode_noop(obj: Any, dummy0: Any) -> Any:
873873
return obj
874874

875875

876-
def _encode_regex(obj: Any, json_options: JSONOptions) -> dict:
876+
def _encode_regex(obj: Any, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
877877
flags = ""
878878
if obj.flags & re.IGNORECASE:
879879
flags += "i"
@@ -918,7 +918,7 @@ def _encode_float(obj: float, json_options: JSONOptions) -> Any:
918918
return obj
919919

920920

921-
def _encode_datetime(obj: datetime.datetime, json_options: JSONOptions) -> dict:
921+
def _encode_datetime(obj: datetime.datetime, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
922922
if json_options.datetime_representation == DatetimeRepresentation.ISO8601:
923923
if not obj.tzinfo:
924924
obj = obj.replace(tzinfo=utc)
@@ -941,51 +941,51 @@ def _encode_datetime(obj: datetime.datetime, json_options: JSONOptions) -> dict:
941941
return {"$date": {"$numberLong": str(millis)}}
942942

943943

944-
def _encode_bytes(obj: bytes, json_options: JSONOptions) -> dict:
944+
def _encode_bytes(obj: bytes, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
945945
return _encode_binary(obj, 0, json_options)
946946

947947

948-
def _encode_binary_obj(obj: Binary, json_options: JSONOptions) -> dict:
948+
def _encode_binary_obj(obj: Binary, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
949949
return _encode_binary(obj, obj.subtype, json_options)
950950

951951

952-
def _encode_uuid(obj: uuid.UUID, json_options: JSONOptions) -> dict:
952+
def _encode_uuid(obj: uuid.UUID, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
953953
if json_options.strict_uuid:
954954
binval = Binary.from_uuid(obj, uuid_representation=json_options.uuid_representation)
955955
return _encode_binary(binval, binval.subtype, json_options)
956956
else:
957957
return {"$uuid": obj.hex}
958958

959959

960-
def _encode_objectid(obj: ObjectId, dummy0: Any) -> dict:
960+
def _encode_objectid(obj: ObjectId, dummy0: Any) -> dict: # type: ignore[type-arg]
961961
return {"$oid": str(obj)}
962962

963963

964-
def _encode_timestamp(obj: Timestamp, dummy0: Any) -> dict:
964+
def _encode_timestamp(obj: Timestamp, dummy0: Any) -> dict: # type: ignore[type-arg]
965965
return {"$timestamp": {"t": obj.time, "i": obj.inc}}
966966

967967

968-
def _encode_decimal128(obj: Timestamp, dummy0: Any) -> dict:
968+
def _encode_decimal128(obj: Timestamp, dummy0: Any) -> dict: # type: ignore[type-arg]
969969
return {"$numberDecimal": str(obj)}
970970

971971

972-
def _encode_dbref(obj: DBRef, json_options: JSONOptions) -> dict:
972+
def _encode_dbref(obj: DBRef, json_options: JSONOptions) -> dict: # type: ignore[type-arg]
973973
return _json_convert(obj.as_doc(), json_options=json_options)
974974

975975

976-
def _encode_minkey(dummy0: Any, dummy1: Any) -> dict:
976+
def _encode_minkey(dummy0: Any, dummy1: Any) -> dict: # type: ignore[type-arg]
977977
return {"$minKey": 1}
978978

979979

980-
def _encode_maxkey(dummy0: Any, dummy1: Any) -> dict:
980+
def _encode_maxkey(dummy0: Any, dummy1: Any) -> dict: # type: ignore[type-arg]
981981
return {"$maxKey": 1}
982982

983983

984984
# Encoders for BSON types
985985
# Each encoder function's signature is:
986986
# - obj: a Python data type, e.g. a Python int for _encode_int
987987
# - json_options: a JSONOptions
988-
_ENCODERS: dict[Type, Callable[[Any, JSONOptions], Any]] = {
988+
_ENCODERS: dict[Type, Callable[[Any, JSONOptions], Any]] = { # type: ignore[type-arg]
989989
bool: _encode_noop,
990990
bytes: _encode_bytes,
991991
datetime.datetime: _encode_datetime,
@@ -1056,7 +1056,7 @@ def _get_datetime_size(obj: datetime.datetime) -> int:
10561056
return 5 + len(str(obj.time()))
10571057

10581058

1059-
def _get_regex_size(obj: Regex) -> int:
1059+
def _get_regex_size(obj: Regex) -> int: # type: ignore[type-arg]
10601060
return 18 + len(obj.pattern)
10611061

10621062

bson/typings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@
2828
_DocumentOut = Union[MutableMapping[str, Any], "RawBSONDocument"]
2929
_DocumentType = TypeVar("_DocumentType", bound=Mapping[str, Any])
3030
_DocumentTypeArg = TypeVar("_DocumentTypeArg", bound=Mapping[str, Any])
31-
_ReadableBuffer = Union[bytes, memoryview, "mmap", "array"]
31+
_ReadableBuffer = Union[bytes, memoryview, "mmap", "array"] # type: ignore[type-arg]

pymongo/_asyncio_lock.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class Lock(_ContextManagerMixin, _LoopBoundMixin):
9393
"""
9494

9595
def __init__(self) -> None:
96-
self._waiters: Optional[collections.deque] = None
96+
self._waiters: Optional[collections.deque[Any]] = None
9797
self._locked = False
9898

9999
def __repr__(self) -> str:
@@ -196,7 +196,7 @@ def __init__(self, lock: Optional[Lock] = None) -> None:
196196
self.acquire = lock.acquire
197197
self.release = lock.release
198198

199-
self._waiters: collections.deque = collections.deque()
199+
self._waiters: collections.deque[Any] = collections.deque()
200200

201201
def __repr__(self) -> str:
202202
res = super().__repr__()
@@ -260,7 +260,7 @@ async def wait(self) -> bool:
260260
self._notify(1)
261261
raise
262262

263-
async def wait_for(self, predicate: Any) -> Coroutine:
263+
async def wait_for(self, predicate: Any) -> Coroutine[Any, Any, Any]:
264264
"""Wait until a predicate becomes true.
265265
266266
The predicate should be a callable whose result will be

pymongo/_asyncio_task.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
# TODO (https://jira.mongodb.org/browse/PYTHON-4981): Revisit once the underlying cause of the swallowed cancellations is uncovered
27-
class _Task(asyncio.Task):
27+
class _Task(asyncio.Task[Any]):
2828
def __init__(self, coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> None:
2929
super().__init__(coro, name=name)
3030
self._cancel_requests = 0
@@ -43,7 +43,7 @@ def cancelling(self) -> int:
4343
return self._cancel_requests
4444

4545

46-
def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task:
46+
def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task[Any]:
4747
if sys.version_info >= (3, 11):
4848
return asyncio.create_task(coro, name=name)
4949
return _Task(coro, name=name)

pymongo/_csot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def clamp_remaining(max_timeout: float) -> float:
6868
return min(timeout, max_timeout)
6969

7070

71-
class _TimeoutContext(AbstractContextManager):
71+
class _TimeoutContext(AbstractContextManager[Any]):
7272
"""Internal timeout context manager.
7373
7474
Use :func:`pymongo.timeout` instead::

pymongo/asynchronous/aggregation.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ class _AggregationCommand:
4646

4747
def __init__(
4848
self,
49-
target: Union[AsyncDatabase, AsyncCollection],
50-
cursor_class: type[AsyncCommandCursor],
49+
target: Union[AsyncDatabase[Any], AsyncCollection[Any]],
50+
cursor_class: type[AsyncCommandCursor[Any]],
5151
pipeline: _Pipeline,
5252
options: MutableMapping[str, Any],
5353
explicit_session: bool,
@@ -111,12 +111,12 @@ def _cursor_namespace(self) -> str:
111111
"""The namespace in which the aggregate command is run."""
112112
raise NotImplementedError
113113

114-
def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> AsyncCollection:
114+
def _cursor_collection(self, cursor_doc: Mapping[str, Any]) -> AsyncCollection[Any]:
115115
"""The AsyncCollection used for the aggregate command cursor."""
116116
raise NotImplementedError
117117

118118
@property
119-
def _database(self) -> AsyncDatabase:
119+
def _database(self) -> AsyncDatabase[Any]:
120120
"""The database against which the aggregation command is run."""
121121
raise NotImplementedError
122122

@@ -205,7 +205,7 @@ async def get_cursor(
205205

206206

207207
class _CollectionAggregationCommand(_AggregationCommand):
208-
_target: AsyncCollection
208+
_target: AsyncCollection[Any]
209209

210210
@property
211211
def _aggregation_target(self) -> str:
@@ -215,12 +215,12 @@ def _aggregation_target(self) -> str:
215215
def _cursor_namespace(self) -> str:
216216
return self._target.full_name
217217

218-
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection:
218+
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection[Any]:
219219
"""The AsyncCollection used for the aggregate command cursor."""
220220
return self._target
221221

222222
@property
223-
def _database(self) -> AsyncDatabase:
223+
def _database(self) -> AsyncDatabase[Any]:
224224
return self._target.database
225225

226226

@@ -234,7 +234,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
234234

235235

236236
class _DatabaseAggregationCommand(_AggregationCommand):
237-
_target: AsyncDatabase
237+
_target: AsyncDatabase[Any]
238238

239239
@property
240240
def _aggregation_target(self) -> int:
@@ -245,10 +245,10 @@ def _cursor_namespace(self) -> str:
245245
return f"{self._target.name}.$cmd.aggregate"
246246

247247
@property
248-
def _database(self) -> AsyncDatabase:
248+
def _database(self) -> AsyncDatabase[Any]:
249249
return self._target
250250

251-
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection:
251+
def _cursor_collection(self, cursor: Mapping[str, Any]) -> AsyncCollection[Any]:
252252
"""The AsyncCollection used for the aggregate command cursor."""
253253
# AsyncCollection level aggregate may not always return the "ns" field
254254
# according to our MockupDB tests. Let's handle that case for db level

pymongo/asynchronous/collection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ async def _command(
581581
conn: AsyncConnection,
582582
command: MutableMapping[str, Any],
583583
read_preference: Optional[_ServerMode] = None,
584-
codec_options: Optional[CodecOptions[dict[str, Any]]] = None,
584+
codec_options: Optional[CodecOptions[Mapping[str, Any]]] = None,
585585
check: bool = True,
586586
allowable_errors: Optional[Sequence[Union[str, int]]] = None,
587587
read_concern: Optional[ReadConcern] = None,
@@ -2525,7 +2525,7 @@ async def _list_indexes(
25252525
session: Optional[AsyncClientSession] = None,
25262526
comment: Optional[Any] = None,
25272527
) -> AsyncCommandCursor[MutableMapping[str, Any]]:
2528-
codec_options: CodecOptions[dict[str, Any]] = CodecOptions(SON)
2528+
codec_options: CodecOptions[Mapping[str, Any]] = CodecOptions(SON)
25292529
coll = cast(
25302530
AsyncCollection[MutableMapping[str, Any]],
25312531
self.with_options(codec_options=codec_options, read_preference=ReadPreference.PRIMARY),

pymongo/asynchronous/cursor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -969,13 +969,15 @@ def _deepcopy(self, x: Iterable, memo: Optional[dict[int, Union[list, dict]]] =
969969

970970
@overload
971971
def _deepcopy(
972-
self, x: SupportsItems, memo: Optional[dict[int, Union[list, dict]]] = None
972+
self,
973+
x: SupportsItems, # type: ignore[type-arg]
974+
memo: Optional[dict[int, Union[list, dict]]] = None, # type: ignore[type-arg]
973975
) -> dict: # type: ignore[type-arg]
974976
...
975977

976978
def _deepcopy(
977979
self,
978-
x: Union[Iterable, SupportsItems],
980+
x: Union[Iterable, SupportsItems], # type: ignore[type-arg]
979981
memo: Optional[dict[int, Union[list, dict]]] = None, # type: ignore[type-arg]
980982
) -> Union[list[Any], dict[str, Any]]:
981983
"""Deepcopy helper for the data dictionary or list.

pymongo/asynchronous/database.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ async def _command(
771771
self._name,
772772
command,
773773
read_preference,
774-
codec_options,
774+
codec_options, # type: ignore[arg-type]
775775
check,
776776
allowable_errors,
777777
write_concern=write_concern,

pymongo/asynchronous/mongo_client.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,10 @@
161161
_IS_SYNC = False
162162

163163
_WriteOp = Union[
164-
InsertOne,
164+
InsertOne, # type: ignore[type-arg]
165165
DeleteOne,
166166
DeleteMany,
167-
ReplaceOne,
167+
ReplaceOne, # type: ignore[type-arg]
168168
UpdateOne,
169169
UpdateMany,
170170
]
@@ -176,7 +176,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
176176
# Define order to retrieve options from ClientOptions for __repr__.
177177
# No host/port; these are retrieved from TopologySettings.
178178
_constructor_args = ("document_class", "tz_aware", "connect")
179-
_clients: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
179+
_clients: weakref.WeakValueDictionary = weakref.WeakValueDictionary() # type: ignore[type-arg]
180180

181181
def __init__(
182182
self,
@@ -847,7 +847,7 @@ def __init__(
847847

848848
self._default_database_name = dbase
849849
self._lock = _async_create_lock()
850-
self._kill_cursors_queue: list = []
850+
self._kill_cursors_queue: list = [] # type: ignore[type-arg]
851851

852852
self._encrypter: Optional[_Encrypter] = None
853853

@@ -1064,7 +1064,7 @@ def _after_fork(self) -> None:
10641064
# Reset the session pool to avoid duplicate sessions in the child process.
10651065
self._topology._session_pool.reset()
10661066

1067-
def _duplicate(self, **kwargs: Any) -> AsyncMongoClient:
1067+
def _duplicate(self, **kwargs: Any) -> AsyncMongoClient: # type: ignore[type-arg]
10681068
args = self._init_kwargs.copy()
10691069
args.update(kwargs)
10701070
return AsyncMongoClient(**args)
@@ -1548,7 +1548,7 @@ def get_database(
15481548
self, name, codec_options, read_preference, write_concern, read_concern
15491549
)
15501550

1551-
def _database_default_options(self, name: str) -> database.AsyncDatabase:
1551+
def _database_default_options(self, name: str) -> database.AsyncDatabase: # type: ignore[type-arg]
15521552
"""Get a AsyncDatabase instance with the default settings."""
15531553
return self.get_database(
15541554
name,
@@ -1887,7 +1887,7 @@ async def _conn_for_reads(
18871887
async def _run_operation(
18881888
self,
18891889
operation: Union[_Query, _GetMore],
1890-
unpack_res: Callable,
1890+
unpack_res: Callable, # type: ignore[type-arg]
18911891
address: Optional[_Address] = None,
18921892
) -> Response:
18931893
"""Run a _Query/_GetMore operation and return a Response.
@@ -2261,7 +2261,7 @@ def _return_server_session(
22612261
@contextlib.asynccontextmanager
22622262
async def _tmp_session(
22632263
self, session: Optional[client_session.AsyncClientSession], close: bool = True
2264-
) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None, None]:
2264+
) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None]:
22652265
"""If provided session is None, lend a temporary session."""
22662266
if session is not None:
22672267
if not isinstance(session, client_session.AsyncClientSession):
@@ -2308,8 +2308,8 @@ async def server_info(
23082308
.. versionchanged:: 3.6
23092309
Added ``session`` parameter.
23102310
"""
2311-
return cast(
2312-
dict,
2311+
return cast( # type: ignore[redundant-cast]
2312+
dict[str, Any],
23132313
await self.admin.command(
23142314
"buildinfo", read_preference=ReadPreference.PRIMARY, session=session
23152315
),
@@ -2438,7 +2438,7 @@ async def drop_database(
24382438
@_csot.apply
24392439
async def bulk_write(
24402440
self,
2441-
models: Sequence[_WriteOp[_DocumentType]],
2441+
models: Sequence[_WriteOp],
24422442
session: Optional[AsyncClientSession] = None,
24432443
ordered: bool = True,
24442444
verbose_results: bool = False,
@@ -2631,7 +2631,10 @@ class _MongoClientErrorHandler:
26312631
)
26322632

26332633
def __init__(
2634-
self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession]
2634+
self,
2635+
client: AsyncMongoClient, # type: ignore[type-arg]
2636+
server: Server,
2637+
session: Optional[AsyncClientSession],
26352638
):
26362639
if not isinstance(client, AsyncMongoClient):
26372640
# This is for compatibility with mocked and subclassed types, such as in Motor.
@@ -2704,7 +2707,7 @@ class _ClientConnectionRetryable(Generic[T]):
27042707

27052708
def __init__(
27062709
self,
2707-
mongo_client: AsyncMongoClient,
2710+
mongo_client: AsyncMongoClient, # type: ignore[type-arg]
27082711
func: _WriteCall[T] | _ReadCall[T],
27092712
bulk: Optional[Union[_AsyncBulk, _AsyncClientBulk]],
27102713
operation: str,

0 commit comments

Comments
 (0)