Skip to content

Commit 9f83a75

Browse files
committed
Merge branch 'master' of github.com:mongodb/mongo-python-driver
2 parents 4b6887a + 63d957c commit 9f83a75

File tree

9 files changed

+37
-17
lines changed

9 files changed

+37
-17
lines changed

pymongo/asynchronous/collection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,9 @@ def __init__(
231231
from pymongo.asynchronous.database import AsyncDatabase
232232

233233
if not isinstance(database, AsyncDatabase):
234-
raise TypeError(f"AsyncCollection requires an AsyncDatabase but {type(database)} given")
234+
# This is for compatibility with mocked and subclassed types, such as in Motor.
235+
if not any(cls.__name__ == "AsyncDatabase" for cls in type(database).__mro__):
236+
raise TypeError(f"AsyncDatabase required but given {type(database).__name__}")
235237

236238
if not name or ".." in name:
237239
raise InvalidName("collection names cannot be empty")

pymongo/asynchronous/database.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,9 @@ def __init__(
125125
raise TypeError("name must be an instance of str")
126126

127127
if not isinstance(client, AsyncMongoClient):
128-
raise TypeError(f"AsyncMongoClient required but given {type(client)}")
128+
# This is for compatibility with mocked and subclassed types, such as in Motor.
129+
if not any(cls.__name__ == "AsyncMongoClient" for cls in type(client).__mro__):
130+
raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}")
129131

130132
if name != "$external":
131133
_check_name(name)

pymongo/asynchronous/encryption.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,13 @@ def __init__(
597597
raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions")
598598

599599
if not isinstance(key_vault_client, AsyncMongoClient):
600-
raise TypeError(f"AsyncMongoClient required but given {type(key_vault_client)}")
600+
# This is for compatibility with mocked and subclassed types, such as in Motor.
601+
if not any(
602+
cls.__name__ == "AsyncMongoClient" for cls in type(key_vault_client).__mro__
603+
):
604+
raise TypeError(
605+
f"AsyncMongoClient required but given {type(key_vault_client).__name__}"
606+
)
601607

602608
self._kms_providers = kms_providers
603609
self._key_vault_namespace = key_vault_namespace
@@ -685,9 +691,9 @@ async def create_encrypted_collection(
685691
686692
"""
687693
if not isinstance(database, AsyncDatabase):
688-
raise TypeError(
689-
f"create_encrypted_collection() requires an AsyncDatabase but {type(database)} given"
690-
)
694+
# This is for compatibility with mocked and subclassed types, such as in Motor.
695+
if not any(cls.__name__ == "AsyncDatabase" for cls in type(database).__mro__):
696+
raise TypeError(f"AsyncDatabase required but given {type(database).__name__}")
691697

692698
encrypted_fields = deepcopy(encrypted_fields)
693699
for i, field in enumerate(encrypted_fields["fields"]):

pymongo/asynchronous/mongo_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2446,7 +2446,9 @@ def __init__(
24462446
self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession]
24472447
):
24482448
if not isinstance(client, AsyncMongoClient):
2449-
raise TypeError(f"AsyncMongoClient required but given {type(client)}")
2449+
# This is for compatibility with mocked and subclassed types, such as in Motor.
2450+
if not any(cls.__name__ == "AsyncMongoClient" for cls in type(client).__mro__):
2451+
raise TypeError(f"AsyncMongoClient required but given {type(client).__name__}")
24502452

24512453
self.client = client
24522454
self.server_address = server.description.address

pymongo/synchronous/collection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,9 @@ def __init__(
234234
from pymongo.synchronous.database import Database
235235

236236
if not isinstance(database, Database):
237-
raise TypeError(f"Collection requires a Database but {type(database)} given")
237+
# This is for compatibility with mocked and subclassed types, such as in Motor.
238+
if not any(cls.__name__ == "Database" for cls in type(database).__mro__):
239+
raise TypeError(f"Database required but given {type(database).__name__}")
238240

239241
if not name or ".." in name:
240242
raise InvalidName("collection names cannot be empty")

pymongo/synchronous/database.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,9 @@ def __init__(
125125
raise TypeError("name must be an instance of str")
126126

127127
if not isinstance(client, MongoClient):
128-
raise TypeError(f"MongoClient required but given {type(client)}")
128+
# This is for compatibility with mocked and subclassed types, such as in Motor.
129+
if not any(cls.__name__ == "MongoClient" for cls in type(client).__mro__):
130+
raise TypeError(f"MongoClient required but given {type(client).__name__}")
129131

130132
if name != "$external":
131133
_check_name(name)

pymongo/synchronous/encryption.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -595,7 +595,9 @@ def __init__(
595595
raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions")
596596

597597
if not isinstance(key_vault_client, MongoClient):
598-
raise TypeError(f"MongoClient required but given {type(key_vault_client)}")
598+
# This is for compatibility with mocked and subclassed types, such as in Motor.
599+
if not any(cls.__name__ == "MongoClient" for cls in type(key_vault_client).__mro__):
600+
raise TypeError(f"MongoClient required but given {type(key_vault_client).__name__}")
599601

600602
self._kms_providers = kms_providers
601603
self._key_vault_namespace = key_vault_namespace
@@ -683,9 +685,9 @@ def create_encrypted_collection(
683685
684686
"""
685687
if not isinstance(database, Database):
686-
raise TypeError(
687-
f"create_encrypted_collection() requires a Database but {type(database)} given"
688-
)
688+
# This is for compatibility with mocked and subclassed types, such as in Motor.
689+
if not any(cls.__name__ == "Database" for cls in type(database).__mro__):
690+
raise TypeError(f"Database required but given {type(database).__name__}")
689691

690692
encrypted_fields = deepcopy(encrypted_fields)
691693
for i, field in enumerate(encrypted_fields["fields"]):

pymongo/synchronous/mongo_client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2434,7 +2434,9 @@ class _MongoClientErrorHandler:
24342434

24352435
def __init__(self, client: MongoClient, server: Server, session: Optional[ClientSession]):
24362436
if not isinstance(client, MongoClient):
2437-
raise TypeError(f"MongoClient required but given {type(client)}")
2437+
# This is for compatibility with mocked and subclassed types, such as in Motor.
2438+
if not any(cls.__name__ == "MongoClient" for cls in type(client).__mro__):
2439+
raise TypeError(f"MongoClient required but given {type(client).__name__}")
24382440

24392441
self.client = client
24402442
self.server_address = server.description.address

test/unified_format.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ def _create_entity(self, entity_spec, uri=None):
580580
return
581581
elif entity_type == "database":
582582
client = self[spec["client"]]
583-
if not isinstance(client, MongoClient):
583+
if type(client).__name__ != "MongoClient":
584584
self.test.fail(
585585
"Expected entity {} to be of type MongoClient, got {}".format(
586586
spec["client"], type(client)
@@ -602,7 +602,7 @@ def _create_entity(self, entity_spec, uri=None):
602602
return
603603
elif entity_type == "session":
604604
client = self[spec["client"]]
605-
if not isinstance(client, MongoClient):
605+
if type(client).__name__ != "MongoClient":
606606
self.test.fail(
607607
"Expected entity {} to be of type MongoClient, got {}".format(
608608
spec["client"], type(client)
@@ -667,7 +667,7 @@ def create_entities_from_spec(self, entity_spec, uri=None):
667667

668668
def get_listener_for_client(self, client_name: str) -> EventListenerUtil:
669669
client = self[client_name]
670-
if not isinstance(client, MongoClient):
670+
if type(client).__name__ != "MongoClient":
671671
self.test.fail(
672672
f"Expected entity {client_name} to be of type MongoClient, got {type(client)}"
673673
)

0 commit comments

Comments
 (0)