Skip to content

Commit 9556f8c

Browse files
committed
PYTHON-5143 Support auto encryption in unified tests
1 parent 37975cb commit 9556f8c

File tree

12 files changed

+1331
-55
lines changed

12 files changed

+1331
-55
lines changed

.evergreen/remove-unimplemented-tests.sh

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,6 @@ PYMONGO=$(dirname "$(cd "$(dirname "$0")" || exit; pwd)")
33

44
rm $PYMONGO/test/transactions/legacy/errors-client.json # PYTHON-1894
55
rm $PYMONGO/test/connection_monitoring/wait-queue-fairness.json # PYTHON-1873
6-
rm $PYMONGO/test/client-side-encryption/spec/unified/fle2v2-BypassQueryAnalysis.json # PYTHON-5143
7-
rm $PYMONGO/test/client-side-encryption/spec/unified/fle2v2-EncryptedFields-vs-EncryptedFieldsMap.json # PYTHON-5143
8-
rm $PYMONGO/test/client-side-encryption/spec/unified/localSchema.json # PYTHON-5143
9-
rm $PYMONGO/test/client-side-encryption/spec/unified/maxWireVersion.json # PYTHON-5143
10-
rm $PYMONGO/test/unified-test-format/valid-pass/poc-queryable-encryption.json # PYTHON-5143
116
rm $PYMONGO/test/discovery_and_monitoring/unified/pool-clear-application-error.json # PYTHON-4918
127
rm $PYMONGO/test/discovery_and_monitoring/unified/pool-clear-checkout-error.json # PYTHON-4918
138
rm $PYMONGO/test/discovery_and_monitoring/unified/pool-clear-min-pool-size-error.json # PYTHON-4918

test/asynchronous/helpers.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,15 @@
116116
}
117117
KMIP_CREDS = {"endpoint": os.environ.get("FLE_KMIP_ENDPOINT", "localhost:5698")}
118118

119+
ALL_KMS_PROVIDERS = dict(
120+
aws=AWS_CREDS,
121+
azure=AZURE_CREDS,
122+
gcp=GCP_CREDS,
123+
local=dict(key=LOCAL_MASTER_KEY),
124+
kmip=KMIP_CREDS,
125+
)
126+
DEFAULT_KMS_TLS = dict(kmip=dict(tlsCAFile=CA_PEM, tlsCertificateKeyFile=CLIENT_PEM))
127+
119128
# Ensure Evergreen metadata doesn't result in truncation
120129
os.environ.setdefault("MONGOB_LOG_MAX_DOCUMENT_LENGTH", "2000")
121130

test/asynchronous/test_encryption.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,20 @@
5454
from test import (
5555
unittest,
5656
)
57-
from test.asynchronous.test_bulk import AsyncBulkTestBase
58-
from test.asynchronous.unified_format import generate_test_classes
59-
from test.asynchronous.utils_spec_runner import AsyncSpecRunner
60-
from test.helpers import (
57+
from test.asynchronous.helpers import (
58+
ALL_KMS_PROVIDERS,
6159
AWS_CREDS,
6260
AZURE_CREDS,
6361
CA_PEM,
6462
CLIENT_PEM,
63+
DEFAULT_KMS_TLS,
6564
GCP_CREDS,
6665
KMIP_CREDS,
6766
LOCAL_MASTER_KEY,
6867
)
68+
from test.asynchronous.test_bulk import AsyncBulkTestBase
69+
from test.asynchronous.unified_format import generate_test_classes
70+
from test.asynchronous.utils_spec_runner import AsyncSpecRunner
6971
from test.utils_shared import (
7072
AllowListEventListener,
7173
OvertCommandListener,
@@ -204,7 +206,7 @@ async def test_init_kms_tls_options(self):
204206
opts = AutoEncryptionOpts(
205207
{},
206208
"k.d",
207-
kms_tls_options={"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}},
209+
kms_tls_options=DEFAULT_KMS_TLS,
208210
)
209211
_kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
210212
ctx = _kms_ssl_contexts["kmip"]
@@ -626,7 +628,6 @@ async def test_with_statement(self):
626628
"accessKeyId": os.environ.get("CSFLE_AWS_TEMP_ACCESS_KEY_ID", ""),
627629
"secretAccessKey": os.environ.get("CSFLE_AWS_TEMP_SECRET_ACCESS_KEY", ""),
628630
}
629-
KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}}
630631

631632

632633
class AsyncTestSpec(AsyncSpecRunner):
@@ -663,7 +664,7 @@ def parse_auto_encrypt_opts(self, opts):
663664
self.skipTest("GCP environment credentials are not set")
664665
if "kmip" in kms_providers:
665666
kms_providers["kmip"] = KMIP_CREDS
666-
opts["kms_tls_options"] = KMS_TLS_OPTS
667+
opts["kms_tls_options"] = DEFAULT_KMS_TLS
667668
if "key_vault_namespace" not in opts:
668669
opts["key_vault_namespace"] = "keyvault.datakeys"
669670
if "extra_options" in opts:
@@ -757,14 +758,6 @@ async def run_scenario(self):
757758
)
758759

759760
# Prose Tests
760-
ALL_KMS_PROVIDERS = {
761-
"aws": AWS_CREDS,
762-
"azure": AZURE_CREDS,
763-
"gcp": GCP_CREDS,
764-
"kmip": KMIP_CREDS,
765-
"local": {"key": LOCAL_MASTER_KEY},
766-
}
767-
768761
LOCAL_KEY_ID = Binary(base64.b64decode(b"LOCALAAAAAAAAAAAAAAAAA=="), UUID_SUBTYPE)
769762
AWS_KEY_ID = Binary(base64.b64decode(b"AWSAAAAAAAAAAAAAAAAAAA=="), UUID_SUBTYPE)
770763
AZURE_KEY_ID = Binary(base64.b64decode(b"AZUREAAAAAAAAAAAAAAAAA=="), UUID_SUBTYPE)
@@ -851,13 +844,17 @@ async def asyncSetUp(self):
851844
self.KMS_PROVIDERS,
852845
"keyvault.datakeys",
853846
schema_map=schemas,
854-
kms_tls_options=KMS_TLS_OPTS,
847+
kms_tls_options=DEFAULT_KMS_TLS,
855848
)
856849
self.client_encrypted = await self.async_rs_or_single_client(
857850
auto_encryption_opts=opts, uuidRepresentation="standard"
858851
)
859852
self.client_encryption = self.create_client_encryption(
860-
self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS
853+
self.KMS_PROVIDERS,
854+
"keyvault.datakeys",
855+
self.client,
856+
OPTS,
857+
kms_tls_options=DEFAULT_KMS_TLS,
861858
)
862859
self.listener.reset()
863860

@@ -1066,7 +1063,7 @@ async def _test_corpus(self, opts):
10661063
"keyvault.datakeys",
10671064
async_client_context.client,
10681065
OPTS,
1069-
kms_tls_options=KMS_TLS_OPTS,
1066+
kms_tls_options=DEFAULT_KMS_TLS,
10701067
)
10711068

10721069
corpus = self.fix_up_curpus(json_data("corpus", "corpus.json"))
@@ -1158,7 +1155,7 @@ async def _test_corpus(self, opts):
11581155

11591156
async def test_corpus(self):
11601157
opts = AutoEncryptionOpts(
1161-
self.kms_providers(), "keyvault.datakeys", kms_tls_options=KMS_TLS_OPTS
1158+
self.kms_providers(), "keyvault.datakeys", kms_tls_options=DEFAULT_KMS_TLS
11621159
)
11631160
await self._test_corpus(opts)
11641161

@@ -1169,7 +1166,7 @@ async def test_corpus_local_schema(self):
11691166
self.kms_providers(),
11701167
"keyvault.datakeys",
11711168
schema_map=schemas,
1172-
kms_tls_options=KMS_TLS_OPTS,
1169+
kms_tls_options=DEFAULT_KMS_TLS,
11731170
)
11741171
await self._test_corpus(opts)
11751172

@@ -1300,7 +1297,7 @@ async def asyncSetUp(self):
13001297
key_vault_namespace="keyvault.datakeys",
13011298
key_vault_client=async_client_context.client,
13021299
codec_options=OPTS,
1303-
kms_tls_options=KMS_TLS_OPTS,
1300+
kms_tls_options=DEFAULT_KMS_TLS,
13041301
)
13051302

13061303
kms_providers_invalid = copy.deepcopy(kms_providers)
@@ -1312,7 +1309,7 @@ async def asyncSetUp(self):
13121309
key_vault_namespace="keyvault.datakeys",
13131310
key_vault_client=async_client_context.client,
13141311
codec_options=OPTS,
1315-
kms_tls_options=KMS_TLS_OPTS,
1312+
kms_tls_options=DEFAULT_KMS_TLS,
13161313
)
13171314
self._kmip_host_error = None
13181315
self._invalid_host_error = None
@@ -2752,7 +2749,7 @@ async def run_test(self, src_provider, dst_provider):
27522749
key_vault_client=self.client,
27532750
key_vault_namespace="keyvault.datakeys",
27542751
kms_providers=ALL_KMS_PROVIDERS,
2755-
kms_tls_options=KMS_TLS_OPTS,
2752+
kms_tls_options=DEFAULT_KMS_TLS,
27562753
codec_options=OPTS,
27572754
)
27582755

@@ -2772,7 +2769,7 @@ async def run_test(self, src_provider, dst_provider):
27722769
key_vault_client=client2,
27732770
key_vault_namespace="keyvault.datakeys",
27742771
kms_providers=ALL_KMS_PROVIDERS,
2775-
kms_tls_options=KMS_TLS_OPTS,
2772+
kms_tls_options=DEFAULT_KMS_TLS,
27762773
codec_options=OPTS,
27772774
)
27782775

test/asynchronous/unified_format.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
)
3838
from test.asynchronous.utils import async_get_pool, flaky
3939
from test.asynchronous.utils_spec_runner import SpecRunnerTask
40+
from test.helpers import ALL_KMS_PROVIDERS, DEFAULT_KMS_TLS
4041
from test.unified_format_shared import (
4142
KMS_TLS_OPTS,
4243
PLACEHOLDER_MAP,
@@ -61,6 +62,8 @@
6162
from test.version import Version
6263
from typing import Any, Dict, List, Mapping, Optional
6364

65+
import pytest
66+
6467
import pymongo
6568
from bson import SON, json_util
6669
from bson.codec_options import DEFAULT_CODEC_OPTIONS
@@ -76,7 +79,7 @@
7679
from pymongo.asynchronous.encryption import AsyncClientEncryption
7780
from pymongo.asynchronous.helpers import anext
7881
from pymongo.driver_info import DriverInfo
79-
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT
82+
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts
8083
from pymongo.errors import (
8184
AutoReconnect,
8285
BulkWriteError,
@@ -259,6 +262,23 @@ async def _create_entity(self, entity_spec, uri=None):
259262
kwargs: dict = {}
260263
observe_events = spec.get("observeEvents", [])
261264

265+
if "autoEncryptOpts" in spec:
266+
auto_encrypt_opts = spec["autoEncryptOpts"]
267+
auto_encrypt_kwargs: dict = dict(kms_tls_options=DEFAULT_KMS_TLS)
268+
kms_providers = ALL_KMS_PROVIDERS.copy()
269+
key_vault_namespace = auto_encrypt_opts.pop("keyVaultNamespace")
270+
for provider_name, provider_value in auto_encrypt_opts.pop("kmsProviders").items():
271+
kms_providers[provider_name].update(provider_value)
272+
extra_opts = auto_encrypt_opts.pop("extraOptions", {})
273+
for key, value in extra_opts.items():
274+
auto_encrypt_kwargs[camel_to_snake(key)] = value
275+
for key, value in auto_encrypt_opts.items():
276+
auto_encrypt_kwargs[camel_to_snake(key)] = value
277+
auto_encryption_opts = AutoEncryptionOpts(
278+
kms_providers, key_vault_namespace, **auto_encrypt_kwargs
279+
)
280+
kwargs["auto_encryption_opts"] = auto_encryption_opts
281+
262282
# The unified tests use topologyOpeningEvent, we use topologyOpenedEvent
263283
for i in range(len(observe_events)):
264284
if "topologyOpeningEvent" == observe_events[i]:
@@ -430,7 +450,7 @@ class UnifiedSpecTestMixinV1(AsyncIntegrationTest):
430450
a class attribute ``TEST_SPEC``.
431451
"""
432452

433-
SCHEMA_VERSION = Version.from_string("1.22")
453+
SCHEMA_VERSION = Version.from_string("1.23")
434454
RUN_ON_LOAD_BALANCER = True
435455
TEST_SPEC: Any
436456
TEST_PATH = "" # This gets filled in by generate_test_classes
@@ -1516,7 +1536,14 @@ class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)): # type: ignore
15161536
TEST_SPEC = test_spec
15171537
EXPECTED_FAILURES = expected_failures
15181538

1519-
return SpecTestBase
1539+
base = SpecTestBase
1540+
1541+
# Add "encryption" marker if the "csfle" runOnRequirement is set.
1542+
for req in test_spec.get("runOnRequirements", []):
1543+
if req.get("csfle", False):
1544+
base = pytest.mark.encryption(base)
1545+
1546+
return base
15201547

15211548
for dirpath, _, filenames in os.walk(test_path):
15221549
dirname = os.path.split(dirpath)[-1]

0 commit comments

Comments
 (0)