From cabcee5d3d48ec77890dd6ad9873d7b0894270c0 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 10 Sep 2024 15:14:07 -0400 Subject: [PATCH 1/5] PYTHON-4590 - Make type guards more compatible --- pymongo/asynchronous/collection.py | 4 ++-- pymongo/asynchronous/database.py | 4 ++-- pymongo/asynchronous/encryption.py | 15 +++++++++------ pymongo/asynchronous/mongo_client.py | 4 ++-- pymongo/synchronous/collection.py | 4 ++-- pymongo/synchronous/database.py | 4 ++-- pymongo/synchronous/encryption.py | 13 +++++++------ pymongo/synchronous/mongo_client.py | 4 ++-- test/unified_format.py | 6 +++--- 9 files changed, 31 insertions(+), 27 deletions(-) diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 6d8dfaf89a..5cf9d1c1e0 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -230,8 +230,8 @@ def __init__( raise TypeError("name must be an instance of str") from pymongo.asynchronous.database import AsyncDatabase - if not isinstance(database, AsyncDatabase): - raise TypeError(f"AsyncCollection requires an AsyncDatabase but {type(database)} given") + if not isinstance(database, AsyncDatabase) and type(database).__name__ != "AsyncDatabase": + raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") if not name or ".." in name: raise InvalidName("collection names cannot be empty") diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index d5eec0134d..cfd75cce12 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -124,8 +124,8 @@ def __init__( if not isinstance(name, str): raise TypeError("name must be an instance of str") - if not isinstance(client, AsyncMongoClient): - raise TypeError(f"AsyncMongoClient required but given {type(client)}") + if not isinstance(client, AsyncMongoClient) and type(client).__name__ != "AsyncMongoClient": + raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") if name != "$external": _check_name(name) diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index c9e3cadd6e..2ea605fd1c 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -596,8 +596,13 @@ def __init__( if not isinstance(codec_options, CodecOptions): raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") - if not isinstance(key_vault_client, AsyncMongoClient): - raise TypeError(f"AsyncMongoClient required but given {type(key_vault_client)}") + if ( + not isinstance(key_vault_client, AsyncMongoClient) + and type(key_vault_client).__name__ != "AsyncMongoClient" + ): + raise TypeError( + f"AsyncMongoClient required but given {type(key_vault_client).__name__}" + ) self._kms_providers = kms_providers self._key_vault_namespace = key_vault_namespace @@ -684,10 +689,8 @@ async def create_encrypted_collection( https://mongodb.com/docs/manual/reference/command/create """ - if not isinstance(database, AsyncDatabase): - raise TypeError( - f"create_encrypted_collection() requires an AsyncDatabase but {type(database)} given" - ) + if not isinstance(database, AsyncDatabase) and type(database).__name__ != "AsyncDatabase": + raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") encrypted_fields = deepcopy(encrypted_fields) for i, field in enumerate(encrypted_fields["fields"]): diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index a84fbf2e59..f5f38a9e2b 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2445,8 +2445,8 @@ class _MongoClientErrorHandler: def __init__( self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession] ): - if not isinstance(client, AsyncMongoClient): - raise TypeError(f"AsyncMongoClient required but given {type(client)}") + if not isinstance(client, AsyncMongoClient) and type(client).__name__ != "AsyncMongoClient": + raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") self.client = client self.server_address = server.description.address diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 93e24432e5..39831d8207 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -233,8 +233,8 @@ def __init__( raise TypeError("name must be an instance of str") from pymongo.synchronous.database import Database - if not isinstance(database, Database): - raise TypeError(f"Collection requires a Database but {type(database)} given") + if not isinstance(database, Database) and type(database).__name__ != "Database": + raise TypeError(f"Database required but given {type(database).__name__}") if not name or ".." in name: raise InvalidName("collection names cannot be empty") diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index 1cd8ee643b..0b57974f69 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -124,8 +124,8 @@ def __init__( if not isinstance(name, str): raise TypeError("name must be an instance of str") - if not isinstance(client, MongoClient): - raise TypeError(f"MongoClient required but given {type(client)}") + if not isinstance(client, MongoClient) and type(client).__name__ != "MongoClient": + raise TypeError(f"MongoClient required but given {type(client).__name__}") if name != "$external": _check_name(name) diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index 3849cf3f2b..c79a2900a6 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -594,8 +594,11 @@ def __init__( if not isinstance(codec_options, CodecOptions): raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") - if not isinstance(key_vault_client, MongoClient): - raise TypeError(f"MongoClient required but given {type(key_vault_client)}") + if ( + not isinstance(key_vault_client, MongoClient) + and type(key_vault_client).__name__ != "MongoClient" + ): + raise TypeError(f"MongoClient required but given {type(key_vault_client).__name__}") self._kms_providers = kms_providers self._key_vault_namespace = key_vault_namespace @@ -682,10 +685,8 @@ def create_encrypted_collection( https://mongodb.com/docs/manual/reference/command/create """ - if not isinstance(database, Database): - raise TypeError( - f"create_encrypted_collection() requires a Database but {type(database)} given" - ) + if not isinstance(database, Database) and type(database).__name__ != "Database": + raise TypeError(f"Database required but given {type(database).__name__}") encrypted_fields = deepcopy(encrypted_fields) for i, field in enumerate(encrypted_fields["fields"]): diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index cec78463b3..1c9e1eb3f4 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2433,8 +2433,8 @@ class _MongoClientErrorHandler: ) def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]): - if not isinstance(client, MongoClient): - raise TypeError(f"MongoClient required but given {type(client)}") + if not isinstance(client, MongoClient) and type(client).__name__ != "MongoClient": + raise TypeError(f"MongoClient required but given {type(client).__name__}") self.client = client self.server_address = server.description.address diff --git a/test/unified_format.py b/test/unified_format.py index 63cd23af88..78fc638787 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -580,7 +580,7 @@ def _create_entity(self, entity_spec, uri=None): return elif entity_type == "database": client = self[spec["client"]] - if not isinstance(client, MongoClient): + if type(client).__name__ != "MongoClient": self.test.fail( "Expected entity {} to be of type MongoClient, got {}".format( spec["client"], type(client) @@ -602,7 +602,7 @@ def _create_entity(self, entity_spec, uri=None): return elif entity_type == "session": client = self[spec["client"]] - if not isinstance(client, MongoClient): + if type(client).__name__ != "MongoClient": self.test.fail( "Expected entity {} to be of type MongoClient, got {}".format( spec["client"], type(client) @@ -667,7 +667,7 @@ def create_entities_from_spec(self, entity_spec, uri=None): def get_listener_for_client(self, client_name: str) -> EventListenerUtil: client = self[client_name] - if not isinstance(client, MongoClient): + if type(client).__name__ != "MongoClient": self.test.fail( f"Expected entity {client_name} to be of type MongoClient, got {type(client)}" ) From 4550031b97684c30ebbf487e7bfe1e14b3100961 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 10 Sep 2024 16:10:37 -0400 Subject: [PATCH 2/5] Fixes --- pymongo/asynchronous/collection.py | 5 +++-- pymongo/asynchronous/database.py | 5 +++-- pymongo/asynchronous/encryption.py | 17 ++++++++--------- pymongo/asynchronous/mongo_client.py | 5 +++-- pymongo/synchronous/collection.py | 5 +++-- pymongo/synchronous/database.py | 5 +++-- pymongo/synchronous/encryption.py | 13 ++++++------- pymongo/synchronous/mongo_client.py | 5 +++-- 8 files changed, 32 insertions(+), 28 deletions(-) diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 5cf9d1c1e0..a0b727dc7a 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -230,8 +230,9 @@ def __init__( raise TypeError("name must be an instance of str") from pymongo.asynchronous.database import AsyncDatabase - if not isinstance(database, AsyncDatabase) and type(database).__name__ != "AsyncDatabase": - raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") + if not isinstance(database, AsyncDatabase): + if not any(cls.__name__ == "AsyncDatabase" for cls in database.__mro__): + raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") if not name or ".." in name: raise InvalidName("collection names cannot be empty") diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index cfd75cce12..fb042972be 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -124,8 +124,9 @@ def __init__( if not isinstance(name, str): raise TypeError("name must be an instance of str") - if not isinstance(client, AsyncMongoClient) and type(client).__name__ != "AsyncMongoClient": - raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") + if not isinstance(client, AsyncMongoClient): + if not any(cls.__name__ == "AsyncMongoClient" for cls in client.__mro__): + raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") if name != "$external": _check_name(name) diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 2ea605fd1c..b03af1b8a1 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -596,13 +596,11 @@ def __init__( if not isinstance(codec_options, CodecOptions): raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") - if ( - not isinstance(key_vault_client, AsyncMongoClient) - and type(key_vault_client).__name__ != "AsyncMongoClient" - ): - raise TypeError( - f"AsyncMongoClient required but given {type(key_vault_client).__name__}" - ) + if not isinstance(key_vault_client, AsyncMongoClient): + if not any(cls.__name__ == "AsyncMongoClient" for cls in key_vault_client.__mro__): + raise TypeError( + f"AsyncMongoClient required but given {type(key_vault_client).__name__}" + ) self._kms_providers = kms_providers self._key_vault_namespace = key_vault_namespace @@ -689,8 +687,9 @@ async def create_encrypted_collection( https://mongodb.com/docs/manual/reference/command/create """ - if not isinstance(database, AsyncDatabase) and type(database).__name__ != "AsyncDatabase": - raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") + if not isinstance(database, AsyncDatabase): + if not any(cls.__name__ == "AsyncDatabase" for cls in database.__mro__): + raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") encrypted_fields = deepcopy(encrypted_fields) for i, field in enumerate(encrypted_fields["fields"]): diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index f5f38a9e2b..20cc65d9d7 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2445,8 +2445,9 @@ class _MongoClientErrorHandler: def __init__( self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession] ): - if not isinstance(client, AsyncMongoClient) and type(client).__name__ != "AsyncMongoClient": - raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") + if not isinstance(client, AsyncMongoClient): + if not any(cls.__name__ == "AsyncMongoClient" for cls in client.__mro__): + raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") self.client = client self.server_address = server.description.address diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 39831d8207..ff02c65af5 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -233,8 +233,9 @@ def __init__( raise TypeError("name must be an instance of str") from pymongo.synchronous.database import Database - if not isinstance(database, Database) and type(database).__name__ != "Database": - raise TypeError(f"Database required but given {type(database).__name__}") + if not isinstance(database, Database): + if not any(cls.__name__ == "Database" for cls in database.__mro__): + raise TypeError(f"Database required but given {type(database).__name__}") if not name or ".." in name: raise InvalidName("collection names cannot be empty") diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index 0b57974f69..5f499fff61 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -124,8 +124,9 @@ def __init__( if not isinstance(name, str): raise TypeError("name must be an instance of str") - if not isinstance(client, MongoClient) and type(client).__name__ != "MongoClient": - raise TypeError(f"MongoClient required but given {type(client).__name__}") + if not isinstance(client, MongoClient): + if not any(cls.__name__ == "MongoClient" for cls in client.__mro__): + raise TypeError(f"MongoClient required but given {type(client).__name__}") if name != "$external": _check_name(name) diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index c79a2900a6..8c6411feb9 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -594,11 +594,9 @@ def __init__( if not isinstance(codec_options, CodecOptions): raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") - if ( - not isinstance(key_vault_client, MongoClient) - and type(key_vault_client).__name__ != "MongoClient" - ): - raise TypeError(f"MongoClient required but given {type(key_vault_client).__name__}") + if not isinstance(key_vault_client, MongoClient): + if not any(cls.__name__ == "MongoClient" for cls in key_vault_client.__mro__): + raise TypeError(f"MongoClient required but given {type(key_vault_client).__name__}") self._kms_providers = kms_providers self._key_vault_namespace = key_vault_namespace @@ -685,8 +683,9 @@ def create_encrypted_collection( https://mongodb.com/docs/manual/reference/command/create """ - if not isinstance(database, Database) and type(database).__name__ != "Database": - raise TypeError(f"Database required but given {type(database).__name__}") + if not isinstance(database, Database): + if not any(cls.__name__ == "Database" for cls in database.__mro__): + raise TypeError(f"Database required but given {type(database).__name__}") encrypted_fields = deepcopy(encrypted_fields) for i, field in enumerate(encrypted_fields["fields"]): diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 1c9e1eb3f4..ac697405d1 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2433,8 +2433,9 @@ class _MongoClientErrorHandler: ) def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]): - if not isinstance(client, MongoClient) and type(client).__name__ != "MongoClient": - raise TypeError(f"MongoClient required but given {type(client).__name__}") + if not isinstance(client, MongoClient): + if not any(cls.__name__ == "MongoClient" for cls in client.__mro__): + raise TypeError(f"MongoClient required but given {type(client).__name__}") self.client = client self.server_address = server.description.address From ebb1a7f9250ba6836261d3ebf6999816d5830521 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 11 Sep 2024 08:58:24 -0400 Subject: [PATCH 3/5] Fix __mro__ --- pymongo/asynchronous/collection.py | 2 +- pymongo/asynchronous/database.py | 2 +- pymongo/asynchronous/encryption.py | 6 ++++-- pymongo/asynchronous/mongo_client.py | 2 +- pymongo/synchronous/collection.py | 2 +- pymongo/synchronous/database.py | 2 +- pymongo/synchronous/encryption.py | 4 ++-- pymongo/synchronous/mongo_client.py | 2 +- 8 files changed, 12 insertions(+), 10 deletions(-) diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index a0b727dc7a..6e798f38d7 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -231,7 +231,7 @@ def __init__( from pymongo.asynchronous.database import AsyncDatabase if not isinstance(database, AsyncDatabase): - if not any(cls.__name__ == "AsyncDatabase" for cls in database.__mro__): + if not any(cls.__name__ == "AsyncDatabase" for cls in type(database).__mro__): raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") if not name or ".." in name: diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index fb042972be..36e598dc35 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -125,7 +125,7 @@ def __init__( raise TypeError("name must be an instance of str") if not isinstance(client, AsyncMongoClient): - if not any(cls.__name__ == "AsyncMongoClient" for cls in client.__mro__): + if not any(cls.__name__ == "AsyncMongoClient" for cls in type(client).__mro__): raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") if name != "$external": diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index b03af1b8a1..0a9a3a9dbd 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -597,7 +597,9 @@ def __init__( raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") if not isinstance(key_vault_client, AsyncMongoClient): - if not any(cls.__name__ == "AsyncMongoClient" for cls in key_vault_client.__mro__): + if not any( + cls.__name__ == "AsyncMongoClient" for cls in type(key_vault_client).__mro__ + ): raise TypeError( f"AsyncMongoClient required but given {type(key_vault_client).__name__}" ) @@ -688,7 +690,7 @@ async def create_encrypted_collection( """ if not isinstance(database, AsyncDatabase): - if not any(cls.__name__ == "AsyncDatabase" for cls in database.__mro__): + if not any(cls.__name__ == "AsyncDatabase" for cls in type(database).__mro__): raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") encrypted_fields = deepcopy(encrypted_fields) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 20cc65d9d7..6850663a43 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2446,7 +2446,7 @@ def __init__( self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession] ): if not isinstance(client, AsyncMongoClient): - if not any(cls.__name__ == "AsyncMongoClient" for cls in client.__mro__): + if not any(cls.__name__ == "AsyncMongoClient" for cls in type(client).__mro__): raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") self.client = client diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index ff02c65af5..e9f56b1176 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -234,7 +234,7 @@ def __init__( from pymongo.synchronous.database import Database if not isinstance(database, Database): - if not any(cls.__name__ == "Database" for cls in database.__mro__): + if not any(cls.__name__ == "Database" for cls in type(database).__mro__): raise TypeError(f"Database required but given {type(database).__name__}") if not name or ".." in name: diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index 5f499fff61..5b5cc77c62 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -125,7 +125,7 @@ def __init__( raise TypeError("name must be an instance of str") if not isinstance(client, MongoClient): - if not any(cls.__name__ == "MongoClient" for cls in client.__mro__): + if not any(cls.__name__ == "MongoClient" for cls in type(client).__mro__): raise TypeError(f"MongoClient required but given {type(client).__name__}") if name != "$external": diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index 8c6411feb9..02d8775d89 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -595,7 +595,7 @@ def __init__( raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") if not isinstance(key_vault_client, MongoClient): - if not any(cls.__name__ == "MongoClient" for cls in key_vault_client.__mro__): + if not any(cls.__name__ == "MongoClient" for cls in type(key_vault_client).__mro__): raise TypeError(f"MongoClient required but given {type(key_vault_client).__name__}") self._kms_providers = kms_providers @@ -684,7 +684,7 @@ def create_encrypted_collection( """ if not isinstance(database, Database): - if not any(cls.__name__ == "Database" for cls in database.__mro__): + if not any(cls.__name__ == "Database" for cls in type(database).__mro__): raise TypeError(f"Database required but given {type(database).__name__}") encrypted_fields = deepcopy(encrypted_fields) diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index ac697405d1..16e6d5821a 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2434,7 +2434,7 @@ class _MongoClientErrorHandler: def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]): if not isinstance(client, MongoClient): - if not any(cls.__name__ == "MongoClient" for cls in client.__mro__): + if not any(cls.__name__ == "MongoClient" for cls in type(client).__mro__): raise TypeError(f"MongoClient required but given {type(client).__name__}") self.client = client From f644f2f540516ca07d2cea4c08a3a8a7792dc710 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 11 Sep 2024 10:18:48 -0400 Subject: [PATCH 4/5] Add comment --- pymongo/asynchronous/collection.py | 1 + pymongo/asynchronous/database.py | 1 + pymongo/asynchronous/encryption.py | 2 ++ pymongo/asynchronous/mongo_client.py | 1 + pymongo/synchronous/collection.py | 1 + pymongo/synchronous/database.py | 1 + pymongo/synchronous/encryption.py | 2 ++ pymongo/synchronous/mongo_client.py | 1 + 8 files changed, 10 insertions(+) diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 6e798f38d7..627d877e1b 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -231,6 +231,7 @@ def __init__( from pymongo.asynchronous.database import AsyncDatabase if not isinstance(database, AsyncDatabase): + # This is for compatibility with mocked and subclassed types, such as in Motor if not any(cls.__name__ == "AsyncDatabase" for cls in type(database).__mro__): raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index 36e598dc35..548f8da364 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -125,6 +125,7 @@ def __init__( raise TypeError("name must be an instance of str") if not isinstance(client, AsyncMongoClient): + # This is for compatibility with mocked and subclassed types, such as in Motor if not any(cls.__name__ == "AsyncMongoClient" for cls in type(client).__mro__): raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 0a9a3a9dbd..67c544b68b 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -597,6 +597,7 @@ def __init__( raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") if not isinstance(key_vault_client, AsyncMongoClient): + # This is for compatibility with mocked and subclassed types, such as in Motor if not any( cls.__name__ == "AsyncMongoClient" for cls in type(key_vault_client).__mro__ ): @@ -690,6 +691,7 @@ async def create_encrypted_collection( """ if not isinstance(database, AsyncDatabase): + # This is for compatibility with mocked and subclassed types, such as in Motor if not any(cls.__name__ == "AsyncDatabase" for cls in type(database).__mro__): raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 6850663a43..a3f5c2675e 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2446,6 +2446,7 @@ def __init__( self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession] ): if not isinstance(client, AsyncMongoClient): + # This is for compatibility with mocked and subclassed types, such as in Motor if not any(cls.__name__ == "AsyncMongoClient" for cls in type(client).__mro__): raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index e9f56b1176..1ea343b02c 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -234,6 +234,7 @@ def __init__( from pymongo.synchronous.database import Database if not isinstance(database, Database): + # This is for compatibility with mocked and subclassed types, such as in Motor if not any(cls.__name__ == "Database" for cls in type(database).__mro__): raise TypeError(f"Database required but given {type(database).__name__}") diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index 5b5cc77c62..a79aadddcd 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -125,6 +125,7 @@ def __init__( raise TypeError("name must be an instance of str") if not isinstance(client, MongoClient): + # This is for compatibility with mocked and subclassed types, such as in Motor if not any(cls.__name__ == "MongoClient" for cls in type(client).__mro__): raise TypeError(f"MongoClient required but given {type(client).__name__}") diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index 02d8775d89..7789da0942 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -595,6 +595,7 @@ def __init__( raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") if not isinstance(key_vault_client, MongoClient): + # This is for compatibility with mocked and subclassed types, such as in Motor if not any(cls.__name__ == "MongoClient" for cls in type(key_vault_client).__mro__): raise TypeError(f"MongoClient required but given {type(key_vault_client).__name__}") @@ -684,6 +685,7 @@ def create_encrypted_collection( """ if not isinstance(database, Database): + # This is for compatibility with mocked and subclassed types, such as in Motor if not any(cls.__name__ == "Database" for cls in type(database).__mro__): raise TypeError(f"Database required but given {type(database).__name__}") diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 16e6d5821a..bb94c028fc 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2434,6 +2434,7 @@ class _MongoClientErrorHandler: def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]): if not isinstance(client, MongoClient): + # This is for compatibility with mocked and subclassed types, such as in Motor if not any(cls.__name__ == "MongoClient" for cls in type(client).__mro__): raise TypeError(f"MongoClient required but given {type(client).__name__}") From bb3fe5d367f7203b1f4418430c825e88fd1f4363 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 11 Sep 2024 10:33:36 -0400 Subject: [PATCH 5/5] Formatting --- pymongo/asynchronous/collection.py | 2 +- pymongo/asynchronous/database.py | 2 +- pymongo/asynchronous/encryption.py | 4 ++-- pymongo/asynchronous/mongo_client.py | 2 +- pymongo/synchronous/collection.py | 2 +- pymongo/synchronous/database.py | 2 +- pymongo/synchronous/encryption.py | 4 ++-- pymongo/synchronous/mongo_client.py | 2 +- 8 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 627d877e1b..1ec74aad02 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -231,7 +231,7 @@ def __init__( from pymongo.asynchronous.database import AsyncDatabase if not isinstance(database, AsyncDatabase): - # This is for compatibility with mocked and subclassed types, such as in Motor + # This is for compatibility with mocked and subclassed types, such as in Motor. if not any(cls.__name__ == "AsyncDatabase" for cls in type(database).__mro__): raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index 548f8da364..06c0eca2c1 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -125,7 +125,7 @@ def __init__( raise TypeError("name must be an instance of str") if not isinstance(client, AsyncMongoClient): - # This is for compatibility with mocked and subclassed types, such as in Motor + # This is for compatibility with mocked and subclassed types, such as in Motor. if not any(cls.__name__ == "AsyncMongoClient" for cls in type(client).__mro__): raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 67c544b68b..9b00c13e10 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -597,7 +597,7 @@ def __init__( raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") if not isinstance(key_vault_client, AsyncMongoClient): - # This is for compatibility with mocked and subclassed types, such as in Motor + # This is for compatibility with mocked and subclassed types, such as in Motor. if not any( cls.__name__ == "AsyncMongoClient" for cls in type(key_vault_client).__mro__ ): @@ -691,7 +691,7 @@ async def create_encrypted_collection( """ if not isinstance(database, AsyncDatabase): - # This is for compatibility with mocked and subclassed types, such as in Motor + # This is for compatibility with mocked and subclassed types, such as in Motor. if not any(cls.__name__ == "AsyncDatabase" for cls in type(database).__mro__): raise TypeError(f"AsyncDatabase required but given {type(database).__name__}") diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index a3f5c2675e..9dba97d12a 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2446,7 +2446,7 @@ def __init__( self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession] ): if not isinstance(client, AsyncMongoClient): - # This is for compatibility with mocked and subclassed types, such as in Motor + # This is for compatibility with mocked and subclassed types, such as in Motor. if not any(cls.__name__ == "AsyncMongoClient" for cls in type(client).__mro__): raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}") diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 1ea343b02c..7a41aef31f 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -234,7 +234,7 @@ def __init__( from pymongo.synchronous.database import Database if not isinstance(database, Database): - # This is for compatibility with mocked and subclassed types, such as in Motor + # This is for compatibility with mocked and subclassed types, such as in Motor. if not any(cls.__name__ == "Database" for cls in type(database).__mro__): raise TypeError(f"Database required but given {type(database).__name__}") diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index a79aadddcd..c57a59e09a 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -125,7 +125,7 @@ def __init__( raise TypeError("name must be an instance of str") if not isinstance(client, MongoClient): - # This is for compatibility with mocked and subclassed types, such as in Motor + # This is for compatibility with mocked and subclassed types, such as in Motor. if not any(cls.__name__ == "MongoClient" for cls in type(client).__mro__): raise TypeError(f"MongoClient required but given {type(client).__name__}") diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index 7789da0942..efef6df9e8 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -595,7 +595,7 @@ def __init__( raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") if not isinstance(key_vault_client, MongoClient): - # This is for compatibility with mocked and subclassed types, such as in Motor + # This is for compatibility with mocked and subclassed types, such as in Motor. if not any(cls.__name__ == "MongoClient" for cls in type(key_vault_client).__mro__): raise TypeError(f"MongoClient required but given {type(key_vault_client).__name__}") @@ -685,7 +685,7 @@ def create_encrypted_collection( """ if not isinstance(database, Database): - # This is for compatibility with mocked and subclassed types, such as in Motor + # This is for compatibility with mocked and subclassed types, such as in Motor. if not any(cls.__name__ == "Database" for cls in type(database).__mro__): raise TypeError(f"Database required but given {type(database).__name__}") diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index bb94c028fc..21fa57b5d8 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -2434,7 +2434,7 @@ class _MongoClientErrorHandler: def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]): if not isinstance(client, MongoClient): - # This is for compatibility with mocked and subclassed types, such as in Motor + # This is for compatibility with mocked and subclassed types, such as in Motor. if not any(cls.__name__ == "MongoClient" for cls in type(client).__mro__): raise TypeError(f"MongoClient required but given {type(client).__name__}")