Skip to content

Commit 4158326

Browse files
committed
PYTHON-4700 - Convert CSFLE tests to async
1 parent 044d92c commit 4158326

File tree

3 files changed

+249
-248
lines changed

3 files changed

+249
-248
lines changed

test/asynchronous/test_encryption.py

Lines changed: 124 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import uuid
3030
import warnings
3131
from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, async_client_context
32+
from test.asynchronous.utils_spec_runner import AsyncSpecRunner
3233
from threading import Thread
3334
from typing import Any, Dict, Mapping
3435

@@ -611,132 +612,130 @@ async def test_with_statement(self):
611612
KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}}
612613

613614

614-
if _IS_SYNC:
615-
# TODO: Add asynchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700)
616-
class TestSpec(SpecRunner):
617-
@classmethod
618-
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
619-
def setUpClass(cls):
620-
super().setUpClass()
621-
622-
def parse_auto_encrypt_opts(self, opts):
623-
"""Parse clientOptions.autoEncryptOpts."""
624-
opts = camel_to_snake_args(opts)
625-
kms_providers = opts["kms_providers"]
626-
if "aws" in kms_providers:
627-
kms_providers["aws"] = AWS_CREDS
628-
if not any(AWS_CREDS.values()):
629-
self.skipTest("AWS environment credentials are not set")
630-
if "awsTemporary" in kms_providers:
631-
kms_providers["aws"] = AWS_TEMP_CREDS
632-
del kms_providers["awsTemporary"]
633-
if not any(AWS_TEMP_CREDS.values()):
634-
self.skipTest("AWS Temp environment credentials are not set")
635-
if "awsTemporaryNoSessionToken" in kms_providers:
636-
kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS
637-
del kms_providers["awsTemporaryNoSessionToken"]
638-
if not any(AWS_TEMP_NO_SESSION_CREDS.values()):
639-
self.skipTest("AWS Temp environment credentials are not set")
640-
if "azure" in kms_providers:
641-
kms_providers["azure"] = AZURE_CREDS
642-
if not any(AZURE_CREDS.values()):
643-
self.skipTest("Azure environment credentials are not set")
644-
if "gcp" in kms_providers:
645-
kms_providers["gcp"] = GCP_CREDS
646-
if not any(AZURE_CREDS.values()):
647-
self.skipTest("GCP environment credentials are not set")
648-
if "kmip" in kms_providers:
649-
kms_providers["kmip"] = KMIP_CREDS
650-
opts["kms_tls_options"] = KMS_TLS_OPTS
651-
if "key_vault_namespace" not in opts:
652-
opts["key_vault_namespace"] = "keyvault.datakeys"
653-
if "extra_options" in opts:
654-
opts.update(camel_to_snake_args(opts.pop("extra_options")))
655-
656-
opts = dict(opts)
657-
return AutoEncryptionOpts(**opts)
658-
659-
def parse_client_options(self, opts):
660-
"""Override clientOptions parsing to support autoEncryptOpts."""
661-
encrypt_opts = opts.pop("autoEncryptOpts", None)
662-
if encrypt_opts:
663-
opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts)
664-
665-
return super().parse_client_options(opts)
666-
667-
def get_object_name(self, op):
668-
"""Default object is collection."""
669-
return op.get("object", "collection")
670-
671-
def maybe_skip_scenario(self, test):
672-
super().maybe_skip_scenario(test)
673-
desc = test["description"].lower()
674-
if (
675-
"timeoutms applied to listcollections to get collection schema" in desc
676-
and sys.platform in ("win32", "darwin")
677-
):
678-
self.skipTest("PYTHON-3706 flaky test on Windows/macOS")
679-
if "type=symbol" in desc:
680-
self.skipTest("PyMongo does not support the symbol type")
681-
682-
def setup_scenario(self, scenario_def):
683-
"""Override a test's setup."""
684-
key_vault_data = scenario_def["key_vault_data"]
685-
encrypted_fields = scenario_def["encrypted_fields"]
686-
json_schema = scenario_def["json_schema"]
687-
data = scenario_def["data"]
688-
coll = async_client_context.client.get_database("keyvault", codec_options=OPTS)[
689-
"datakeys"
690-
]
691-
coll.delete_many({})
692-
if key_vault_data:
693-
coll.insert_many(key_vault_data)
694-
695-
db_name = self.get_scenario_db_name(scenario_def)
696-
coll_name = self.get_scenario_coll_name(scenario_def)
697-
db = async_client_context.client.get_database(db_name, codec_options=OPTS)
698-
coll = db.drop_collection(coll_name, encrypted_fields=encrypted_fields)
699-
wc = WriteConcern(w="majority")
700-
kwargs: Dict[str, Any] = {}
701-
if json_schema:
702-
kwargs["validator"] = {"$jsonSchema": json_schema}
703-
kwargs["codec_options"] = OPTS
704-
if not data:
705-
kwargs["write_concern"] = wc
706-
if encrypted_fields:
707-
kwargs["encryptedFields"] = encrypted_fields
708-
db.create_collection(coll_name, **kwargs)
709-
coll = db[coll_name]
710-
if data:
711-
# Load data.
712-
coll.with_options(write_concern=wc).insert_many(scenario_def["data"])
713-
714-
def allowable_errors(self, op):
715-
"""Override expected error classes."""
716-
errors = super().allowable_errors(op)
717-
# An updateOne test expects encryption to error when no $ operator
718-
# appears but pymongo raises a client side ValueError in this case.
719-
if op["name"] == "updateOne":
720-
errors += (ValueError,)
721-
return errors
722-
723-
def create_test(scenario_def, test, name):
724-
@async_client_context.require_test_commands
725-
def run_scenario(self):
726-
self.run_scenario(scenario_def, test)
727-
728-
return run_scenario
729-
730-
test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy"))
731-
test_creator.create_tests()
732-
733-
if _HAVE_PYMONGOCRYPT:
734-
globals().update(
735-
generate_test_classes(
736-
os.path.join(SPEC_PATH, "unified"),
737-
module=__name__,
738-
)
615+
class AsyncTestSpec(AsyncSpecRunner):
616+
@classmethod
617+
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
618+
async def _setup_class(cls):
619+
await super()._setup_class()
620+
621+
def parse_auto_encrypt_opts(self, opts):
622+
"""Parse clientOptions.autoEncryptOpts."""
623+
opts = camel_to_snake_args(opts)
624+
kms_providers = opts["kms_providers"]
625+
if "aws" in kms_providers:
626+
kms_providers["aws"] = AWS_CREDS
627+
if not any(AWS_CREDS.values()):
628+
self.skipTest("AWS environment credentials are not set")
629+
if "awsTemporary" in kms_providers:
630+
kms_providers["aws"] = AWS_TEMP_CREDS
631+
del kms_providers["awsTemporary"]
632+
if not any(AWS_TEMP_CREDS.values()):
633+
self.skipTest("AWS Temp environment credentials are not set")
634+
if "awsTemporaryNoSessionToken" in kms_providers:
635+
kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS
636+
del kms_providers["awsTemporaryNoSessionToken"]
637+
if not any(AWS_TEMP_NO_SESSION_CREDS.values()):
638+
self.skipTest("AWS Temp environment credentials are not set")
639+
if "azure" in kms_providers:
640+
kms_providers["azure"] = AZURE_CREDS
641+
if not any(AZURE_CREDS.values()):
642+
self.skipTest("Azure environment credentials are not set")
643+
if "gcp" in kms_providers:
644+
kms_providers["gcp"] = GCP_CREDS
645+
if not any(AZURE_CREDS.values()):
646+
self.skipTest("GCP environment credentials are not set")
647+
if "kmip" in kms_providers:
648+
kms_providers["kmip"] = KMIP_CREDS
649+
opts["kms_tls_options"] = KMS_TLS_OPTS
650+
if "key_vault_namespace" not in opts:
651+
opts["key_vault_namespace"] = "keyvault.datakeys"
652+
if "extra_options" in opts:
653+
opts.update(camel_to_snake_args(opts.pop("extra_options")))
654+
655+
opts = dict(opts)
656+
return AutoEncryptionOpts(**opts)
657+
658+
def parse_client_options(self, opts):
659+
"""Override clientOptions parsing to support autoEncryptOpts."""
660+
encrypt_opts = opts.pop("autoEncryptOpts", None)
661+
if encrypt_opts:
662+
opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts)
663+
664+
return super().parse_client_options(opts)
665+
666+
def get_object_name(self, op):
667+
"""Default object is collection."""
668+
return op.get("object", "collection")
669+
670+
def maybe_skip_scenario(self, test):
671+
super().maybe_skip_scenario(test)
672+
desc = test["description"].lower()
673+
if (
674+
"timeoutms applied to listcollections to get collection schema" in desc
675+
and sys.platform in ("win32", "darwin")
676+
):
677+
self.skipTest("PYTHON-3706 flaky test on Windows/macOS")
678+
if "type=symbol" in desc:
679+
self.skipTest("PyMongo does not support the symbol type")
680+
681+
async def setup_scenario(self, scenario_def):
682+
"""Override a test's setup."""
683+
key_vault_data = scenario_def["key_vault_data"]
684+
encrypted_fields = scenario_def["encrypted_fields"]
685+
json_schema = scenario_def["json_schema"]
686+
data = scenario_def["data"]
687+
coll = async_client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"]
688+
await coll.delete_many({})
689+
if key_vault_data:
690+
await coll.insert_many(key_vault_data)
691+
692+
db_name = self.get_scenario_db_name(scenario_def)
693+
coll_name = self.get_scenario_coll_name(scenario_def)
694+
db = async_client_context.client.get_database(db_name, codec_options=OPTS)
695+
coll = await db.drop_collection(coll_name, encrypted_fields=encrypted_fields)
696+
wc = WriteConcern(w="majority")
697+
kwargs: Dict[str, Any] = {}
698+
if json_schema:
699+
kwargs["validator"] = {"$jsonSchema": json_schema}
700+
kwargs["codec_options"] = OPTS
701+
if not data:
702+
kwargs["write_concern"] = wc
703+
if encrypted_fields:
704+
kwargs["encryptedFields"] = encrypted_fields
705+
await db.create_collection(coll_name, **kwargs)
706+
coll = db[coll_name]
707+
if data:
708+
# Load data.
709+
await coll.with_options(write_concern=wc).insert_many(scenario_def["data"])
710+
711+
def allowable_errors(self, op):
712+
"""Override expected error classes."""
713+
errors = super().allowable_errors(op)
714+
# An updateOne test expects encryption to error when no $ operator
715+
# appears but pymongo raises a client side ValueError in this case.
716+
if op["name"] == "updateOne":
717+
errors += (ValueError,)
718+
return errors
719+
720+
721+
def create_test(scenario_def, test, name):
722+
@async_client_context.require_test_commands
723+
def run_scenario(self):
724+
self.run_scenario(scenario_def, test)
725+
726+
return run_scenario
727+
728+
729+
test_creator = SpecTestCreator(create_test, AsyncTestSpec, os.path.join(SPEC_PATH, "legacy"))
730+
test_creator.create_tests()
731+
732+
if _HAVE_PYMONGOCRYPT:
733+
globals().update(
734+
generate_test_classes(
735+
os.path.join(SPEC_PATH, "unified"),
736+
module=__name__,
739737
)
738+
)
740739

741740
# Prose Tests
742741
ALL_KMS_PROVIDERS = {

0 commit comments

Comments
 (0)