Skip to content

PYTHON-5046 Support $lookup in CSFLE and QE #2210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions pymongo/asynchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
)
raise exc from final_err

async def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
async def collection_info(self, database: str, filter: bytes) -> Optional[list[bytes]]:
"""Get the collection info for a namespace.

The returned collection info is passed to libmongocrypt which reads
Expand All @@ -256,8 +256,11 @@ async def collection_info(self, database: str, filter: bytes) -> Optional[bytes]
async with await self.client_ref()[database].list_collections(
filter=RawBSONDocument(filter)
) as cursor:
lst = []
async for doc in cursor:
return _dict_to_bson(doc, False, _DATA_KEY_OPTS)
lst.append(_dict_to_bson(doc, False, _DATA_KEY_OPTS))
if len(lst) > 0:
return lst
return None

def spawn(self) -> None:
Expand Down Expand Up @@ -551,7 +554,7 @@ def _create_mongocrypt_options(**kwargs: Any) -> MongoCryptOptions:
# For compat with pymongocrypt <1.13, avoid setting the default key_expiration_ms.
if kwargs.get("key_expiration_ms") is None:
kwargs.pop("key_expiration_ms", None)
return MongoCryptOptions(**kwargs)
return MongoCryptOptions(**kwargs, enable_multiple_collinfo=True)


class AsyncClientEncryption(Generic[_DocumentType]):
Expand Down
9 changes: 6 additions & 3 deletions pymongo/synchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
)
raise exc from final_err

def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
def collection_info(self, database: str, filter: bytes) -> Optional[list[bytes]]:
"""Get the collection info for a namespace.

The returned collection info is passed to libmongocrypt which reads
Expand All @@ -253,8 +253,11 @@ def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
:return: The first document from the listCollections command response as BSON.
"""
with self.client_ref()[database].list_collections(filter=RawBSONDocument(filter)) as cursor:
lst = []
for doc in cursor:
return _dict_to_bson(doc, False, _DATA_KEY_OPTS)
lst.append(_dict_to_bson(doc, False, _DATA_KEY_OPTS))
if len(lst) > 0:
return lst
return None

def spawn(self) -> None:
Expand Down Expand Up @@ -548,7 +551,7 @@ def _create_mongocrypt_options(**kwargs: Any) -> MongoCryptOptions:
# For compat with pymongocrypt <1.13, avoid setting the default key_expiration_ms.
if kwargs.get("key_expiration_ms") is None:
kwargs.pop("key_expiration_ms", None)
return MongoCryptOptions(**kwargs)
return MongoCryptOptions(**kwargs, enable_multiple_collinfo=True)


class ClientEncryption(Generic[_DocumentType]):
Expand Down
315 changes: 314 additions & 1 deletion test/asynchronous/test_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
is_greenthread_patched,
)

from bson import DatetimeMS, Decimal128, encode, json_util
from bson import BSON, DatetimeMS, Decimal128, encode, json_util
from bson.binary import UUID_SUBTYPE, Binary, UuidRepresentation
from bson.codec_options import CodecOptions
from bson.errors import BSONError
Expand All @@ -94,6 +94,7 @@
EncryptionError,
InvalidOperation,
OperationFailure,
PyMongoError,
ServerSelectionTimeoutError,
WriteError,
)
Expand Down Expand Up @@ -2419,6 +2420,318 @@ async def test_05_roundtrip_encrypted_unindexed(self):
self.assertEqual(decrypted, val)


# https://github.com/mongodb/specifications/blob/527e22d5090ec48bf1e144c45fc831de0f1935f6/source/client-side-encryption/tests/README.md#25-test-lookup
class TestLookupProse(AsyncEncryptionIntegrationTest):
@async_client_context.require_no_standalone
@async_client_context.require_version_min(7, 0, -1)
async def asyncSetUp(self):
await super().asyncSetUp()
self.encrypted_client = await self.async_rs_or_single_client(
auto_encryption_opts=AutoEncryptionOpts(
key_vault_namespace="db.keyvault",
kms_providers={"local": {"key": LOCAL_MASTER_KEY}},
)
)
await self.encrypted_client.db.drop_collection("keyvault")

key_doc = json_data("etc", "data", "lookup", "key-doc.json")
key_vault = await create_key_vault(self.encrypted_client.db.keyvault, key_doc)
self.addCleanup(key_vault.drop)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addCleanup -> addAsyncCleanup

Also can we use the client_context client here to drop the entire database instead of just the keyvault? This will fix the InvalidOperation: Cannot use AsyncMongoClient after close" errors (which were introduced now that we close these clients).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively we can just leave the data there and skip the cleanup altogether.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good catch, and yes! done!


await self.encrypted_client.db.drop_collection("csfle")
await self.encrypted_client.db.create_collection(
"csfle",
validator={"$jsonSchema": json_data("etc", "data", "lookup", "schema-csfle.json")},
)

await self.encrypted_client.db.drop_collection("csfle2")
await self.encrypted_client.db.create_collection(
"csfle2",
validator={"$jsonSchema": json_data("etc", "data", "lookup", "schema-csfle2.json")},
)

await self.encrypted_client.db.drop_collection("qe")
await self.encrypted_client.db.create_collection(
"qe", encryptedFields=json_data("etc", "data", "lookup", "schema-qe.json")
)

await self.encrypted_client.db.drop_collection("qe2")
await self.encrypted_client.db.create_collection(
"qe2", encryptedFields=json_data("etc", "data", "lookup", "schema-qe2.json")
)

await self.encrypted_client.db.drop_collection("no_schema")
await self.encrypted_client.db.create_collection("no_schema")

await self.encrypted_client.db.drop_collection("no_schema2")
await self.encrypted_client.db.create_collection("no_schema2")

self.unencrypted_client = await self.async_rs_or_single_client()

await self.encrypted_client.db.csfle.insert_one({"csfle": "csfle"})
doc = await self.unencrypted_client.db.csfle.find_one()
self.assertTrue(isinstance(doc["csfle"], Binary))
await self.encrypted_client.db.csfle2.insert_one({"csfle2": "csfle2"})
doc = await self.unencrypted_client.db.csfle2.find_one()
self.assertTrue(isinstance(doc["csfle2"], Binary))
await self.encrypted_client.db.qe.insert_one({"qe": "qe"})
doc = await self.unencrypted_client.db.qe.find_one()
self.assertTrue(isinstance(doc["qe"], Binary))
await self.encrypted_client.db.qe2.insert_one({"qe2": "qe2"})
doc = await self.unencrypted_client.db.qe2.find_one()
self.assertTrue(isinstance(doc["qe2"], Binary))
await self.encrypted_client.db.no_schema.insert_one({"no_schema": "no_schema"})
await self.encrypted_client.db.no_schema2.insert_one({"no_schema2": "no_schema2"})

@async_client_context.require_version_min(8, 1, -1)
async def test_1(self):
self.encrypted_client = await self.async_rs_or_single_client(
auto_encryption_opts=AutoEncryptionOpts(
key_vault_namespace="db.keyvault",
kms_providers={"local": {"key": LOCAL_MASTER_KEY}},
)
)
doc = await anext(
await self.encrypted_client.db.csfle.aggregate(
[
{"$match": {"csfle": "csfle"}},
{
"$lookup": {
"from": "no_schema",
"as": "matched",
"pipeline": [
{"$match": {"no_schema": "no_schema"}},
{"$project": {"_id": 0}},
],
}
},
{"$project": {"_id": 0}},
]
)
)
self.assertEqual(doc, {"csfle": "csfle", "matched": [{"no_schema": "no_schema"}]})

@async_client_context.require_version_min(8, 1, -1)
async def test_2(self):
self.encrypted_client = await self.async_rs_or_single_client(
auto_encryption_opts=AutoEncryptionOpts(
key_vault_namespace="db.keyvault",
kms_providers={"local": {"key": LOCAL_MASTER_KEY}},
)
)
doc = await anext(
await self.encrypted_client.db.qe.aggregate(
[
{"$match": {"qe": "qe"}},
{
"$lookup": {
"from": "no_schema",
"as": "matched",
"pipeline": [
{"$match": {"no_schema": "no_schema"}},
{"$project": {"_id": 0, "__safeContent__": 0}},
],
}
},
{"$project": {"_id": 0, "__safeContent__": 0}},
]
)
)
self.assertEqual(doc, {"qe": "qe", "matched": [{"no_schema": "no_schema"}]})

@async_client_context.require_version_min(8, 1, -1)
async def test_3(self):
self.encrypted_client = await self.async_rs_or_single_client(
auto_encryption_opts=AutoEncryptionOpts(
key_vault_namespace="db.keyvault",
kms_providers={"local": {"key": LOCAL_MASTER_KEY}},
)
)
doc = await anext(
await self.encrypted_client.db.no_schema.aggregate(
[
{"$match": {"no_schema": "no_schema"}},
{
"$lookup": {
"from": "csfle",
"as": "matched",
"pipeline": [{"$match": {"csfle": "csfle"}}, {"$project": {"_id": 0}}],
}
},
{"$project": {"_id": 0}},
]
)
)
self.assertEqual(doc, {"no_schema": "no_schema", "matched": [{"csfle": "csfle"}]})

@async_client_context.require_version_min(8, 1, -1)
async def test_4(self):
self.encrypted_client = await self.async_rs_or_single_client(
auto_encryption_opts=AutoEncryptionOpts(
key_vault_namespace="db.keyvault",
kms_providers={"local": {"key": LOCAL_MASTER_KEY}},
)
)
doc = await anext(
await self.encrypted_client.db.no_schema.aggregate(
[
{"$match": {"no_schema": "no_schema"}},
{
"$lookup": {
"from": "qe",
"as": "matched",
"pipeline": [
{"$match": {"qe": "qe"}},
{"$project": {"_id": 0, "__safeContent__": 0}},
],
}
},
{"$project": {"_id": 0}},
]
)
)
self.assertEqual(doc, {"no_schema": "no_schema", "matched": [{"qe": "qe"}]})

@async_client_context.require_version_min(8, 1, -1)
async def test_5(self):
self.encrypted_client = await self.async_rs_or_single_client(
auto_encryption_opts=AutoEncryptionOpts(
key_vault_namespace="db.keyvault",
kms_providers={"local": {"key": LOCAL_MASTER_KEY}},
)
)
doc = await anext(
await self.encrypted_client.db.csfle.aggregate(
[
{"$match": {"csfle": "csfle"}},
{
"$lookup": {
"from": "csfle2",
"as": "matched",
"pipeline": [
{"$match": {"csfle2": "csfle2"}},
{"$project": {"_id": 0}},
],
}
},
{"$project": {"_id": 0}},
]
)
)
self.assertEqual(doc, {"csfle": "csfle", "matched": [{"csfle2": "csfle2"}]})

@async_client_context.require_version_min(8, 1, -1)
async def test_6(self):
self.encrypted_client = await self.async_rs_or_single_client(
auto_encryption_opts=AutoEncryptionOpts(
key_vault_namespace="db.keyvault",
kms_providers={"local": {"key": LOCAL_MASTER_KEY}},
)
)
doc = await anext(
await self.encrypted_client.db.qe.aggregate(
[
{"$match": {"qe": "qe"}},
{
"$lookup": {
"from": "qe2",
"as": "matched",
"pipeline": [
{"$match": {"qe2": "qe2"}},
{"$project": {"_id": 0, "__safeContent__": 0}},
],
}
},
{"$project": {"_id": 0, "__safeContent__": 0}},
]
)
)
self.assertEqual(doc, {"qe": "qe", "matched": [{"qe2": "qe2"}]})

@async_client_context.require_version_min(8, 1, -1)
async def test_7(self):
self.encrypted_client = await self.async_rs_or_single_client(
auto_encryption_opts=AutoEncryptionOpts(
key_vault_namespace="db.keyvault",
kms_providers={"local": {"key": LOCAL_MASTER_KEY}},
)
)
doc = await anext(
await self.encrypted_client.db.no_schema.aggregate(
[
{"$match": {"no_schema": "no_schema"}},
{
"$lookup": {
"from": "no_schema2",
"as": "matched",
"pipeline": [
{"$match": {"no_schema2": "no_schema2"}},
{"$project": {"_id": 0}},
],
}
},
{"$project": {"_id": 0}},
]
)
)
self.assertEqual(doc, {"no_schema": "no_schema", "matched": [{"no_schema2": "no_schema2"}]})

@async_client_context.require_version_min(8, 1, -1)
async def test_8(self):
self.encrypted_client = await self.async_rs_or_single_client(
auto_encryption_opts=AutoEncryptionOpts(
key_vault_namespace="db.keyvault",
kms_providers={"local": {"key": LOCAL_MASTER_KEY}},
)
)
with self.assertRaises(PyMongoError) as exc:
_ = await anext(
await self.encrypted_client.db.csfle.aggregate(
[
{"$match": {"csfle": "qe"}},
{
"$lookup": {
"from": "qe",
"as": "matched",
"pipeline": [{"$match": {"qe": "qe"}}, {"$project": {"_id": 0}}],
}
},
{"$project": {"_id": 0}},
]
)
)
self.assertTrue("not supported" in str(exc))

@async_client_context.require_version_max(8, 1, -1)
async def test_9(self):
self.encrypted_client = await self.async_rs_or_single_client(
auto_encryption_opts=AutoEncryptionOpts(
key_vault_namespace="db.keyvault",
kms_providers={"local": {"key": LOCAL_MASTER_KEY}},
)
)
with self.assertRaises(PyMongoError) as exc:
_ = await anext(
await self.encrypted_client.db.csfle.aggregate(
[
{"$match": {"csfle": "csfle"}},
{
"$lookup": {
"from": "no_schema",
"as": "matched",
"pipeline": [
{"$match": {"no_schema": "no_schema"}},
{"$project": {"_id": 0}},
],
}
},
{"$project": {"_id": 0}},
]
)
)
self.assertTrue("Upgrade" in str(exc))


# https://github.com/mongodb/specifications/blob/072601/source/client-side-encryption/tests/README.md#rewrap
class TestRewrapWithSeparateClientEncryption(AsyncEncryptionIntegrationTest):
MASTER_KEYS: Mapping[str, Mapping[str, Any]] = {
Expand Down
Loading
Loading