diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index e5a54c0904..6d8dfaf89a 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -228,6 +228,10 @@ def __init__( ) if not isinstance(name, str): raise TypeError("name must be an instance of str") + from pymongo.asynchronous.database import AsyncDatabase + + if not isinstance(database, AsyncDatabase): + raise TypeError(f"AsyncCollection requires an AsyncDatabase but {type(database)} given") if not name or ".." in name: raise InvalidName("collection names cannot be empty") diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index b61d581839..d5eec0134d 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -119,9 +119,14 @@ def __init__( read_concern or client.read_concern, ) + from pymongo.asynchronous.mongo_client import AsyncMongoClient + if not isinstance(name, str): raise TypeError("name must be an instance of str") + if not isinstance(client, AsyncMongoClient): + raise TypeError(f"AsyncMongoClient required but given {type(client)}") + if name != "$external": _check_name(name) diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 3fb00c6ca9..c4cb886df7 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -194,9 +194,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: # Wrap I/O errors in PyMongo exceptions. _raise_connection_failure((host, port), error) - async def collection_info( - self, database: AsyncDatabase[Mapping[str, Any]], filter: bytes - ) -> Optional[bytes]: + async def collection_info(self, database: str, filter: bytes) -> Optional[bytes]: """Get the collection info for a namespace. The returned collection info is passed to libmongocrypt which reads @@ -598,6 +596,9 @@ def __init__( if not isinstance(codec_options, CodecOptions): raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") + if not isinstance(key_vault_client, AsyncMongoClient): + raise TypeError(f"AsyncMongoClient required but given {type(key_vault_client)}") + self._kms_providers = kms_providers self._key_vault_namespace = key_vault_namespace self._key_vault_client = key_vault_client @@ -683,6 +684,11 @@ async def create_encrypted_collection( https://mongodb.com/docs/manual/reference/command/create """ + if not isinstance(database, AsyncDatabase): + raise TypeError( + f"create_encrypted_collection() requires an AsyncDatabase but {type(database)} given" + ) + encrypted_fields = deepcopy(encrypted_fields) for i, field in enumerate(encrypted_fields["fields"]): if isinstance(field, dict) and field.get("keyId") is None: diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 05e4e80f1d..2af773c440 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2419,6 +2419,9 @@ class _MongoClientErrorHandler: def __init__( self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession] ): + if not isinstance(client, AsyncMongoClient): + raise TypeError(f"AsyncMongoClient required but given {type(client)}") + self.client = client self.server_address = server.description.address self.session = session diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 54db3a56b3..93e24432e5 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -231,6 +231,10 @@ def __init__( ) if not isinstance(name, str): raise TypeError("name must be an instance of str") + from pymongo.synchronous.database import Database + + if not isinstance(database, Database): + raise TypeError(f"Collection requires a Database but {type(database)} given") if not name or ".." in name: raise InvalidName("collection names cannot be empty") diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index 93a9985281..1cd8ee643b 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -119,9 +119,14 @@ def __init__( read_concern or client.read_concern, ) + from pymongo.synchronous.mongo_client import MongoClient + if not isinstance(name, str): raise TypeError("name must be an instance of str") + if not isinstance(client, MongoClient): + raise TypeError(f"MongoClient required but given {type(client)}") + if name != "$external": _check_name(name) diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index e06ddad93d..2efa995978 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -194,9 +194,7 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None: # Wrap I/O errors in PyMongo exceptions. _raise_connection_failure((host, port), error) - def collection_info( - self, database: Database[Mapping[str, Any]], filter: bytes - ) -> Optional[bytes]: + def collection_info(self, database: str, filter: bytes) -> Optional[bytes]: """Get the collection info for a namespace. The returned collection info is passed to libmongocrypt which reads @@ -596,6 +594,9 @@ def __init__( if not isinstance(codec_options, CodecOptions): raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") + if not isinstance(key_vault_client, MongoClient): + raise TypeError(f"MongoClient required but given {type(key_vault_client)}") + self._kms_providers = kms_providers self._key_vault_namespace = key_vault_namespace self._key_vault_client = key_vault_client @@ -681,6 +682,11 @@ def create_encrypted_collection( https://mongodb.com/docs/manual/reference/command/create """ + if not isinstance(database, Database): + raise TypeError( + f"create_encrypted_collection() requires a Database but {type(database)} given" + ) + encrypted_fields = deepcopy(encrypted_fields) for i, field in enumerate(encrypted_fields["fields"]): if isinstance(field, dict) and field.get("keyId") is None: diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 77e029a7c9..6c5f68b7eb 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2406,6 +2406,9 @@ class _MongoClientErrorHandler: ) def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]): + if not isinstance(client, MongoClient): + raise TypeError(f"MongoClient required but given {type(client)}") + self.client = client self.server_address = server.description.address self.session = session diff --git a/test/helpers.py b/test/helpers.py index d136e5b8d2..b38b2e2980 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -35,23 +35,13 @@ HAVE_IPADDRESS = True except ImportError: HAVE_IPADDRESS = False -from contextlib import contextmanager from functools import wraps -from test.version import Version from typing import Any, Callable, Dict, Generator, no_type_check from unittest import SkipTest -from urllib.parse import quote_plus -import pymongo -import pymongo.errors from bson.son import SON from pymongo import common, message -from pymongo.common import partition_node -from pymongo.hello import HelloCompat -from pymongo.server_api import ServerApi from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -from pymongo.synchronous.database import Database -from pymongo.synchronous.mongo_client import MongoClient from pymongo.uri_parser import parse_uri if HAVE_SSL: