Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions .evergreen/remove-unimplemented-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,6 @@ PYMONGO=$(dirname "$(cd "$(dirname "$0")" || exit; pwd)")

rm $PYMONGO/test/transactions/legacy/errors-client.json # PYTHON-1894
rm $PYMONGO/test/connection_monitoring/wait-queue-fairness.json # PYTHON-1873
rm $PYMONGO/test/client-side-encryption/spec/unified/fle2v2-BypassQueryAnalysis.json # PYTHON-5143
rm $PYMONGO/test/client-side-encryption/spec/unified/fle2v2-EncryptedFields-vs-EncryptedFieldsMap.json # PYTHON-5143
rm $PYMONGO/test/client-side-encryption/spec/unified/localSchema.json # PYTHON-5143
rm $PYMONGO/test/client-side-encryption/spec/unified/maxWireVersion.json # PYTHON-5143
rm $PYMONGO/test/unified-test-format/valid-pass/poc-queryable-encryption.json # PYTHON-5143
rm $PYMONGO/test/discovery_and_monitoring/unified/pool-clear-application-error.json # PYTHON-4918
rm $PYMONGO/test/discovery_and_monitoring/unified/pool-clear-checkout-error.json # PYTHON-4918
rm $PYMONGO/test/discovery_and_monitoring/unified/pool-clear-min-pool-size-error.json # PYTHON-4918
Expand Down
21 changes: 20 additions & 1 deletion test/asynchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import unittest
import warnings
from inspect import iscoroutinefunction
from pathlib import Path

from pymongo._asyncio_task import create_task

Expand Down Expand Up @@ -69,7 +70,11 @@
db_user = os.environ.get("DB_USER", "user")
db_pwd = os.environ.get("DB_PASSWORD", "password")

CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates")
HERE = Path(__file__).absolute()
if _IS_SYNC:
CERT_PATH = str(HERE.parent / "certificates")
else:
CERT_PATH = str(HERE.parent.parent / "certificates")
CLIENT_PEM = os.environ.get("CLIENT_PEM", os.path.join(CERT_PATH, "client.pem"))
CA_PEM = os.environ.get("CA_PEM", os.path.join(CERT_PATH, "ca.pem"))

Expand Down Expand Up @@ -115,6 +120,20 @@
"privateKey": os.environ.get("FLE_GCP_PRIVATEKEY", ""),
}
KMIP_CREDS = {"endpoint": os.environ.get("FLE_KMIP_ENDPOINT", "localhost:5698")}
AWS_TEMP_CREDS = {
"accessKeyId": os.environ.get("CSFLE_AWS_TEMP_ACCESS_KEY_ID", ""),
"secretAccessKey": os.environ.get("CSFLE_AWS_TEMP_SECRET_ACCESS_KEY", ""),
"sessionToken": os.environ.get("CSFLE_AWS_TEMP_SESSION_TOKEN", ""),
}

ALL_KMS_PROVIDERS = dict(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these need to be here where they'll be duplicated by synchro? Or can they go in a shared file in test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you suggesting we move the rest of the constants as well?

Copy link
Contributor

@NoahStapp NoahStapp Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any constant (or function, value, anything really) that doesn't change based on async/sync should live outside of the synchro'd directories, yeah.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'll refactor a bit

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, let's see how the tests do

aws=AWS_CREDS,
azure=AZURE_CREDS,
gcp=GCP_CREDS,
local=dict(key=LOCAL_MASTER_KEY),
kmip=KMIP_CREDS,
)
DEFAULT_KMS_TLS = dict(kmip=dict(tlsCAFile=CA_PEM, tlsCertificateKeyFile=CLIENT_PEM))

# Ensure Evergreen metadata doesn't result in truncation
os.environ.setdefault("MONGOB_LOG_MAX_DOCUMENT_LENGTH", "2000")
Expand Down
52 changes: 22 additions & 30 deletions test/asynchronous/test_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,21 @@
from test import (
unittest,
)
from test.asynchronous.test_bulk import AsyncBulkTestBase
from test.asynchronous.unified_format import generate_test_classes
from test.asynchronous.utils_spec_runner import AsyncSpecRunner
from test.helpers import (
from test.asynchronous.helpers import (
ALL_KMS_PROVIDERS,
AWS_CREDS,
AWS_TEMP_CREDS,
AZURE_CREDS,
CA_PEM,
CLIENT_PEM,
DEFAULT_KMS_TLS,
GCP_CREDS,
KMIP_CREDS,
LOCAL_MASTER_KEY,
)
from test.asynchronous.test_bulk import AsyncBulkTestBase
from test.asynchronous.unified_format import generate_test_classes
from test.asynchronous.utils_spec_runner import AsyncSpecRunner
from test.utils_shared import (
AllowListEventListener,
OvertCommandListener,
Expand Down Expand Up @@ -204,7 +207,7 @@ async def test_init_kms_tls_options(self):
opts = AutoEncryptionOpts(
{},
"k.d",
kms_tls_options={"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}},
kms_tls_options=DEFAULT_KMS_TLS,
)
_kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
ctx = _kms_ssl_contexts["kmip"]
Expand Down Expand Up @@ -616,17 +619,10 @@ async def test_with_statement(self):


# Spec tests
AWS_TEMP_CREDS = {
"accessKeyId": os.environ.get("CSFLE_AWS_TEMP_ACCESS_KEY_ID", ""),
"secretAccessKey": os.environ.get("CSFLE_AWS_TEMP_SECRET_ACCESS_KEY", ""),
"sessionToken": os.environ.get("CSFLE_AWS_TEMP_SESSION_TOKEN", ""),
}

AWS_TEMP_NO_SESSION_CREDS = {
"accessKeyId": os.environ.get("CSFLE_AWS_TEMP_ACCESS_KEY_ID", ""),
"secretAccessKey": os.environ.get("CSFLE_AWS_TEMP_SECRET_ACCESS_KEY", ""),
}
KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}}


class AsyncTestSpec(AsyncSpecRunner):
Expand Down Expand Up @@ -663,7 +659,7 @@ def parse_auto_encrypt_opts(self, opts):
self.skipTest("GCP environment credentials are not set")
if "kmip" in kms_providers:
kms_providers["kmip"] = KMIP_CREDS
opts["kms_tls_options"] = KMS_TLS_OPTS
opts["kms_tls_options"] = DEFAULT_KMS_TLS
if "key_vault_namespace" not in opts:
opts["key_vault_namespace"] = "keyvault.datakeys"
if "extra_options" in opts:
Expand Down Expand Up @@ -757,14 +753,6 @@ async def run_scenario(self):
)

# Prose Tests
ALL_KMS_PROVIDERS = {
"aws": AWS_CREDS,
"azure": AZURE_CREDS,
"gcp": GCP_CREDS,
"kmip": KMIP_CREDS,
"local": {"key": LOCAL_MASTER_KEY},
}

LOCAL_KEY_ID = Binary(base64.b64decode(b"LOCALAAAAAAAAAAAAAAAAA=="), UUID_SUBTYPE)
AWS_KEY_ID = Binary(base64.b64decode(b"AWSAAAAAAAAAAAAAAAAAAA=="), UUID_SUBTYPE)
AZURE_KEY_ID = Binary(base64.b64decode(b"AZUREAAAAAAAAAAAAAAAAA=="), UUID_SUBTYPE)
Expand Down Expand Up @@ -851,13 +839,17 @@ async def asyncSetUp(self):
self.KMS_PROVIDERS,
"keyvault.datakeys",
schema_map=schemas,
kms_tls_options=KMS_TLS_OPTS,
kms_tls_options=DEFAULT_KMS_TLS,
)
self.client_encrypted = await self.async_rs_or_single_client(
auto_encryption_opts=opts, uuidRepresentation="standard"
)
self.client_encryption = self.create_client_encryption(
self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS
self.KMS_PROVIDERS,
"keyvault.datakeys",
self.client,
OPTS,
kms_tls_options=DEFAULT_KMS_TLS,
)
self.listener.reset()

Expand Down Expand Up @@ -1066,7 +1058,7 @@ async def _test_corpus(self, opts):
"keyvault.datakeys",
async_client_context.client,
OPTS,
kms_tls_options=KMS_TLS_OPTS,
kms_tls_options=DEFAULT_KMS_TLS,
)

corpus = self.fix_up_curpus(json_data("corpus", "corpus.json"))
Expand Down Expand Up @@ -1158,7 +1150,7 @@ async def _test_corpus(self, opts):

async def test_corpus(self):
opts = AutoEncryptionOpts(
self.kms_providers(), "keyvault.datakeys", kms_tls_options=KMS_TLS_OPTS
self.kms_providers(), "keyvault.datakeys", kms_tls_options=DEFAULT_KMS_TLS
)
await self._test_corpus(opts)

Expand All @@ -1169,7 +1161,7 @@ async def test_corpus_local_schema(self):
self.kms_providers(),
"keyvault.datakeys",
schema_map=schemas,
kms_tls_options=KMS_TLS_OPTS,
kms_tls_options=DEFAULT_KMS_TLS,
)
await self._test_corpus(opts)

Expand Down Expand Up @@ -1300,7 +1292,7 @@ async def asyncSetUp(self):
key_vault_namespace="keyvault.datakeys",
key_vault_client=async_client_context.client,
codec_options=OPTS,
kms_tls_options=KMS_TLS_OPTS,
kms_tls_options=DEFAULT_KMS_TLS,
)

kms_providers_invalid = copy.deepcopy(kms_providers)
Expand All @@ -1312,7 +1304,7 @@ async def asyncSetUp(self):
key_vault_namespace="keyvault.datakeys",
key_vault_client=async_client_context.client,
codec_options=OPTS,
kms_tls_options=KMS_TLS_OPTS,
kms_tls_options=DEFAULT_KMS_TLS,
)
self._kmip_host_error = None
self._invalid_host_error = None
Expand Down Expand Up @@ -2752,7 +2744,7 @@ async def run_test(self, src_provider, dst_provider):
key_vault_client=self.client,
key_vault_namespace="keyvault.datakeys",
kms_providers=ALL_KMS_PROVIDERS,
kms_tls_options=KMS_TLS_OPTS,
kms_tls_options=DEFAULT_KMS_TLS,
codec_options=OPTS,
)

Expand All @@ -2772,7 +2764,7 @@ async def run_test(self, src_provider, dst_provider):
key_vault_client=client2,
key_vault_namespace="keyvault.datakeys",
kms_providers=ALL_KMS_PROVIDERS,
kms_tls_options=KMS_TLS_OPTS,
kms_tls_options=DEFAULT_KMS_TLS,
codec_options=OPTS,
)

Expand Down
41 changes: 38 additions & 3 deletions test/asynchronous/unified_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from test.asynchronous.utils import async_get_pool, flaky
from test.asynchronous.utils_spec_runner import SpecRunnerTask
from test.helpers import ALL_KMS_PROVIDERS, DEFAULT_KMS_TLS
from test.unified_format_shared import (
KMS_TLS_OPTS,
PLACEHOLDER_MAP,
Expand All @@ -61,6 +62,8 @@
from test.version import Version
from typing import Any, Dict, List, Mapping, Optional

import pytest

import pymongo
from bson import SON, json_util
from bson.codec_options import DEFAULT_CODEC_OPTIONS
Expand All @@ -76,7 +79,7 @@
from pymongo.asynchronous.encryption import AsyncClientEncryption
from pymongo.asynchronous.helpers import anext
from pymongo.driver_info import DriverInfo
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts
from pymongo.errors import (
AutoReconnect,
BulkWriteError,
Expand Down Expand Up @@ -259,6 +262,23 @@ async def _create_entity(self, entity_spec, uri=None):
kwargs: dict = {}
observe_events = spec.get("observeEvents", [])

if "autoEncryptOpts" in spec:
auto_encrypt_opts = spec["autoEncryptOpts"].copy()
auto_encrypt_kwargs: dict = dict(kms_tls_options=DEFAULT_KMS_TLS)
kms_providers = ALL_KMS_PROVIDERS.copy()
key_vault_namespace = auto_encrypt_opts.pop("keyVaultNamespace")
for provider_name, provider_value in auto_encrypt_opts.pop("kmsProviders").items():
kms_providers[provider_name].update(provider_value)
extra_opts = auto_encrypt_opts.pop("extraOptions", {})
for key, value in extra_opts.items():
auto_encrypt_kwargs[camel_to_snake(key)] = value
for key, value in auto_encrypt_opts.items():
auto_encrypt_kwargs[camel_to_snake(key)] = value
auto_encryption_opts = AutoEncryptionOpts(
kms_providers, key_vault_namespace, **auto_encrypt_kwargs
)
kwargs["auto_encryption_opts"] = auto_encryption_opts

# The unified tests use topologyOpeningEvent, we use topologyOpenedEvent
for i in range(len(observe_events)):
if "topologyOpeningEvent" == observe_events[i]:
Expand Down Expand Up @@ -430,7 +450,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
a class attribute ``TEST_SPEC``.
"""

SCHEMA_VERSION = Version.from_string("1.22")
SCHEMA_VERSION = Version.from_string("1.23")
RUN_ON_LOAD_BALANCER = True
TEST_SPEC: Any
TEST_PATH = "" # This gets filled in by generate_test_classes
Expand Down Expand Up @@ -462,6 +482,13 @@ async def insert_initial_data(self, initial_data):
wc = WriteConcern(w="majority")
else:
wc = WriteConcern(w=1)

# Remove any encryption collections associated with the collection.
collections = await db.list_collection_names()
for collection in collections:
if collection in [f"enxcol_.{coll_name}.esc", f"enxcol_.{coll_name}.ecoc"]:
await db.drop_collection(collection)

if documents:
if opts:
await db.create_collection(coll_name, **opts)
Expand All @@ -488,6 +515,7 @@ def tearDownClass(cls) -> None:
async def asyncSetUp(self):
# super call creates internal client cls.client
await super().asyncSetUp()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unintended whitespace?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

# process file-level runOnRequirements
run_on_spec = self.TEST_SPEC.get("runOnRequirements", [])
if not await self.should_run_on(run_on_spec):
Expand Down Expand Up @@ -1516,7 +1544,14 @@ class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)): # type: ignore
TEST_SPEC = test_spec
EXPECTED_FAILURES = expected_failures

return SpecTestBase
base = SpecTestBase

# Add "encryption" marker if the "csfle" runOnRequirement is set.
for req in test_spec.get("runOnRequirements", []):
if req.get("csfle", False):
base = pytest.mark.encryption(base)

return base

for dirpath, _, filenames in os.walk(test_path):
dirname = os.path.split(dirpath)[-1]
Expand Down
Loading
Loading