Skip to content

Commit 1fb56b0

Browse files
committed
add tests draft
1 parent c69c3bd commit 1fb56b0

File tree

4 files changed

+666
-8
lines changed

4 files changed

+666
-8
lines changed

pymongo/asynchronous/encryption.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
242242
)
243243
raise exc from final_err
244244

245-
async def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
245+
async def collection_info(self, database: str, filter: bytes) -> Optional[list[bytes]]:
246246
"""Get the collection info for a namespace.
247247
248248
The returned collection info is passed to libmongocrypt which reads
@@ -256,8 +256,11 @@ async def collection_info(self, database: str, filter: bytes) -> Optional[bytes]
256256
async with await self.client_ref()[database].list_collections(
257257
filter=RawBSONDocument(filter)
258258
) as cursor:
259+
lst = []
259260
async for doc in cursor:
260-
return _dict_to_bson(doc, False, _DATA_KEY_OPTS)
261+
lst.append(_dict_to_bson(doc, False, _DATA_KEY_OPTS))
262+
if len(lst) > 0:
263+
return lst
261264
return None
262265

263266
def spawn(self) -> None:
@@ -551,7 +554,7 @@ def _create_mongocrypt_options(**kwargs: Any) -> MongoCryptOptions:
551554
# For compat with pymongocrypt <1.13, avoid setting the default key_expiration_ms.
552555
if kwargs.get("key_expiration_ms") is None:
553556
kwargs.pop("key_expiration_ms", None)
554-
return MongoCryptOptions(**kwargs)
557+
return MongoCryptOptions(**kwargs, enable_multiple_collinfo=True)
555558

556559

557560
class AsyncClientEncryption(Generic[_DocumentType]):

pymongo/synchronous/encryption.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
241241
)
242242
raise exc from final_err
243243

244-
def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
244+
def collection_info(self, database: str, filter: bytes) -> Optional[list[bytes]]:
245245
"""Get the collection info for a namespace.
246246
247247
The returned collection info is passed to libmongocrypt which reads
@@ -253,8 +253,11 @@ def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
253253
:return: The first document from the listCollections command response as BSON.
254254
"""
255255
with self.client_ref()[database].list_collections(filter=RawBSONDocument(filter)) as cursor:
256+
lst = []
256257
for doc in cursor:
257-
return _dict_to_bson(doc, False, _DATA_KEY_OPTS)
258+
lst.append(_dict_to_bson(doc, False, _DATA_KEY_OPTS))
259+
if len(lst) > 0:
260+
return lst
258261
return None
259262

260263
def spawn(self) -> None:
@@ -548,7 +551,7 @@ def _create_mongocrypt_options(**kwargs: Any) -> MongoCryptOptions:
548551
# For compat with pymongocrypt <1.13, avoid setting the default key_expiration_ms.
549552
if kwargs.get("key_expiration_ms") is None:
550553
kwargs.pop("key_expiration_ms", None)
551-
return MongoCryptOptions(**kwargs)
554+
return MongoCryptOptions(**kwargs, enable_multiple_collinfo=True)
552555

553556

554557
class ClientEncryption(Generic[_DocumentType]):

0 commit comments

Comments
 (0)