From 41583268ca3d49024f66c7262a99243771ef7f17 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 6 Sep 2024 10:32:48 -0400 Subject: [PATCH 01/16] PYTHON-4700 - Convert CSFLE tests to async --- test/asynchronous/test_encryption.py | 249 +++++++++++++-------------- test/test_encryption.py | 247 +++++++++++++------------- tools/synchro.py | 1 + 3 files changed, 249 insertions(+), 248 deletions(-) diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index eb431e1d50..8409720fb0 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -29,6 +29,7 @@ import uuid import warnings from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, async_client_context +from test.asynchronous.utils_spec_runner import AsyncSpecRunner from threading import Thread from typing import Any, Dict, Mapping @@ -611,132 +612,130 @@ async def test_with_statement(self): KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}} -if _IS_SYNC: - # TODO: Add asynchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700) - class TestSpec(SpecRunner): - @classmethod - @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") - def setUpClass(cls): - super().setUpClass() - - def parse_auto_encrypt_opts(self, opts): - """Parse clientOptions.autoEncryptOpts.""" - opts = camel_to_snake_args(opts) - kms_providers = opts["kms_providers"] - if "aws" in kms_providers: - kms_providers["aws"] = AWS_CREDS - if not any(AWS_CREDS.values()): - self.skipTest("AWS environment credentials are not set") - if "awsTemporary" in kms_providers: - kms_providers["aws"] = AWS_TEMP_CREDS - del kms_providers["awsTemporary"] - if not any(AWS_TEMP_CREDS.values()): - self.skipTest("AWS Temp environment credentials are not set") - if "awsTemporaryNoSessionToken" in kms_providers: - kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS - del kms_providers["awsTemporaryNoSessionToken"] - if not any(AWS_TEMP_NO_SESSION_CREDS.values()): - self.skipTest("AWS Temp environment credentials are not set") - if "azure" in kms_providers: - kms_providers["azure"] = AZURE_CREDS - if not any(AZURE_CREDS.values()): - self.skipTest("Azure environment credentials are not set") - if "gcp" in kms_providers: - kms_providers["gcp"] = GCP_CREDS - if not any(AZURE_CREDS.values()): - self.skipTest("GCP environment credentials are not set") - if "kmip" in kms_providers: - kms_providers["kmip"] = KMIP_CREDS - opts["kms_tls_options"] = KMS_TLS_OPTS - if "key_vault_namespace" not in opts: - opts["key_vault_namespace"] = "keyvault.datakeys" - if "extra_options" in opts: - opts.update(camel_to_snake_args(opts.pop("extra_options"))) - - opts = dict(opts) - return AutoEncryptionOpts(**opts) - - def parse_client_options(self, opts): - """Override clientOptions parsing to support autoEncryptOpts.""" - encrypt_opts = opts.pop("autoEncryptOpts", None) - if encrypt_opts: - opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts) - - return super().parse_client_options(opts) - - def get_object_name(self, op): - """Default object is collection.""" - return op.get("object", "collection") - - def maybe_skip_scenario(self, test): - super().maybe_skip_scenario(test) - desc = test["description"].lower() - if ( - "timeoutms applied to listcollections to get collection schema" in desc - and sys.platform in ("win32", "darwin") - ): - self.skipTest("PYTHON-3706 flaky test on Windows/macOS") - if "type=symbol" in desc: - self.skipTest("PyMongo does not support the symbol type") - - def setup_scenario(self, scenario_def): - """Override a test's setup.""" - key_vault_data = scenario_def["key_vault_data"] - encrypted_fields = scenario_def["encrypted_fields"] - json_schema = scenario_def["json_schema"] - data = scenario_def["data"] - coll = async_client_context.client.get_database("keyvault", codec_options=OPTS)[ - "datakeys" - ] - coll.delete_many({}) - if key_vault_data: - coll.insert_many(key_vault_data) - - db_name = self.get_scenario_db_name(scenario_def) - coll_name = self.get_scenario_coll_name(scenario_def) - db = async_client_context.client.get_database(db_name, codec_options=OPTS) - coll = db.drop_collection(coll_name, encrypted_fields=encrypted_fields) - wc = WriteConcern(w="majority") - kwargs: Dict[str, Any] = {} - if json_schema: - kwargs["validator"] = {"$jsonSchema": json_schema} - kwargs["codec_options"] = OPTS - if not data: - kwargs["write_concern"] = wc - if encrypted_fields: - kwargs["encryptedFields"] = encrypted_fields - db.create_collection(coll_name, **kwargs) - coll = db[coll_name] - if data: - # Load data. - coll.with_options(write_concern=wc).insert_many(scenario_def["data"]) - - def allowable_errors(self, op): - """Override expected error classes.""" - errors = super().allowable_errors(op) - # An updateOne test expects encryption to error when no $ operator - # appears but pymongo raises a client side ValueError in this case. - if op["name"] == "updateOne": - errors += (ValueError,) - return errors - - def create_test(scenario_def, test, name): - @async_client_context.require_test_commands - def run_scenario(self): - self.run_scenario(scenario_def, test) - - return run_scenario - - test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy")) - test_creator.create_tests() - - if _HAVE_PYMONGOCRYPT: - globals().update( - generate_test_classes( - os.path.join(SPEC_PATH, "unified"), - module=__name__, - ) +class AsyncTestSpec(AsyncSpecRunner): + @classmethod + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") + async def _setup_class(cls): + await super()._setup_class() + + def parse_auto_encrypt_opts(self, opts): + """Parse clientOptions.autoEncryptOpts.""" + opts = camel_to_snake_args(opts) + kms_providers = opts["kms_providers"] + if "aws" in kms_providers: + kms_providers["aws"] = AWS_CREDS + if not any(AWS_CREDS.values()): + self.skipTest("AWS environment credentials are not set") + if "awsTemporary" in kms_providers: + kms_providers["aws"] = AWS_TEMP_CREDS + del kms_providers["awsTemporary"] + if not any(AWS_TEMP_CREDS.values()): + self.skipTest("AWS Temp environment credentials are not set") + if "awsTemporaryNoSessionToken" in kms_providers: + kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS + del kms_providers["awsTemporaryNoSessionToken"] + if not any(AWS_TEMP_NO_SESSION_CREDS.values()): + self.skipTest("AWS Temp environment credentials are not set") + if "azure" in kms_providers: + kms_providers["azure"] = AZURE_CREDS + if not any(AZURE_CREDS.values()): + self.skipTest("Azure environment credentials are not set") + if "gcp" in kms_providers: + kms_providers["gcp"] = GCP_CREDS + if not any(AZURE_CREDS.values()): + self.skipTest("GCP environment credentials are not set") + if "kmip" in kms_providers: + kms_providers["kmip"] = KMIP_CREDS + opts["kms_tls_options"] = KMS_TLS_OPTS + if "key_vault_namespace" not in opts: + opts["key_vault_namespace"] = "keyvault.datakeys" + if "extra_options" in opts: + opts.update(camel_to_snake_args(opts.pop("extra_options"))) + + opts = dict(opts) + return AutoEncryptionOpts(**opts) + + def parse_client_options(self, opts): + """Override clientOptions parsing to support autoEncryptOpts.""" + encrypt_opts = opts.pop("autoEncryptOpts", None) + if encrypt_opts: + opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts) + + return super().parse_client_options(opts) + + def get_object_name(self, op): + """Default object is collection.""" + return op.get("object", "collection") + + def maybe_skip_scenario(self, test): + super().maybe_skip_scenario(test) + desc = test["description"].lower() + if ( + "timeoutms applied to listcollections to get collection schema" in desc + and sys.platform in ("win32", "darwin") + ): + self.skipTest("PYTHON-3706 flaky test on Windows/macOS") + if "type=symbol" in desc: + self.skipTest("PyMongo does not support the symbol type") + + async def setup_scenario(self, scenario_def): + """Override a test's setup.""" + key_vault_data = scenario_def["key_vault_data"] + encrypted_fields = scenario_def["encrypted_fields"] + json_schema = scenario_def["json_schema"] + data = scenario_def["data"] + coll = async_client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"] + await coll.delete_many({}) + if key_vault_data: + await coll.insert_many(key_vault_data) + + db_name = self.get_scenario_db_name(scenario_def) + coll_name = self.get_scenario_coll_name(scenario_def) + db = async_client_context.client.get_database(db_name, codec_options=OPTS) + coll = await db.drop_collection(coll_name, encrypted_fields=encrypted_fields) + wc = WriteConcern(w="majority") + kwargs: Dict[str, Any] = {} + if json_schema: + kwargs["validator"] = {"$jsonSchema": json_schema} + kwargs["codec_options"] = OPTS + if not data: + kwargs["write_concern"] = wc + if encrypted_fields: + kwargs["encryptedFields"] = encrypted_fields + await db.create_collection(coll_name, **kwargs) + coll = db[coll_name] + if data: + # Load data. + await coll.with_options(write_concern=wc).insert_many(scenario_def["data"]) + + def allowable_errors(self, op): + """Override expected error classes.""" + errors = super().allowable_errors(op) + # An updateOne test expects encryption to error when no $ operator + # appears but pymongo raises a client side ValueError in this case. + if op["name"] == "updateOne": + errors += (ValueError,) + return errors + + +def create_test(scenario_def, test, name): + @async_client_context.require_test_commands + def run_scenario(self): + self.run_scenario(scenario_def, test) + + return run_scenario + + +test_creator = SpecTestCreator(create_test, AsyncTestSpec, os.path.join(SPEC_PATH, "legacy")) +test_creator.create_tests() + +if _HAVE_PYMONGOCRYPT: + globals().update( + generate_test_classes( + os.path.join(SPEC_PATH, "unified"), + module=__name__, ) + ) # Prose Tests ALL_KMS_PROVIDERS = { diff --git a/test/test_encryption.py b/test/test_encryption.py index 568ebffc9e..d9c402f6f3 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -29,6 +29,7 @@ import uuid import warnings from test import IntegrationTest, PyMongoTestCase, client_context +from test.utils_spec_runner import SpecRunner from threading import Thread from typing import Any, Dict, Mapping @@ -609,130 +610,130 @@ def test_with_statement(self): KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}} -if _IS_SYNC: - # TODO: Add synchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700) - class TestSpec(SpecRunner): - @classmethod - @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") - def setUpClass(cls): - super().setUpClass() - - def parse_auto_encrypt_opts(self, opts): - """Parse clientOptions.autoEncryptOpts.""" - opts = camel_to_snake_args(opts) - kms_providers = opts["kms_providers"] - if "aws" in kms_providers: - kms_providers["aws"] = AWS_CREDS - if not any(AWS_CREDS.values()): - self.skipTest("AWS environment credentials are not set") - if "awsTemporary" in kms_providers: - kms_providers["aws"] = AWS_TEMP_CREDS - del kms_providers["awsTemporary"] - if not any(AWS_TEMP_CREDS.values()): - self.skipTest("AWS Temp environment credentials are not set") - if "awsTemporaryNoSessionToken" in kms_providers: - kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS - del kms_providers["awsTemporaryNoSessionToken"] - if not any(AWS_TEMP_NO_SESSION_CREDS.values()): - self.skipTest("AWS Temp environment credentials are not set") - if "azure" in kms_providers: - kms_providers["azure"] = AZURE_CREDS - if not any(AZURE_CREDS.values()): - self.skipTest("Azure environment credentials are not set") - if "gcp" in kms_providers: - kms_providers["gcp"] = GCP_CREDS - if not any(AZURE_CREDS.values()): - self.skipTest("GCP environment credentials are not set") - if "kmip" in kms_providers: - kms_providers["kmip"] = KMIP_CREDS - opts["kms_tls_options"] = KMS_TLS_OPTS - if "key_vault_namespace" not in opts: - opts["key_vault_namespace"] = "keyvault.datakeys" - if "extra_options" in opts: - opts.update(camel_to_snake_args(opts.pop("extra_options"))) - - opts = dict(opts) - return AutoEncryptionOpts(**opts) - - def parse_client_options(self, opts): - """Override clientOptions parsing to support autoEncryptOpts.""" - encrypt_opts = opts.pop("autoEncryptOpts", None) - if encrypt_opts: - opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts) - - return super().parse_client_options(opts) - - def get_object_name(self, op): - """Default object is collection.""" - return op.get("object", "collection") - - def maybe_skip_scenario(self, test): - super().maybe_skip_scenario(test) - desc = test["description"].lower() - if ( - "timeoutms applied to listcollections to get collection schema" in desc - and sys.platform in ("win32", "darwin") - ): - self.skipTest("PYTHON-3706 flaky test on Windows/macOS") - if "type=symbol" in desc: - self.skipTest("PyMongo does not support the symbol type") - - def setup_scenario(self, scenario_def): - """Override a test's setup.""" - key_vault_data = scenario_def["key_vault_data"] - encrypted_fields = scenario_def["encrypted_fields"] - json_schema = scenario_def["json_schema"] - data = scenario_def["data"] - coll = client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"] - coll.delete_many({}) - if key_vault_data: - coll.insert_many(key_vault_data) - - db_name = self.get_scenario_db_name(scenario_def) - coll_name = self.get_scenario_coll_name(scenario_def) - db = client_context.client.get_database(db_name, codec_options=OPTS) - coll = db.drop_collection(coll_name, encrypted_fields=encrypted_fields) - wc = WriteConcern(w="majority") - kwargs: Dict[str, Any] = {} - if json_schema: - kwargs["validator"] = {"$jsonSchema": json_schema} - kwargs["codec_options"] = OPTS - if not data: - kwargs["write_concern"] = wc - if encrypted_fields: - kwargs["encryptedFields"] = encrypted_fields - db.create_collection(coll_name, **kwargs) - coll = db[coll_name] - if data: - # Load data. - coll.with_options(write_concern=wc).insert_many(scenario_def["data"]) - - def allowable_errors(self, op): - """Override expected error classes.""" - errors = super().allowable_errors(op) - # An updateOne test expects encryption to error when no $ operator - # appears but pymongo raises a client side ValueError in this case. - if op["name"] == "updateOne": - errors += (ValueError,) - return errors - - def create_test(scenario_def, test, name): - @client_context.require_test_commands - def run_scenario(self): - self.run_scenario(scenario_def, test) - - return run_scenario - - test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy")) - test_creator.create_tests() - - if _HAVE_PYMONGOCRYPT: - globals().update( - generate_test_classes( - os.path.join(SPEC_PATH, "unified"), - module=__name__, - ) +class TestSpec(SpecRunner): + @classmethod + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") + def _setup_class(cls): + super()._setup_class() + + def parse_auto_encrypt_opts(self, opts): + """Parse clientOptions.autoEncryptOpts.""" + opts = camel_to_snake_args(opts) + kms_providers = opts["kms_providers"] + if "aws" in kms_providers: + kms_providers["aws"] = AWS_CREDS + if not any(AWS_CREDS.values()): + self.skipTest("AWS environment credentials are not set") + if "awsTemporary" in kms_providers: + kms_providers["aws"] = AWS_TEMP_CREDS + del kms_providers["awsTemporary"] + if not any(AWS_TEMP_CREDS.values()): + self.skipTest("AWS Temp environment credentials are not set") + if "awsTemporaryNoSessionToken" in kms_providers: + kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS + del kms_providers["awsTemporaryNoSessionToken"] + if not any(AWS_TEMP_NO_SESSION_CREDS.values()): + self.skipTest("AWS Temp environment credentials are not set") + if "azure" in kms_providers: + kms_providers["azure"] = AZURE_CREDS + if not any(AZURE_CREDS.values()): + self.skipTest("Azure environment credentials are not set") + if "gcp" in kms_providers: + kms_providers["gcp"] = GCP_CREDS + if not any(AZURE_CREDS.values()): + self.skipTest("GCP environment credentials are not set") + if "kmip" in kms_providers: + kms_providers["kmip"] = KMIP_CREDS + opts["kms_tls_options"] = KMS_TLS_OPTS + if "key_vault_namespace" not in opts: + opts["key_vault_namespace"] = "keyvault.datakeys" + if "extra_options" in opts: + opts.update(camel_to_snake_args(opts.pop("extra_options"))) + + opts = dict(opts) + return AutoEncryptionOpts(**opts) + + def parse_client_options(self, opts): + """Override clientOptions parsing to support autoEncryptOpts.""" + encrypt_opts = opts.pop("autoEncryptOpts", None) + if encrypt_opts: + opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts) + + return super().parse_client_options(opts) + + def get_object_name(self, op): + """Default object is collection.""" + return op.get("object", "collection") + + def maybe_skip_scenario(self, test): + super().maybe_skip_scenario(test) + desc = test["description"].lower() + if ( + "timeoutms applied to listcollections to get collection schema" in desc + and sys.platform in ("win32", "darwin") + ): + self.skipTest("PYTHON-3706 flaky test on Windows/macOS") + if "type=symbol" in desc: + self.skipTest("PyMongo does not support the symbol type") + + def setup_scenario(self, scenario_def): + """Override a test's setup.""" + key_vault_data = scenario_def["key_vault_data"] + encrypted_fields = scenario_def["encrypted_fields"] + json_schema = scenario_def["json_schema"] + data = scenario_def["data"] + coll = client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"] + coll.delete_many({}) + if key_vault_data: + coll.insert_many(key_vault_data) + + db_name = self.get_scenario_db_name(scenario_def) + coll_name = self.get_scenario_coll_name(scenario_def) + db = client_context.client.get_database(db_name, codec_options=OPTS) + coll = db.drop_collection(coll_name, encrypted_fields=encrypted_fields) + wc = WriteConcern(w="majority") + kwargs: Dict[str, Any] = {} + if json_schema: + kwargs["validator"] = {"$jsonSchema": json_schema} + kwargs["codec_options"] = OPTS + if not data: + kwargs["write_concern"] = wc + if encrypted_fields: + kwargs["encryptedFields"] = encrypted_fields + db.create_collection(coll_name, **kwargs) + coll = db[coll_name] + if data: + # Load data. + coll.with_options(write_concern=wc).insert_many(scenario_def["data"]) + + def allowable_errors(self, op): + """Override expected error classes.""" + errors = super().allowable_errors(op) + # An updateOne test expects encryption to error when no $ operator + # appears but pymongo raises a client side ValueError in this case. + if op["name"] == "updateOne": + errors += (ValueError,) + return errors + + +def create_test(scenario_def, test, name): + @client_context.require_test_commands + def run_scenario(self): + self.run_scenario(scenario_def, test) + + return run_scenario + + +test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy")) +test_creator.create_tests() + +if _HAVE_PYMONGOCRYPT: + globals().update( + generate_test_classes( + os.path.join(SPEC_PATH, "unified"), + module=__name__, ) + ) # Prose Tests ALL_KMS_PROVIDERS = { diff --git a/tools/synchro.py b/tools/synchro.py index f38a83f128..2a06daea6c 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -103,6 +103,7 @@ "PyMongo|async": "PyMongo", "AsyncTestGridFile": "TestGridFile", "AsyncTestGridFileNoConnect": "TestGridFileNoConnect", + "AsyncTestSpec": "TestSpec", } docstring_replacements: dict[tuple[str, str], str] = { From acc8bdc681a36a94f7649a031d5872f11576cbc1 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 9 Sep 2024 10:52:59 -0400 Subject: [PATCH 02/16] Add async TestSpecCreator --- test/asynchronous/test_encryption.py | 12 +- test/asynchronous/utils_spec_runner.py | 172 ++++++++++++++++++++++++- test/test_connection_monitoring.py | 3 +- test/test_encryption.py | 4 +- test/test_retryable_reads.py | 3 +- test/test_retryable_writes.py | 3 +- test/utils.py | 147 --------------------- test/utils_spec_runner.py | 170 +++++++++++++++++++++++- tools/synchro.py | 1 + 9 files changed, 345 insertions(+), 170 deletions(-) diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 8409720fb0..6a3d600fcb 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -29,7 +29,7 @@ import uuid import warnings from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, async_client_context -from test.asynchronous.utils_spec_runner import AsyncSpecRunner +from test.asynchronous.utils_spec_runner import AsyncSpecRunner, AsyncSpecTestCreator from threading import Thread from typing import Any, Dict, Mapping @@ -58,14 +58,12 @@ from test.utils import ( AllowListEventListener, OvertCommandListener, - SpecTestCreator, TopologyEventListener, async_rs_or_single_client, async_wait_until, camel_to_snake_args, is_greenthread_patched, ) -from test.utils_spec_runner import SpecRunner from bson import DatetimeMS, Decimal128, encode, json_util from bson.binary import UUID_SUBTYPE, Binary, UuidRepresentation @@ -718,15 +716,15 @@ def allowable_errors(self, op): return errors -def create_test(scenario_def, test, name): +async def create_test(scenario_def, test, name): @async_client_context.require_test_commands - def run_scenario(self): - self.run_scenario(scenario_def, test) + async def run_scenario(self): + await self.run_scenario(scenario_def, test) return run_scenario -test_creator = SpecTestCreator(create_test, AsyncTestSpec, os.path.join(SPEC_PATH, "legacy")) +test_creator = AsyncSpecTestCreator(create_test, AsyncTestSpec, os.path.join(SPEC_PATH, "legacy")) test_creator.create_tests() if _HAVE_PYMONGOCRYPT: diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 71044d1530..a44f4b4a11 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -15,8 +15,12 @@ """Utilities for testing driver specs.""" from __future__ import annotations +import asyncio import functools +import os import threading +import unittest +from asyncio import iscoroutinefunction from collections import abc from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs from test.utils import ( @@ -24,6 +28,7 @@ CompareType, EventListener, OvertCommandListener, + ScenarioDict, ServerAndTopologyEventListener, async_rs_client, camel_to_snake, @@ -33,11 +38,12 @@ ) from typing import List -from bson import ObjectId, decode, encode +from bson import ObjectId, decode, encode, json_util from bson.binary import Binary from bson.int64 import Int64 from bson.son import SON from gridfs import GridFSBucket +from gridfs.asynchronous.grid_file import AsyncGridFSBucket from pymongo.asynchronous import client_session from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.cursor import AsyncCursor @@ -84,6 +90,161 @@ def run(self): self.stop() +class AsyncSpecTestCreator: + """Class to create test cases from specifications.""" + + def __init__(self, create_test, test_class, test_path): + """Create a TestCreator object. + + :Parameters: + - `create_test`: callback that returns a test case. The callback + must accept the following arguments - a dictionary containing the + entire test specification (the `scenario_def`), a dictionary + containing the specification for which the test case will be + generated (the `test_def`). + - `test_class`: the unittest.TestCase class in which to create the + test case. + - `test_path`: path to the directory containing the JSON files with + the test specifications. + """ + self._create_test = create_test + self._test_class = test_class + self.test_path = test_path + + def _ensure_min_max_server_version(self, scenario_def, method): + """Test modifier that enforces a version range for the server on a + test case. + """ + if "minServerVersion" in scenario_def: + min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split(".")) + if min_ver is not None: + method = async_client_context.require_version_min(*min_ver)(method) + + if "maxServerVersion" in scenario_def: + max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split(".")) + if max_ver is not None: + method = async_client_context.require_version_max(*max_ver)(method) + + if "serverless" in scenario_def: + serverless = scenario_def["serverless"] + if serverless == "require": + serverless_satisfied = async_client_context.serverless + elif serverless == "forbid": + serverless_satisfied = not async_client_context.serverless + else: # unset or "allow" + serverless_satisfied = True + method = unittest.skipUnless( + serverless_satisfied, "Serverless requirement not satisfied" + )(method) + + return method + + @staticmethod + async def valid_topology(run_on_req): + return await async_client_context.is_topology_type( + run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"]) + ) + + @staticmethod + def min_server_version(run_on_req): + version = run_on_req.get("minServerVersion") + if version: + min_ver = tuple(int(elt) for elt in version.split(".")) + return async_client_context.version >= min_ver + return True + + @staticmethod + def max_server_version(run_on_req): + version = run_on_req.get("maxServerVersion") + if version: + max_ver = tuple(int(elt) for elt in version.split(".")) + return async_client_context.version <= max_ver + return True + + @staticmethod + def valid_auth_enabled(run_on_req): + if "authEnabled" in run_on_req: + if run_on_req["authEnabled"]: + return async_client_context.auth_enabled + return not async_client_context.auth_enabled + return True + + @staticmethod + def serverless_ok(run_on_req): + serverless = run_on_req["serverless"] + if serverless == "require": + return async_client_context.serverless + elif serverless == "forbid": + return not async_client_context.serverless + else: # unset or "allow" + return True + + async def should_run_on(self, scenario_def): + run_on = scenario_def.get("runOn", []) + if not run_on: + # Always run these tests. + return True + + for req in run_on: + if ( + await self.valid_topology(req) + and self.min_server_version(req) + and self.max_server_version(req) + and self.valid_auth_enabled(req) + and self.serverless_ok(req) + ): + return True + return False + + def ensure_run_on(self, scenario_def, method): + """Test modifier that enforces a 'runOn' on a test case.""" + + async def predicate(): + return await self.should_run_on(scenario_def) + + return async_client_context._require(predicate, "runOn not satisfied", method) + + def tests(self, scenario_def): + """Allow CMAP spec test to override the location of test.""" + return scenario_def["tests"] + + async def _create_tests(self): + for dirpath, _, filenames in os.walk(self.test_path): + dirname = os.path.split(dirpath)[-1] + + for filename in filenames: + with open(os.path.join(dirpath, filename)) as scenario_stream: # noqa: ASYNC101, RUF100 + # Use tz_aware=False to match how CodecOptions decodes + # dates. + opts = json_util.JSONOptions(tz_aware=False) + scenario_def = ScenarioDict( + json_util.loads(scenario_stream.read(), json_options=opts) + ) + + test_type = os.path.splitext(filename)[0] + + # Construct test from scenario. + for test_def in self.tests(scenario_def): + test_name = "test_{}_{}_{}".format( + dirname, + test_type.replace("-", "_").replace(".", "_"), + str(test_def["description"].replace(" ", "_").replace(".", "_")), + ) + + new_test = await self._create_test(scenario_def, test_def, test_name) + new_test = self._ensure_min_max_server_version(scenario_def, new_test) + new_test = self.ensure_run_on(scenario_def, new_test) + + new_test.__name__ = test_name + setattr(self._test_class, new_test.__name__, new_test) + + def create_tests(self): + if _IS_SYNC: + self._create_tests() + else: + asyncio.run(self._create_tests()) + + class AsyncSpecRunner(AsyncIntegrationTest): mongos_clients: List knobs: client_knobs @@ -283,7 +444,7 @@ async def run_operation(self, sessions, collection, operation): if object_name == "gridfsbucket": # Only create the GridFSBucket when we need it (for the gridfs # retryable reads tests). - obj = GridFSBucket(database, bucket_name=collection.name) + obj = AsyncGridFSBucket(database, bucket_name=collection.name) else: objects = { "client": database.client, @@ -311,7 +472,10 @@ async def run_operation(self, sessions, collection, operation): args.update(arguments) arguments = args - result = cmd(**dict(arguments)) + if not _IS_SYNC and iscoroutinefunction(cmd): + result = await cmd(**dict(arguments)) + else: + result = cmd(**dict(arguments)) # Cleanup open change stream cursors. if name == "watch": self.addAsyncCleanup(result.close) @@ -587,7 +751,7 @@ async def run_scenario(self, scenario_def, test): read_preference=ReadPreference.PRIMARY, read_concern=ReadConcern("local"), ) - actual_data = await (await outcome_coll.find(sort=[("_id", 1)])).to_list() + actual_data = await outcome_coll.find(sort=[("_id", 1)]).to_list() # The expected data needs to be the left hand side here otherwise # CompareType(Binary) doesn't work. diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index 9ee3202e13..3f3c29c011 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -25,7 +25,6 @@ from test.pymongo_mocks import DummyMonitor from test.utils import ( CMAPListener, - SpecTestCreator, camel_to_snake, client_context, get_pool, @@ -35,7 +34,7 @@ single_client_noauth, wait_until, ) -from test.utils_spec_runner import SpecRunnerThread +from test.utils_spec_runner import SpecRunnerThread, SpecTestCreator from bson.objectid import ObjectId from bson.son import SON diff --git a/test/test_encryption.py b/test/test_encryption.py index d9c402f6f3..18d8639419 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -29,7 +29,7 @@ import uuid import warnings from test import IntegrationTest, PyMongoTestCase, client_context -from test.utils_spec_runner import SpecRunner +from test.utils_spec_runner import SpecRunner, SpecTestCreator from threading import Thread from typing import Any, Dict, Mapping @@ -58,14 +58,12 @@ from test.utils import ( AllowListEventListener, OvertCommandListener, - SpecTestCreator, TopologyEventListener, camel_to_snake_args, is_greenthread_patched, rs_or_single_client, wait_until, ) -from test.utils_spec_runner import SpecRunner from bson import DatetimeMS, Decimal128, encode, json_util from bson.binary import UUID_SUBTYPE, Binary, UuidRepresentation diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index 9ea546ba9b..78d8670846 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -36,12 +36,11 @@ CMAPListener, EventListener, OvertCommandListener, - SpecTestCreator, rs_client, rs_or_single_client, set_fail_point, ) -from test.utils_spec_runner import SpecRunner +from test.utils_spec_runner import SpecRunner, SpecTestCreator from pymongo.monitoring import ( ConnectionCheckedOutEvent, diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 45a740e844..dc06b4c471 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -29,11 +29,10 @@ DeprecationFilter, EventListener, OvertCommandListener, - SpecTestCreator, rs_or_single_client, set_fail_point, ) -from test.utils_spec_runner import SpecRunner +from test.utils_spec_runner import SpecRunner, SpecTestCreator from test.version import Version from bson.codec_options import DEFAULT_CODEC_OPTIONS diff --git a/test/utils.py b/test/utils.py index fa198b1c64..21defb879d 100644 --- a/test/utils.py +++ b/test/utils.py @@ -418,153 +418,6 @@ def call_count(self): return len(self._call_list) -class SpecTestCreator: - """Class to create test cases from specifications.""" - - def __init__(self, create_test, test_class, test_path): - """Create a TestCreator object. - - :Parameters: - - `create_test`: callback that returns a test case. The callback - must accept the following arguments - a dictionary containing the - entire test specification (the `scenario_def`), a dictionary - containing the specification for which the test case will be - generated (the `test_def`). - - `test_class`: the unittest.TestCase class in which to create the - test case. - - `test_path`: path to the directory containing the JSON files with - the test specifications. - """ - self._create_test = create_test - self._test_class = test_class - self.test_path = test_path - - def _ensure_min_max_server_version(self, scenario_def, method): - """Test modifier that enforces a version range for the server on a - test case. - """ - if "minServerVersion" in scenario_def: - min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split(".")) - if min_ver is not None: - method = client_context.require_version_min(*min_ver)(method) - - if "maxServerVersion" in scenario_def: - max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split(".")) - if max_ver is not None: - method = client_context.require_version_max(*max_ver)(method) - - if "serverless" in scenario_def: - serverless = scenario_def["serverless"] - if serverless == "require": - serverless_satisfied = client_context.serverless - elif serverless == "forbid": - serverless_satisfied = not client_context.serverless - else: # unset or "allow" - serverless_satisfied = True - method = unittest.skipUnless( - serverless_satisfied, "Serverless requirement not satisfied" - )(method) - - return method - - @staticmethod - def valid_topology(run_on_req): - return client_context.is_topology_type( - run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"]) - ) - - @staticmethod - def min_server_version(run_on_req): - version = run_on_req.get("minServerVersion") - if version: - min_ver = tuple(int(elt) for elt in version.split(".")) - return client_context.version >= min_ver - return True - - @staticmethod - def max_server_version(run_on_req): - version = run_on_req.get("maxServerVersion") - if version: - max_ver = tuple(int(elt) for elt in version.split(".")) - return client_context.version <= max_ver - return True - - @staticmethod - def valid_auth_enabled(run_on_req): - if "authEnabled" in run_on_req: - if run_on_req["authEnabled"]: - return client_context.auth_enabled - return not client_context.auth_enabled - return True - - @staticmethod - def serverless_ok(run_on_req): - serverless = run_on_req["serverless"] - if serverless == "require": - return client_context.serverless - elif serverless == "forbid": - return not client_context.serverless - else: # unset or "allow" - return True - - def should_run_on(self, scenario_def): - run_on = scenario_def.get("runOn", []) - if not run_on: - # Always run these tests. - return True - - for req in run_on: - if ( - self.valid_topology(req) - and self.min_server_version(req) - and self.max_server_version(req) - and self.valid_auth_enabled(req) - and self.serverless_ok(req) - ): - return True - return False - - def ensure_run_on(self, scenario_def, method): - """Test modifier that enforces a 'runOn' on a test case.""" - return client_context._require( - lambda: self.should_run_on(scenario_def), "runOn not satisfied", method - ) - - def tests(self, scenario_def): - """Allow CMAP spec test to override the location of test.""" - return scenario_def["tests"] - - def create_tests(self): - for dirpath, _, filenames in os.walk(self.test_path): - dirname = os.path.split(dirpath)[-1] - - for filename in filenames: - with open(os.path.join(dirpath, filename)) as scenario_stream: - # Use tz_aware=False to match how CodecOptions decodes - # dates. - opts = json_util.JSONOptions(tz_aware=False) - scenario_def = ScenarioDict( - json_util.loads(scenario_stream.read(), json_options=opts) - ) - - test_type = os.path.splitext(filename)[0] - - # Construct test from scenario. - for test_def in self.tests(scenario_def): - test_name = "test_{}_{}_{}".format( - dirname, - test_type.replace("-", "_").replace(".", "_"), - str(test_def["description"].replace(" ", "_").replace(".", "_")), - ) - - new_test = self._create_test(scenario_def, test_def, test_name) - new_test = self._ensure_min_max_server_version(scenario_def, new_test) - new_test = self.ensure_run_on(scenario_def, new_test) - - new_test.__name__ = test_name - setattr(self._test_class, new_test.__name__, new_test) - - def _connection_string(h): if h.startswith(("mongodb://", "mongodb+srv://")): return h diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 0b882a8bc3..1a782a6e7d 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -15,8 +15,12 @@ """Utilities for testing driver specs.""" from __future__ import annotations +import asyncio import functools +import os import threading +import unittest +from asyncio import iscoroutinefunction from collections import abc from test import IntegrationTest, client_context, client_knobs from test.utils import ( @@ -24,6 +28,7 @@ CompareType, EventListener, OvertCommandListener, + ScenarioDict, ServerAndTopologyEventListener, camel_to_snake, camel_to_snake_args, @@ -33,11 +38,12 @@ ) from typing import List -from bson import ObjectId, decode, encode +from bson import ObjectId, decode, encode, json_util from bson.binary import Binary from bson.int64 import Int64 from bson.son import SON from gridfs import GridFSBucket +from gridfs.synchronous.grid_file import GridFSBucket from pymongo.errors import BulkWriteError, OperationFailure, PyMongoError from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference @@ -84,6 +90,161 @@ def run(self): self.stop() +class SpecTestCreator: + """Class to create test cases from specifications.""" + + def __init__(self, create_test, test_class, test_path): + """Create a TestCreator object. + + :Parameters: + - `create_test`: callback that returns a test case. The callback + must accept the following arguments - a dictionary containing the + entire test specification (the `scenario_def`), a dictionary + containing the specification for which the test case will be + generated (the `test_def`). + - `test_class`: the unittest.TestCase class in which to create the + test case. + - `test_path`: path to the directory containing the JSON files with + the test specifications. + """ + self._create_test = create_test + self._test_class = test_class + self.test_path = test_path + + def _ensure_min_max_server_version(self, scenario_def, method): + """Test modifier that enforces a version range for the server on a + test case. + """ + if "minServerVersion" in scenario_def: + min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split(".")) + if min_ver is not None: + method = client_context.require_version_min(*min_ver)(method) + + if "maxServerVersion" in scenario_def: + max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split(".")) + if max_ver is not None: + method = client_context.require_version_max(*max_ver)(method) + + if "serverless" in scenario_def: + serverless = scenario_def["serverless"] + if serverless == "require": + serverless_satisfied = client_context.serverless + elif serverless == "forbid": + serverless_satisfied = not client_context.serverless + else: # unset or "allow" + serverless_satisfied = True + method = unittest.skipUnless( + serverless_satisfied, "Serverless requirement not satisfied" + )(method) + + return method + + @staticmethod + def valid_topology(run_on_req): + return client_context.is_topology_type( + run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"]) + ) + + @staticmethod + def min_server_version(run_on_req): + version = run_on_req.get("minServerVersion") + if version: + min_ver = tuple(int(elt) for elt in version.split(".")) + return client_context.version >= min_ver + return True + + @staticmethod + def max_server_version(run_on_req): + version = run_on_req.get("maxServerVersion") + if version: + max_ver = tuple(int(elt) for elt in version.split(".")) + return client_context.version <= max_ver + return True + + @staticmethod + def valid_auth_enabled(run_on_req): + if "authEnabled" in run_on_req: + if run_on_req["authEnabled"]: + return client_context.auth_enabled + return not client_context.auth_enabled + return True + + @staticmethod + def serverless_ok(run_on_req): + serverless = run_on_req["serverless"] + if serverless == "require": + return client_context.serverless + elif serverless == "forbid": + return not client_context.serverless + else: # unset or "allow" + return True + + def should_run_on(self, scenario_def): + run_on = scenario_def.get("runOn", []) + if not run_on: + # Always run these tests. + return True + + for req in run_on: + if ( + self.valid_topology(req) + and self.min_server_version(req) + and self.max_server_version(req) + and self.valid_auth_enabled(req) + and self.serverless_ok(req) + ): + return True + return False + + def ensure_run_on(self, scenario_def, method): + """Test modifier that enforces a 'runOn' on a test case.""" + + def predicate(): + return self.should_run_on(scenario_def) + + return client_context._require(predicate, "runOn not satisfied", method) + + def tests(self, scenario_def): + """Allow CMAP spec test to override the location of test.""" + return scenario_def["tests"] + + def _create_tests(self): + for dirpath, _, filenames in os.walk(self.test_path): + dirname = os.path.split(dirpath)[-1] + + for filename in filenames: + with open(os.path.join(dirpath, filename)) as scenario_stream: # noqa: ASYNC101, RUF100 + # Use tz_aware=False to match how CodecOptions decodes + # dates. + opts = json_util.JSONOptions(tz_aware=False) + scenario_def = ScenarioDict( + json_util.loads(scenario_stream.read(), json_options=opts) + ) + + test_type = os.path.splitext(filename)[0] + + # Construct test from scenario. + for test_def in self.tests(scenario_def): + test_name = "test_{}_{}_{}".format( + dirname, + test_type.replace("-", "_").replace(".", "_"), + str(test_def["description"].replace(" ", "_").replace(".", "_")), + ) + + new_test = self._create_test(scenario_def, test_def, test_name) + new_test = self._ensure_min_max_server_version(scenario_def, new_test) + new_test = self.ensure_run_on(scenario_def, new_test) + + new_test.__name__ = test_name + setattr(self._test_class, new_test.__name__, new_test) + + def create_tests(self): + if _IS_SYNC: + self._create_tests() + else: + asyncio.run(self._create_tests()) + + class SpecRunner(IntegrationTest): mongos_clients: List knobs: client_knobs @@ -311,7 +472,10 @@ def run_operation(self, sessions, collection, operation): args.update(arguments) arguments = args - result = cmd(**dict(arguments)) + if not _IS_SYNC and iscoroutinefunction(cmd): + result = cmd(**dict(arguments)) + else: + result = cmd(**dict(arguments)) # Cleanup open change stream cursors. if name == "watch": self.addCleanup(result.close) @@ -582,7 +746,7 @@ def run_scenario(self, scenario_def, test): read_preference=ReadPreference.PRIMARY, read_concern=ReadConcern("local"), ) - actual_data = (outcome_coll.find(sort=[("_id", 1)])).to_list() + actual_data = outcome_coll.find(sort=[("_id", 1)]).to_list() # The expected data needs to be the left hand side here otherwise # CompareType(Binary) doesn't work. diff --git a/tools/synchro.py b/tools/synchro.py index 2a06daea6c..6fed099ad7 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -104,6 +104,7 @@ "AsyncTestGridFile": "TestGridFile", "AsyncTestGridFileNoConnect": "TestGridFileNoConnect", "AsyncTestSpec": "TestSpec", + "AsyncSpecTestCreator": "SpecTestCreator", } docstring_replacements: dict[tuple[str, str], str] = { From e6d4b114d2a0fc82b6de02624b4b709b1c1ff0d7 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 9 Sep 2024 11:37:54 -0400 Subject: [PATCH 03/16] Fix imports --- test/test_server_selection_in_window.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 9dced595c9..6ab251f052 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -20,12 +20,12 @@ from test import IntegrationTest, client_context, unittest from test.utils import ( OvertCommandListener, - SpecTestCreator, get_pool, rs_client, wait_until, ) from test.utils_selection_tests import create_topology +from test.utils_spec_runner import SpecTestCreator from pymongo.common import clean_node from pymongo.operations import _Op From 192270273251c1730e7ae2828be24c75605b6c22 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 9 Sep 2024 12:04:39 -0400 Subject: [PATCH 04/16] Linting --- test/asynchronous/test_encryption.py | 2 +- test/test_encryption.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index d8f96c7e4b..c4b9855934 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -29,8 +29,8 @@ import uuid import warnings from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, async_client_context -from test.asynchronous.utils_spec_runner import AsyncSpecRunner, AsyncSpecTestCreator from test.asynchronous.test_bulk import AsyncBulkTestBase +from test.asynchronous.utils_spec_runner import AsyncSpecRunner, AsyncSpecTestCreator from threading import Thread from typing import Any, Dict, Mapping diff --git a/test/test_encryption.py b/test/test_encryption.py index cf6d724f3d..475993e213 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -29,8 +29,8 @@ import uuid import warnings from test import IntegrationTest, PyMongoTestCase, client_context -from test.utils_spec_runner import SpecRunner, SpecTestCreator from test.test_bulk import BulkTestBase +from test.utils_spec_runner import SpecRunner, SpecTestCreator from threading import Thread from typing import Any, Dict, Mapping From 97f39dd226679b31ebce303f3ad9398dfd131445 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 9 Sep 2024 14:07:23 -0400 Subject: [PATCH 05/16] Predicate --- .evergreen/run-tests.sh | 4 ++-- hatch.toml | 2 +- test/asynchronous/utils_spec_runner.py | 2 +- test/utils_spec_runner.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 66df6b26ca..330d881fe7 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -254,9 +254,9 @@ if [ -z "$GREEN_FRAMEWORK" ]; then # Use --capture=tee-sys so pytest prints test output inline: # https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html if [ -z "$TEST_SUITES" ]; then - python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS + python -m pytest -v --capture=tee-sys --durations=5 $TEST_ARGS else - python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 -m $TEST_SUITES $TEST_ARGS + python -m pytest -v --capture=tee-sys --durations=5 -m $TEST_SUITES $TEST_ARGS fi else python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS diff --git a/hatch.toml b/hatch.toml index 8b1cf93e32..3821dcc822 100644 --- a/hatch.toml +++ b/hatch.toml @@ -39,7 +39,7 @@ run-manual = "pre-commit run --all-files --hook-stage manual" [envs.test] features = ["test"] [envs.test.scripts] -test = "pytest -v --durations=5 --maxfail=10 {args}" +test = "pytest -v --durations=5 {args}" test-eg = "bash ./.evergreen/run-tests.sh {args}" test-async = "pytest -v --durations=5 --maxfail=10 -m default_async {args}" test-mockupdb = ["pip install -U git+https://github.com/ajdavis/mongo-mockup-db@master", "test -m mockupdb"] diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index a44f4b4a11..6de9233747 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -202,7 +202,7 @@ def ensure_run_on(self, scenario_def, method): async def predicate(): return await self.should_run_on(scenario_def) - return async_client_context._require(predicate, "runOn not satisfied", method) + return async_client_context._require(lambda: predicate, "runOn not satisfied", method) def tests(self, scenario_def): """Allow CMAP spec test to override the location of test.""" diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 1a782a6e7d..fc7aa1f5ee 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -202,7 +202,7 @@ def ensure_run_on(self, scenario_def, method): def predicate(): return self.should_run_on(scenario_def) - return client_context._require(predicate, "runOn not satisfied", method) + return client_context._require(lambda: predicate, "runOn not satisfied", method) def tests(self, scenario_def): """Allow CMAP spec test to override the location of test.""" From 7a4df29b4b226fa2b5b7f333bc5ff69790a468f7 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 26 Sep 2024 16:14:55 -0400 Subject: [PATCH 06/16] Add test duration --- .evergreen/run-tests.sh | 4 ++-- hatch.toml | 2 +- test/asynchronous/utils_spec_runner.py | 5 +++++ test/utils_spec_runner.py | 5 +++++ 4 files changed, 13 insertions(+), 3 deletions(-) diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 330d881fe7..66df6b26ca 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -254,9 +254,9 @@ if [ -z "$GREEN_FRAMEWORK" ]; then # Use --capture=tee-sys so pytest prints test output inline: # https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html if [ -z "$TEST_SUITES" ]; then - python -m pytest -v --capture=tee-sys --durations=5 $TEST_ARGS + python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS else - python -m pytest -v --capture=tee-sys --durations=5 -m $TEST_SUITES $TEST_ARGS + python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 -m $TEST_SUITES $TEST_ARGS fi else python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS diff --git a/hatch.toml b/hatch.toml index 3821dcc822..8b1cf93e32 100644 --- a/hatch.toml +++ b/hatch.toml @@ -39,7 +39,7 @@ run-manual = "pre-commit run --all-files --hook-stage manual" [envs.test] features = ["test"] [envs.test.scripts] -test = "pytest -v --durations=5 {args}" +test = "pytest -v --durations=5 --maxfail=10 {args}" test-eg = "bash ./.evergreen/run-tests.sh {args}" test-async = "pytest -v --durations=5 --maxfail=10 -m default_async {args}" test-mockupdb = ["pip install -U git+https://github.com/ajdavis/mongo-mockup-db@master", "test -m mockupdb"] diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 6de9233747..35c7e4a43a 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -181,11 +181,13 @@ def serverless_ok(run_on_req): async def should_run_on(self, scenario_def): run_on = scenario_def.get("runOn", []) + print(f"RUN_ON: {run_on}") if not run_on: # Always run these tests. return True for req in run_on: + print(f"REQ: {req}") if ( await self.valid_topology(req) and self.min_server_version(req) @@ -193,7 +195,10 @@ async def should_run_on(self, scenario_def): and self.valid_auth_enabled(req) and self.serverless_ok(req) ): + print(f"REQ passes: {req}") return True + else: + print(f"REQ fails: {req}") return False def ensure_run_on(self, scenario_def, method): diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index fc7aa1f5ee..71eba9ee8a 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -181,11 +181,13 @@ def serverless_ok(run_on_req): def should_run_on(self, scenario_def): run_on = scenario_def.get("runOn", []) + print(f"RUN_ON: {run_on}") if not run_on: # Always run these tests. return True for req in run_on: + print(f"REQ: {req}") if ( self.valid_topology(req) and self.min_server_version(req) @@ -193,7 +195,10 @@ def should_run_on(self, scenario_def): and self.valid_auth_enabled(req) and self.serverless_ok(req) ): + print(f"REQ passes: {req}") return True + else: + print(f"REQ fails: {req}") return False def ensure_run_on(self, scenario_def, method): From 689b0b825df004dd5508420f7da14c3e517b85da Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 27 Sep 2024 09:42:38 -0400 Subject: [PATCH 07/16] Cleanup --- test/asynchronous/test_encryption.py | 2 +- test/test_encryption.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index a444ec301a..0dcea367c7 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -706,7 +706,7 @@ async def setup_scenario(self, scenario_def): db_name = self.get_scenario_db_name(scenario_def) coll_name = self.get_scenario_coll_name(scenario_def) db = async_client_context.client.get_database(db_name, codec_options=OPTS) - coll = await db.drop_collection(coll_name, encrypted_fields=encrypted_fields) + await db.drop_collection(coll_name, encrypted_fields=encrypted_fields) wc = WriteConcern(w="majority") kwargs: Dict[str, Any] = {} if json_schema: diff --git a/test/test_encryption.py b/test/test_encryption.py index 7851be1547..68469e5d66 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -704,7 +704,7 @@ def setup_scenario(self, scenario_def): db_name = self.get_scenario_db_name(scenario_def) coll_name = self.get_scenario_coll_name(scenario_def) db = client_context.client.get_database(db_name, codec_options=OPTS) - coll = db.drop_collection(coll_name, encrypted_fields=encrypted_fields) + db.drop_collection(coll_name, encrypted_fields=encrypted_fields) wc = WriteConcern(w="majority") kwargs: Dict[str, Any] = {} if json_schema: From 6feec443e011d681bab01af4ac53bbbe1e8d522f Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 27 Sep 2024 10:31:54 -0400 Subject: [PATCH 08/16] WIP --- .evergreen/run-tests.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 8d7a9f082a..5e8429dd28 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -257,9 +257,9 @@ if [ -z "$GREEN_FRAMEWORK" ]; then # Use --capture=tee-sys so pytest prints test output inline: # https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html if [ -z "$TEST_SUITES" ]; then - python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS + python -m pytest -v --capture=tee-sys --durations=5 $TEST_ARGS else - python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 -m $TEST_SUITES $TEST_ARGS + python -m pytest -v --capture=tee-sys --durations=5 -m $TEST_SUITES $TEST_ARGS fi else python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS From b472d3a0fc48aef554e7ea905380c17255f0380b Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 3 Oct 2024 16:40:52 -0400 Subject: [PATCH 09/16] Async KMS --- pymongo/asynchronous/encryption.py | 57 ++++++++++++++++++++++-------- pymongo/network_layer.py | 34 ++++++++++++++++++ pymongo/synchronous/encryption.py | 57 ++++++++++++++++++++++-------- 3 files changed, 120 insertions(+), 28 deletions(-) diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 9b00c13e10..cb7d6b0ab2 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -142,6 +142,32 @@ def __init__( self.opts = opts self._spawned = False + async def _async_kms_request( + self, kms_context: MongoCryptKmsContext, host, port, opts, message + ) -> None: + from pymongo.network_layer import _async_receive_data_socket + + try: + conn = await _configured_socket((host, port), opts) + try: + await async_sendall(conn, message) + while kms_context.bytes_needed > 0: + # CSOT: update timeout. + conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) + data = await _async_receive_data_socket(conn, kms_context.bytes_needed) + if not data: + raise OSError("KMS connection closed") + kms_context.feed(data) + except BLOCKING_IO_ERRORS: + raise socket.timeout("timed out") from None + finally: + conn.close() + except (PyMongoError, MongoCryptError): + raise # Propagate pymongo errors directly. + except Exception as error: + # Wrap I/O errors in PyMongo exceptions. + _raise_connection_failure((host, port), error) + async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: """Complete a KMS request. @@ -174,20 +200,23 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: ) host, port = parse_host(endpoint, _HTTPS_PORT) try: - conn = await _configured_socket((host, port), opts) - try: - await async_sendall(conn, message) - while kms_context.bytes_needed > 0: - # CSOT: update timeout. - conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - data = conn.recv(kms_context.bytes_needed) - if not data: - raise OSError("KMS connection closed") - kms_context.feed(data) - except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") from None - finally: - conn.close() + if _IS_SYNC: + conn = await _configured_socket((host, port), opts) + try: + await async_sendall(conn, message) + while kms_context.bytes_needed > 0: + # CSOT: update timeout. + conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) + data = conn.recv(kms_context.bytes_needed) + if not data: + raise OSError("KMS connection closed") + kms_context.feed(data) + except BLOCKING_IO_ERRORS: + raise socket.timeout("timed out") from None + finally: + conn.close() + else: + return await self._async_kms_request(kms_context, host, port, opts, message) except (PyMongoError, MongoCryptError): raise # Propagate pymongo errors directly. except Exception as error: diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 4b57620d83..4e7fd1c4ee 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -275,6 +275,40 @@ async def async_receive_data( sock.settimeout(sock_timeout) +async def _async_receive_data_socket( + sock: socket.socket | _sslConn, length: int, deadline: Optional[float] +) -> memoryview: + sock_timeout = sock.gettimeout() + timeout: Optional[Union[float, int]] + if deadline: + # When the timeout has expired perform one final check to + # see if the socket is readable. This helps avoid spurious + # timeouts on AWS Lambda and other FaaS environments. + timeout = max(deadline - time.monotonic(), 0) + else: + timeout = sock_timeout + + sock.settimeout(0.0) + loop = asyncio.get_event_loop() + try: + if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): + read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type] + else: + read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type] + tasks = [read_task] + done, pending = await asyncio.wait( + tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED + ) + for task in pending: + task.cancel() + await asyncio.wait(pending) + if read_task in done: + return read_task.result() + raise socket.timeout("timed out") + finally: + sock.settimeout(sock_timeout) + + async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview: mv = memoryview(bytearray(length)) bytes_read = 0 diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index efef6df9e8..77e17876e4 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -142,6 +142,32 @@ def __init__( self.opts = opts self._spawned = False + def _async_kms_request( + self, kms_context: MongoCryptKmsContext, host, port, opts, message + ) -> None: + from pymongo.network_layer import _receive_data_socket + + try: + conn = _configured_socket((host, port), opts) + try: + sendall(conn, message) + while kms_context.bytes_needed > 0: + # CSOT: update timeout. + conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) + data = _receive_data_socket(conn, kms_context.bytes_needed) + if not data: + raise OSError("KMS connection closed") + kms_context.feed(data) + except BLOCKING_IO_ERRORS: + raise socket.timeout("timed out") from None + finally: + conn.close() + except (PyMongoError, MongoCryptError): + raise # Propagate pymongo errors directly. + except Exception as error: + # Wrap I/O errors in PyMongo exceptions. + _raise_connection_failure((host, port), error) + def kms_request(self, kms_context: MongoCryptKmsContext) -> None: """Complete a KMS request. @@ -174,20 +200,23 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None: ) host, port = parse_host(endpoint, _HTTPS_PORT) try: - conn = _configured_socket((host, port), opts) - try: - sendall(conn, message) - while kms_context.bytes_needed > 0: - # CSOT: update timeout. - conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - data = conn.recv(kms_context.bytes_needed) - if not data: - raise OSError("KMS connection closed") - kms_context.feed(data) - except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") from None - finally: - conn.close() + if _IS_SYNC: + conn = _configured_socket((host, port), opts) + try: + sendall(conn, message) + while kms_context.bytes_needed > 0: + # CSOT: update timeout. + conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) + data = conn.recv(kms_context.bytes_needed) + if not data: + raise OSError("KMS connection closed") + kms_context.feed(data) + except BLOCKING_IO_ERRORS: + raise socket.timeout("timed out") from None + finally: + conn.close() + else: + return self._async_kms_request(kms_context, host, port, opts, message) except (PyMongoError, MongoCryptError): raise # Propagate pymongo errors directly. except Exception as error: From 3bdd38113d0a753cad5036ba3fe4312ff2cb7bd6 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 7 Oct 2024 09:46:09 -0400 Subject: [PATCH 10/16] Fix deadline --- pymongo/asynchronous/encryption.py | 2 +- pymongo/synchronous/encryption.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index cb7d6b0ab2..1d9057394d 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -154,7 +154,7 @@ async def _async_kms_request( while kms_context.bytes_needed > 0: # CSOT: update timeout. conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - data = await _async_receive_data_socket(conn, kms_context.bytes_needed) + data = await _async_receive_data_socket(conn, kms_context.bytes_needed, None) if not data: raise OSError("KMS connection closed") kms_context.feed(data) diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index 77e17876e4..7f640c8598 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -154,7 +154,7 @@ def _async_kms_request( while kms_context.bytes_needed > 0: # CSOT: update timeout. conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - data = _receive_data_socket(conn, kms_context.bytes_needed) + data = _receive_data_socket(conn, kms_context.bytes_needed, None) if not data: raise OSError("KMS connection closed") kms_context.feed(data) From 8e6a30d526ca8537e51eb23f8bcbe446795e7ffb Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 7 Oct 2024 11:27:46 -0400 Subject: [PATCH 11/16] Fix async KMS --- pymongo/asynchronous/encryption.py | 8 +++---- pymongo/network_layer.py | 35 ++++++++++-------------------- pymongo/synchronous/encryption.py | 8 +++---- 3 files changed, 19 insertions(+), 32 deletions(-) diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 1d9057394d..37ec3320ee 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -154,10 +154,10 @@ async def _async_kms_request( while kms_context.bytes_needed > 0: # CSOT: update timeout. conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - data = await _async_receive_data_socket(conn, kms_context.bytes_needed, None) - if not data: - raise OSError("KMS connection closed") + data = await _async_receive_data_socket(conn, kms_context.bytes_needed) kms_context.feed(data) + except OSError as err: + raise OSError("KMS connection closed") from err except BLOCKING_IO_ERRORS: raise socket.timeout("timed out") from None finally: @@ -216,7 +216,7 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: finally: conn.close() else: - return await self._async_kms_request(kms_context, host, port, opts, message) + await self._async_kms_request(kms_context, host, port, opts, message) except (PyMongoError, MongoCryptError): raise # Propagate pymongo errors directly. except Exception as error: diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 4e7fd1c4ee..8d1a6464f2 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -130,7 +130,7 @@ def _is_ready(fut: Future) -> None: loop.remove_writer(fd) async def _async_receive_ssl( - conn: _sslConn, length: int, loop: AbstractEventLoop + conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False ) -> memoryview: mv = memoryview(bytearray(length)) total_read = 0 @@ -145,6 +145,8 @@ def _is_ready(fut: Future) -> None: read = conn.recv_into(mv[total_read:]) if read == 0: raise OSError("connection closed") + if once: + return mv[:read] total_read += read except BLOCKING_IO_ERRORS as exc: fd = conn.fileno() @@ -275,36 +277,21 @@ async def async_receive_data( sock.settimeout(sock_timeout) -async def _async_receive_data_socket( - sock: socket.socket | _sslConn, length: int, deadline: Optional[float] -) -> memoryview: +async def _async_receive_data_socket(sock: socket.socket | _sslConn, length: int) -> memoryview: sock_timeout = sock.gettimeout() - timeout: Optional[Union[float, int]] - if deadline: - # When the timeout has expired perform one final check to - # see if the socket is readable. This helps avoid spurious - # timeouts on AWS Lambda and other FaaS environments. - timeout = max(deadline - time.monotonic(), 0) - else: - timeout = sock_timeout + timeout = sock_timeout sock.settimeout(0.0) loop = asyncio.get_event_loop() try: if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): - read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type] + return await asyncio.wait_for( + _async_receive_ssl(sock, length, loop, once=True), timeout=timeout + ) # type: ignore[arg-type] else: - read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type] - tasks = [read_task] - done, pending = await asyncio.wait( - tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED - ) - for task in pending: - task.cancel() - await asyncio.wait(pending) - if read_task in done: - return read_task.result() - raise socket.timeout("timed out") + return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type] + except asyncio.TimeoutError as err: + raise socket.timeout("timed out") from err finally: sock.settimeout(sock_timeout) diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index 7f640c8598..5bc97144da 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -154,10 +154,10 @@ def _async_kms_request( while kms_context.bytes_needed > 0: # CSOT: update timeout. conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - data = _receive_data_socket(conn, kms_context.bytes_needed, None) - if not data: - raise OSError("KMS connection closed") + data = _receive_data_socket(conn, kms_context.bytes_needed) kms_context.feed(data) + except OSError as err: + raise OSError("KMS connection closed") from err except BLOCKING_IO_ERRORS: raise socket.timeout("timed out") from None finally: @@ -216,7 +216,7 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None: finally: conn.close() else: - return self._async_kms_request(kms_context, host, port, opts, message) + self._async_kms_request(kms_context, host, port, opts, message) except (PyMongoError, MongoCryptError): raise # Propagate pymongo errors directly. except Exception as error: From 8ca4b0ec26fc633c110b19c371a99e74e8c9f842 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 7 Oct 2024 13:15:43 -0400 Subject: [PATCH 12/16] Fix _require --- test/__init__.py | 11 ++++++----- test/asynchronous/__init__.py | 11 ++++++----- test/asynchronous/utils_spec_runner.py | 7 +------ test/utils_spec_runner.py | 7 +------ 4 files changed, 14 insertions(+), 22 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index af12bc032a..fd33fde293 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -464,11 +464,12 @@ def wrap(*args, **kwargs): if not self.connected: pair = self.pair raise SkipTest(f"Cannot connect to MongoDB on {pair}") - if iscoroutinefunction(condition) and condition(): - if wraps_async: - return f(*args, **kwargs) - else: - return f(*args, **kwargs) + if iscoroutinefunction(condition): + if condition(): + if wraps_async: + return f(*args, **kwargs) + else: + return f(*args, **kwargs) elif condition(): if wraps_async: return f(*args, **kwargs) diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 2a44785b2f..0579828c49 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -466,11 +466,12 @@ async def wrap(*args, **kwargs): if not self.connected: pair = await self.pair raise SkipTest(f"Cannot connect to MongoDB on {pair}") - if iscoroutinefunction(condition) and await condition(): - if wraps_async: - return await f(*args, **kwargs) - else: - return f(*args, **kwargs) + if iscoroutinefunction(condition): + if await condition(): + if wraps_async: + return await f(*args, **kwargs) + else: + return f(*args, **kwargs) elif condition(): if wraps_async: return await f(*args, **kwargs) diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 7925c9662b..4d9c4c8f20 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -180,13 +180,11 @@ def serverless_ok(run_on_req): async def should_run_on(self, scenario_def): run_on = scenario_def.get("runOn", []) - print(f"RUN_ON: {run_on}") if not run_on: # Always run these tests. return True for req in run_on: - print(f"REQ: {req}") if ( await self.valid_topology(req) and self.min_server_version(req) @@ -194,10 +192,7 @@ async def should_run_on(self, scenario_def): and self.valid_auth_enabled(req) and self.serverless_ok(req) ): - print(f"REQ passes: {req}") return True - else: - print(f"REQ fails: {req}") return False def ensure_run_on(self, scenario_def, method): @@ -206,7 +201,7 @@ def ensure_run_on(self, scenario_def, method): async def predicate(): return await self.should_run_on(scenario_def) - return async_client_context._require(lambda: predicate, "runOn not satisfied", method) + return async_client_context._require(predicate, "runOn not satisfied", method) def tests(self, scenario_def): """Allow CMAP spec test to override the location of test.""" diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index ef16402b35..8a061de0b1 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -180,13 +180,11 @@ def serverless_ok(run_on_req): def should_run_on(self, scenario_def): run_on = scenario_def.get("runOn", []) - print(f"RUN_ON: {run_on}") if not run_on: # Always run these tests. return True for req in run_on: - print(f"REQ: {req}") if ( self.valid_topology(req) and self.min_server_version(req) @@ -194,10 +192,7 @@ def should_run_on(self, scenario_def): and self.valid_auth_enabled(req) and self.serverless_ok(req) ): - print(f"REQ passes: {req}") return True - else: - print(f"REQ fails: {req}") return False def ensure_run_on(self, scenario_def, method): @@ -206,7 +201,7 @@ def ensure_run_on(self, scenario_def, method): def predicate(): return self.should_run_on(scenario_def) - return client_context._require(lambda: predicate, "runOn not satisfied", method) + return client_context._require(predicate, "runOn not satisfied", method) def tests(self, scenario_def): """Allow CMAP spec test to override the location of test.""" From a229746b9e9365cb633b14df2100fd8f4b7e3981 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 7 Oct 2024 15:59:06 -0400 Subject: [PATCH 13/16] Type error fixes --- pymongo/asynchronous/encryption.py | 11 ++++++++--- pymongo/network_layer.py | 9 ++++++--- pymongo/synchronous/encryption.py | 11 ++++++++--- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 37ec3320ee..79510c4bd2 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -143,9 +143,14 @@ def __init__( self._spawned = False async def _async_kms_request( - self, kms_context: MongoCryptKmsContext, host, port, opts, message + self, + kms_context: MongoCryptKmsContext, + host: str, + port: Optional[int], + opts: PoolOptions, + message: bytes, ) -> None: - from pymongo.network_layer import _async_receive_data_socket + from pymongo.network_layer import async_receive_data_socket # type: ignore[attr-defined] try: conn = await _configured_socket((host, port), opts) @@ -154,7 +159,7 @@ async def _async_kms_request( while kms_context.bytes_needed > 0: # CSOT: update timeout. conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - data = await _async_receive_data_socket(conn, kms_context.bytes_needed) + data = await async_receive_data_socket(conn, kms_context.bytes_needed) kms_context.feed(data) except OSError as err: raise OSError("KMS connection closed") from err diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 8d1a6464f2..e7aa077f44 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -277,7 +277,9 @@ async def async_receive_data( sock.settimeout(sock_timeout) -async def _async_receive_data_socket(sock: socket.socket | _sslConn, length: int) -> memoryview: +async def async_receive_data_socket( + sock: Union[socket.socket, _sslConn], length: int +) -> memoryview: sock_timeout = sock.gettimeout() timeout = sock_timeout @@ -286,8 +288,9 @@ async def _async_receive_data_socket(sock: socket.socket | _sslConn, length: int try: if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): return await asyncio.wait_for( - _async_receive_ssl(sock, length, loop, once=True), timeout=timeout - ) # type: ignore[arg-type] + _async_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type] + timeout=timeout, + ) else: return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type] except asyncio.TimeoutError as err: diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index 5bc97144da..68b43c391f 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -143,9 +143,14 @@ def __init__( self._spawned = False def _async_kms_request( - self, kms_context: MongoCryptKmsContext, host, port, opts, message + self, + kms_context: MongoCryptKmsContext, + host: str, + port: Optional[int], + opts: PoolOptions, + message: bytes, ) -> None: - from pymongo.network_layer import _receive_data_socket + from pymongo.network_layer import receive_data_socket # type: ignore[attr-defined] try: conn = _configured_socket((host, port), opts) @@ -154,7 +159,7 @@ def _async_kms_request( while kms_context.bytes_needed > 0: # CSOT: update timeout. conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - data = _receive_data_socket(conn, kms_context.bytes_needed) + data = receive_data_socket(conn, kms_context.bytes_needed) kms_context.feed(data) except OSError as err: raise OSError("KMS connection closed") from err From 59a95cce214cb50bbd6df669f50b835ce80b9cec Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 9 Oct 2024 13:58:33 -0400 Subject: [PATCH 14/16] 10x timeoutMS listCollection timeouts for socket consistency on Linux --- test/client-side-encryption/spec/legacy/timeoutMS.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/client-side-encryption/spec/legacy/timeoutMS.json b/test/client-side-encryption/spec/legacy/timeoutMS.json index b667767cfc..8411306224 100644 --- a/test/client-side-encryption/spec/legacy/timeoutMS.json +++ b/test/client-side-encryption/spec/legacy/timeoutMS.json @@ -110,7 +110,7 @@ "listCollections" ], "blockConnection": true, - "blockTimeMS": 60 + "blockTimeMS": 600 } }, "clientOptions": { @@ -119,7 +119,7 @@ "aws": {} } }, - "timeoutMS": 50 + "timeoutMS": 500 }, "operations": [ { From 72303d6a516b56291caabe381ffb12ab2d206e80 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 10 Oct 2024 10:30:38 -0400 Subject: [PATCH 15/16] Cleanup --- pymongo/asynchronous/encryption.py | 70 ++++++++++-------------------- pymongo/synchronous/encryption.py | 70 ++++++++++-------------------- 2 files changed, 46 insertions(+), 94 deletions(-) diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 79510c4bd2..f7094f0817 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -142,37 +142,6 @@ def __init__( self.opts = opts self._spawned = False - async def _async_kms_request( - self, - kms_context: MongoCryptKmsContext, - host: str, - port: Optional[int], - opts: PoolOptions, - message: bytes, - ) -> None: - from pymongo.network_layer import async_receive_data_socket # type: ignore[attr-defined] - - try: - conn = await _configured_socket((host, port), opts) - try: - await async_sendall(conn, message) - while kms_context.bytes_needed > 0: - # CSOT: update timeout. - conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - data = await async_receive_data_socket(conn, kms_context.bytes_needed) - kms_context.feed(data) - except OSError as err: - raise OSError("KMS connection closed") from err - except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") from None - finally: - conn.close() - except (PyMongoError, MongoCryptError): - raise # Propagate pymongo errors directly. - except Exception as error: - # Wrap I/O errors in PyMongo exceptions. - _raise_connection_failure((host, port), error) - async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: """Complete a KMS request. @@ -205,23 +174,30 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: ) host, port = parse_host(endpoint, _HTTPS_PORT) try: - if _IS_SYNC: - conn = await _configured_socket((host, port), opts) - try: - await async_sendall(conn, message) - while kms_context.bytes_needed > 0: - # CSOT: update timeout. - conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) + conn = await _configured_socket((host, port), opts) + try: + await async_sendall(conn, message) + while kms_context.bytes_needed > 0: + # CSOT: update timeout. + conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) + if _IS_SYNC: data = conn.recv(kms_context.bytes_needed) - if not data: - raise OSError("KMS connection closed") - kms_context.feed(data) - except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") from None - finally: - conn.close() - else: - await self._async_kms_request(kms_context, host, port, opts, message) + else: + from pymongo.network_layer import ( + async_receive_data_socket, # type: ignore[attr-defined] + ) + + data = await async_receive_data_socket(conn, kms_context.bytes_needed) + if not data: + raise OSError("KMS connection closed") + kms_context.feed(data) + # Async raises an OSError instead of returning empty bytes + except OSError as err: + raise OSError("KMS connection closed") from err + except BLOCKING_IO_ERRORS: + raise socket.timeout("timed out") from None + finally: + conn.close() except (PyMongoError, MongoCryptError): raise # Propagate pymongo errors directly. except Exception as error: diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index 68b43c391f..4f74381fed 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -142,37 +142,6 @@ def __init__( self.opts = opts self._spawned = False - def _async_kms_request( - self, - kms_context: MongoCryptKmsContext, - host: str, - port: Optional[int], - opts: PoolOptions, - message: bytes, - ) -> None: - from pymongo.network_layer import receive_data_socket # type: ignore[attr-defined] - - try: - conn = _configured_socket((host, port), opts) - try: - sendall(conn, message) - while kms_context.bytes_needed > 0: - # CSOT: update timeout. - conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - data = receive_data_socket(conn, kms_context.bytes_needed) - kms_context.feed(data) - except OSError as err: - raise OSError("KMS connection closed") from err - except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") from None - finally: - conn.close() - except (PyMongoError, MongoCryptError): - raise # Propagate pymongo errors directly. - except Exception as error: - # Wrap I/O errors in PyMongo exceptions. - _raise_connection_failure((host, port), error) - def kms_request(self, kms_context: MongoCryptKmsContext) -> None: """Complete a KMS request. @@ -205,23 +174,30 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None: ) host, port = parse_host(endpoint, _HTTPS_PORT) try: - if _IS_SYNC: - conn = _configured_socket((host, port), opts) - try: - sendall(conn, message) - while kms_context.bytes_needed > 0: - # CSOT: update timeout. - conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) + conn = _configured_socket((host, port), opts) + try: + sendall(conn, message) + while kms_context.bytes_needed > 0: + # CSOT: update timeout. + conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) + if _IS_SYNC: data = conn.recv(kms_context.bytes_needed) - if not data: - raise OSError("KMS connection closed") - kms_context.feed(data) - except BLOCKING_IO_ERRORS: - raise socket.timeout("timed out") from None - finally: - conn.close() - else: - self._async_kms_request(kms_context, host, port, opts, message) + else: + from pymongo.network_layer import ( + receive_data_socket, # type: ignore[attr-defined] + ) + + data = receive_data_socket(conn, kms_context.bytes_needed) + if not data: + raise OSError("KMS connection closed") + kms_context.feed(data) + # Async raises an OSError instead of returning empty bytes + except OSError as err: + raise OSError("KMS connection closed") from err + except BLOCKING_IO_ERRORS: + raise socket.timeout("timed out") from None + finally: + conn.close() except (PyMongoError, MongoCryptError): raise # Propagate pymongo errors directly. except Exception as error: From 2a772854c97c53690c009f3efa05c6e65f14c14b Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 10 Oct 2024 10:48:56 -0400 Subject: [PATCH 16/16] Typing fixes + comment --- pymongo/asynchronous/encryption.py | 4 ++-- pymongo/network_layer.py | 1 + pymongo/synchronous/encryption.py | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index f7094f0817..735e543047 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -183,8 +183,8 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: if _IS_SYNC: data = conn.recv(kms_context.bytes_needed) else: - from pymongo.network_layer import ( - async_receive_data_socket, # type: ignore[attr-defined] + from pymongo.network_layer import ( # type: ignore[attr-defined] + async_receive_data_socket, ) data = await async_receive_data_socket(conn, kms_context.bytes_needed) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index e7aa077f44..d14a21f41d 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -145,6 +145,7 @@ def _is_ready(fut: Future) -> None: read = conn.recv_into(mv[total_read:]) if read == 0: raise OSError("connection closed") + # KMS responses update their expected size after the first batch, stop reading after one loop if once: return mv[:read] total_read += read diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index 4f74381fed..506ff8bcba 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -183,8 +183,8 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None: if _IS_SYNC: data = conn.recv(kms_context.bytes_needed) else: - from pymongo.network_layer import ( - receive_data_socket, # type: ignore[attr-defined] + from pymongo.network_layer import ( # type: ignore[attr-defined] + receive_data_socket, ) data = receive_data_socket(conn, kms_context.bytes_needed)