diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index a0b727dc7a..1ec74aad02 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -231,7 +231,8 @@ def __init__( from pymongo.asynchronous.database import AsyncDatabase if not isinstance(database, AsyncDatabase): - if not any(cls.__name__ == "AsyncDatabase" for cls in database.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "AsyncDatabase" for cls in type(database).__mro__): raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") if not name or ".." in name: diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index fb042972be..06c0eca2c1 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -125,7 +125,8 @@ def __init__( raise TypeError("name must be an instance of str") if not isinstance(client, AsyncMongoClient): - if not any(cls.__name__ == "AsyncMongoClient" for cls in client.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "AsyncMongoClient" for cls in type(client).__mro__): raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") if name != "$external": diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index b03af1b8a1..9b00c13e10 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -597,7 +597,10 @@ def __init__( raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") if not isinstance(key_vault_client, AsyncMongoClient): - if not any(cls.__name__ == "AsyncMongoClient" for cls in key_vault_client.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any( + cls.__name__ == "AsyncMongoClient" for cls in type(key_vault_client).__mro__ + ): raise TypeError( f"AsyncMongoClient required but given {type(key_vault_client).__name__}" ) @@ -688,7 +691,8 @@ async def create_encrypted_collection( """ if not isinstance(database, AsyncDatabase): - if not any(cls.__name__ == "AsyncDatabase" for cls in database.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "AsyncDatabase" for cls in type(database).__mro__): raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") encrypted_fields = deepcopy(encrypted_fields) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 20cc65d9d7..9dba97d12a 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2446,7 +2446,8 @@ def __init__( self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession] ): if not isinstance(client, AsyncMongoClient): - if not any(cls.__name__ == "AsyncMongoClient" for cls in client.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "AsyncMongoClient" for cls in type(client).__mro__): raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") self.client = client diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index ff02c65af5..7a41aef31f 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -234,7 +234,8 @@ def __init__( from pymongo.synchronous.database import Database if not isinstance(database, Database): - if not any(cls.__name__ == "Database" for cls in database.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "Database" for cls in type(database).__mro__): raise TypeError(f"Database required but given {type(database).__name__}") if not name or ".." in name: diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index 5f499fff61..c57a59e09a 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -125,7 +125,8 @@ def __init__( raise TypeError("name must be an instance of str") if not isinstance(client, MongoClient): - if not any(cls.__name__ == "MongoClient" for cls in client.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "MongoClient" for cls in type(client).__mro__): raise TypeError(f"MongoClient required but given {type(client).__name__}") if name != "$external": diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index 8c6411feb9..efef6df9e8 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -595,7 +595,8 @@ def __init__( raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") if not isinstance(key_vault_client, MongoClient): - if not any(cls.__name__ == "MongoClient" for cls in key_vault_client.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "MongoClient" for cls in type(key_vault_client).__mro__): raise TypeError(f"MongoClient required but given {type(key_vault_client).__name__}") self._kms_providers = kms_providers @@ -684,7 +685,8 @@ def create_encrypted_collection( """ if not isinstance(database, Database): - if not any(cls.__name__ == "Database" for cls in database.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "Database" for cls in type(database).__mro__): raise TypeError(f"Database required but given {type(database).__name__}") encrypted_fields = deepcopy(encrypted_fields) diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index ac697405d1..21fa57b5d8 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2434,7 +2434,8 @@ class _MongoClientErrorHandler: def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]): if not isinstance(client, MongoClient): - if not any(cls.__name__ == "MongoClient" for cls in client.__mro__): + # This is for compatibility with mocked and subclassed types, such as in Motor. + if not any(cls.__name__ == "MongoClient" for cls in type(client).__mro__): raise TypeError(f"MongoClient required but given {type(client).__name__}") self.client = client