From 8118aea985f017457259bff78e64656232f08eb5 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 11 Oct 2024 08:29:12 -0400 Subject: [PATCH 01/19] =?UTF-8?q?PYTHON-4844=20-=20Skip=20async=20test=5Fe?= =?UTF-8?q?ncryption.AsyncTestSpec.test=5Flegacy=5Fti=E2=80=A6=20(#1914)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/asynchronous/test_encryption.py | 5 +++++ test/test_encryption.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index c3f6223384..3e52fb9e1b 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -693,6 +693,11 @@ def maybe_skip_scenario(self, test): self.skipTest("PYTHON-3706 flaky test on Windows/macOS") if "type=symbol" in desc: self.skipTest("PyMongo does not support the symbol type") + if ( + "timeoutms applied to listcollections to get collection schema" in desc + and not _IS_SYNC + ): + self.skipTest("PYTHON-4844 flaky test on async") def setup_scenario(self, scenario_def): """Override a test's setup.""" diff --git a/test/test_encryption.py b/test/test_encryption.py index 43c85e2c5b..64aa7ebf50 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -691,6 +691,11 @@ def maybe_skip_scenario(self, test): self.skipTest("PYTHON-3706 flaky test on Windows/macOS") if "type=symbol" in desc: self.skipTest("PyMongo does not support the symbol type") + if ( + "timeoutms applied to listcollections to get collection schema" in desc + and not _IS_SYNC + ): + self.skipTest("PYTHON-4844 flaky test on async") def setup_scenario(self, scenario_def): """Override a test's setup.""" From 3a662291e010cbed832c00aff8ffe7b43d470489 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 11 Oct 2024 10:48:24 -0400 Subject: [PATCH 02/19] PYTHON-4700 - Convert CSFLE tests to async (#1907) --- .evergreen/run-tests.sh | 4 +- pymongo/asynchronous/encryption.py | 12 +- pymongo/network_layer.py | 27 +- pymongo/synchronous/encryption.py | 12 +- test/__init__.py | 11 +- test/asynchronous/__init__.py | 11 +- test/asynchronous/test_encryption.py | 257 +++++++++--------- test/asynchronous/utils_spec_runner.py | 172 +++++++++++- .../spec/legacy/timeoutMS.json | 4 +- test/test_connection_monitoring.py | 3 +- test/test_encryption.py | 255 +++++++++-------- test/test_server_selection_in_window.py | 2 +- test/utils.py | 147 ---------- test/utils_spec_runner.py | 170 +++++++++++- tools/synchro.py | 2 + 15 files changed, 655 insertions(+), 434 deletions(-) diff --git a/.evergreen/run-tests.sh b/.evergreen/run-tests.sh index 8d7a9f082a..5e8429dd28 100755 --- a/.evergreen/run-tests.sh +++ b/.evergreen/run-tests.sh @@ -257,9 +257,9 @@ if [ -z "$GREEN_FRAMEWORK" ]; then # Use --capture=tee-sys so pytest prints test output inline: # https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html if [ -z "$TEST_SUITES" ]; then - python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS + python -m pytest -v --capture=tee-sys --durations=5 $TEST_ARGS else - python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 -m $TEST_SUITES $TEST_ARGS + python -m pytest -v --capture=tee-sys --durations=5 -m $TEST_SUITES $TEST_ARGS fi else python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 9b00c13e10..735e543047 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -180,10 +180,20 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: while kms_context.bytes_needed > 0: # CSOT: update timeout. conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - data = conn.recv(kms_context.bytes_needed) + if _IS_SYNC: + data = conn.recv(kms_context.bytes_needed) + else: + from pymongo.network_layer import ( # type: ignore[attr-defined] + async_receive_data_socket, + ) + + data = await async_receive_data_socket(conn, kms_context.bytes_needed) if not data: raise OSError("KMS connection closed") kms_context.feed(data) + # Async raises an OSError instead of returning empty bytes + except OSError as err: + raise OSError("KMS connection closed") from err except BLOCKING_IO_ERRORS: raise socket.timeout("timed out") from None finally: diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 4b57620d83..d14a21f41d 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -130,7 +130,7 @@ def _is_ready(fut: Future) -> None: loop.remove_writer(fd) async def _async_receive_ssl( - conn: _sslConn, length: int, loop: AbstractEventLoop + conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False ) -> memoryview: mv = memoryview(bytearray(length)) total_read = 0 @@ -145,6 +145,9 @@ def _is_ready(fut: Future) -> None: read = conn.recv_into(mv[total_read:]) if read == 0: raise OSError("connection closed") + # KMS responses update their expected size after the first batch, stop reading after one loop + if once: + return mv[:read] total_read += read except BLOCKING_IO_ERRORS as exc: fd = conn.fileno() @@ -275,6 +278,28 @@ async def async_receive_data( sock.settimeout(sock_timeout) +async def async_receive_data_socket( + sock: Union[socket.socket, _sslConn], length: int +) -> memoryview: + sock_timeout = sock.gettimeout() + timeout = sock_timeout + + sock.settimeout(0.0) + loop = asyncio.get_event_loop() + try: + if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): + return await asyncio.wait_for( + _async_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type] + timeout=timeout, + ) + else: + return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type] + except asyncio.TimeoutError as err: + raise socket.timeout("timed out") from err + finally: + sock.settimeout(sock_timeout) + + async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview: mv = memoryview(bytearray(length)) bytes_read = 0 diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index efef6df9e8..506ff8bcba 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -180,10 +180,20 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None: while kms_context.bytes_needed > 0: # CSOT: update timeout. conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0)) - data = conn.recv(kms_context.bytes_needed) + if _IS_SYNC: + data = conn.recv(kms_context.bytes_needed) + else: + from pymongo.network_layer import ( # type: ignore[attr-defined] + receive_data_socket, + ) + + data = receive_data_socket(conn, kms_context.bytes_needed) if not data: raise OSError("KMS connection closed") kms_context.feed(data) + # Async raises an OSError instead of returning empty bytes + except OSError as err: + raise OSError("KMS connection closed") from err except BLOCKING_IO_ERRORS: raise socket.timeout("timed out") from None finally: diff --git a/test/__init__.py b/test/__init__.py index af12bc032a..fd33fde293 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -464,11 +464,12 @@ def wrap(*args, **kwargs): if not self.connected: pair = self.pair raise SkipTest(f"Cannot connect to MongoDB on {pair}") - if iscoroutinefunction(condition) and condition(): - if wraps_async: - return f(*args, **kwargs) - else: - return f(*args, **kwargs) + if iscoroutinefunction(condition): + if condition(): + if wraps_async: + return f(*args, **kwargs) + else: + return f(*args, **kwargs) elif condition(): if wraps_async: return f(*args, **kwargs) diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 2a44785b2f..0579828c49 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -466,11 +466,12 @@ async def wrap(*args, **kwargs): if not self.connected: pair = await self.pair raise SkipTest(f"Cannot connect to MongoDB on {pair}") - if iscoroutinefunction(condition) and await condition(): - if wraps_async: - return await f(*args, **kwargs) - else: - return f(*args, **kwargs) + if iscoroutinefunction(condition): + if await condition(): + if wraps_async: + return await f(*args, **kwargs) + else: + return f(*args, **kwargs) elif condition(): if wraps_async: return await f(*args, **kwargs) diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 3e52fb9e1b..88b005c4b3 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -30,6 +30,7 @@ import warnings from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, async_client_context from test.asynchronous.test_bulk import AsyncBulkTestBase +from test.asynchronous.utils_spec_runner import AsyncSpecRunner, AsyncSpecTestCreator from threading import Thread from typing import Any, Dict, Mapping, Optional @@ -59,7 +60,6 @@ from test.utils import ( AllowListEventListener, OvertCommandListener, - SpecTestCreator, TopologyEventListener, async_wait_until, camel_to_snake_args, @@ -626,137 +626,132 @@ async def test_with_statement(self): KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}} -if _IS_SYNC: - # TODO: Add asynchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700) - class TestSpec(AsyncSpecRunner): - @classmethod - @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") - def setUpClass(cls): - super().setUpClass() - - def parse_auto_encrypt_opts(self, opts): - """Parse clientOptions.autoEncryptOpts.""" - opts = camel_to_snake_args(opts) - kms_providers = opts["kms_providers"] - if "aws" in kms_providers: - kms_providers["aws"] = AWS_CREDS - if not any(AWS_CREDS.values()): - self.skipTest("AWS environment credentials are not set") - if "awsTemporary" in kms_providers: - kms_providers["aws"] = AWS_TEMP_CREDS - del kms_providers["awsTemporary"] - if not any(AWS_TEMP_CREDS.values()): - self.skipTest("AWS Temp environment credentials are not set") - if "awsTemporaryNoSessionToken" in kms_providers: - kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS - del kms_providers["awsTemporaryNoSessionToken"] - if not any(AWS_TEMP_NO_SESSION_CREDS.values()): - self.skipTest("AWS Temp environment credentials are not set") - if "azure" in kms_providers: - kms_providers["azure"] = AZURE_CREDS - if not any(AZURE_CREDS.values()): - self.skipTest("Azure environment credentials are not set") - if "gcp" in kms_providers: - kms_providers["gcp"] = GCP_CREDS - if not any(AZURE_CREDS.values()): - self.skipTest("GCP environment credentials are not set") - if "kmip" in kms_providers: - kms_providers["kmip"] = KMIP_CREDS - opts["kms_tls_options"] = KMS_TLS_OPTS - if "key_vault_namespace" not in opts: - opts["key_vault_namespace"] = "keyvault.datakeys" - if "extra_options" in opts: - opts.update(camel_to_snake_args(opts.pop("extra_options"))) - - opts = dict(opts) - return AutoEncryptionOpts(**opts) - - def parse_client_options(self, opts): - """Override clientOptions parsing to support autoEncryptOpts.""" - encrypt_opts = opts.pop("autoEncryptOpts", None) - if encrypt_opts: - opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts) - - return super().parse_client_options(opts) - - def get_object_name(self, op): - """Default object is collection.""" - return op.get("object", "collection") - - def maybe_skip_scenario(self, test): - super().maybe_skip_scenario(test) - desc = test["description"].lower() - if ( - "timeoutms applied to listcollections to get collection schema" in desc - and sys.platform in ("win32", "darwin") - ): - self.skipTest("PYTHON-3706 flaky test on Windows/macOS") - if "type=symbol" in desc: - self.skipTest("PyMongo does not support the symbol type") - if ( - "timeoutms applied to listcollections to get collection schema" in desc - and not _IS_SYNC - ): - self.skipTest("PYTHON-4844 flaky test on async") - - def setup_scenario(self, scenario_def): - """Override a test's setup.""" - key_vault_data = scenario_def["key_vault_data"] - encrypted_fields = scenario_def["encrypted_fields"] - json_schema = scenario_def["json_schema"] - data = scenario_def["data"] - coll = async_client_context.client.get_database("keyvault", codec_options=OPTS)[ - "datakeys" - ] - coll.delete_many({}) - if key_vault_data: - coll.insert_many(key_vault_data) - - db_name = self.get_scenario_db_name(scenario_def) - coll_name = self.get_scenario_coll_name(scenario_def) - db = async_client_context.client.get_database(db_name, codec_options=OPTS) - coll = db.drop_collection(coll_name, encrypted_fields=encrypted_fields) - wc = WriteConcern(w="majority") - kwargs: Dict[str, Any] = {} - if json_schema: - kwargs["validator"] = {"$jsonSchema": json_schema} - kwargs["codec_options"] = OPTS - if not data: - kwargs["write_concern"] = wc - if encrypted_fields: - kwargs["encryptedFields"] = encrypted_fields - db.create_collection(coll_name, **kwargs) - coll = db[coll_name] - if data: - # Load data. - coll.with_options(write_concern=wc).insert_many(scenario_def["data"]) - - def allowable_errors(self, op): - """Override expected error classes.""" - errors = super().allowable_errors(op) - # An updateOne test expects encryption to error when no $ operator - # appears but pymongo raises a client side ValueError in this case. - if op["name"] == "updateOne": - errors += (ValueError,) - return errors - - def create_test(scenario_def, test, name): - @async_client_context.require_test_commands - def run_scenario(self): - self.run_scenario(scenario_def, test) - - return run_scenario - - test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy")) - test_creator.create_tests() - - if _HAVE_PYMONGOCRYPT: - globals().update( - generate_test_classes( - os.path.join(SPEC_PATH, "unified"), - module=__name__, - ) +class AsyncTestSpec(AsyncSpecRunner): + @classmethod + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") + async def _setup_class(cls): + await super()._setup_class() + + def parse_auto_encrypt_opts(self, opts): + """Parse clientOptions.autoEncryptOpts.""" + opts = camel_to_snake_args(opts) + kms_providers = opts["kms_providers"] + if "aws" in kms_providers: + kms_providers["aws"] = AWS_CREDS + if not any(AWS_CREDS.values()): + self.skipTest("AWS environment credentials are not set") + if "awsTemporary" in kms_providers: + kms_providers["aws"] = AWS_TEMP_CREDS + del kms_providers["awsTemporary"] + if not any(AWS_TEMP_CREDS.values()): + self.skipTest("AWS Temp environment credentials are not set") + if "awsTemporaryNoSessionToken" in kms_providers: + kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS + del kms_providers["awsTemporaryNoSessionToken"] + if not any(AWS_TEMP_NO_SESSION_CREDS.values()): + self.skipTest("AWS Temp environment credentials are not set") + if "azure" in kms_providers: + kms_providers["azure"] = AZURE_CREDS + if not any(AZURE_CREDS.values()): + self.skipTest("Azure environment credentials are not set") + if "gcp" in kms_providers: + kms_providers["gcp"] = GCP_CREDS + if not any(AZURE_CREDS.values()): + self.skipTest("GCP environment credentials are not set") + if "kmip" in kms_providers: + kms_providers["kmip"] = KMIP_CREDS + opts["kms_tls_options"] = KMS_TLS_OPTS + if "key_vault_namespace" not in opts: + opts["key_vault_namespace"] = "keyvault.datakeys" + if "extra_options" in opts: + opts.update(camel_to_snake_args(opts.pop("extra_options"))) + + opts = dict(opts) + return AutoEncryptionOpts(**opts) + + def parse_client_options(self, opts): + """Override clientOptions parsing to support autoEncryptOpts.""" + encrypt_opts = opts.pop("autoEncryptOpts", None) + if encrypt_opts: + opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts) + + return super().parse_client_options(opts) + + def get_object_name(self, op): + """Default object is collection.""" + return op.get("object", "collection") + + def maybe_skip_scenario(self, test): + super().maybe_skip_scenario(test) + desc = test["description"].lower() + if ( + "timeoutms applied to listcollections to get collection schema" in desc + and sys.platform in ("win32", "darwin") + ): + self.skipTest("PYTHON-3706 flaky test on Windows/macOS") + if "type=symbol" in desc: + self.skipTest("PyMongo does not support the symbol type") + if "timeoutms applied to listcollections to get collection schema" in desc and not _IS_SYNC: + self.skipTest("PYTHON-4844 flaky test on async") + + async def setup_scenario(self, scenario_def): + """Override a test's setup.""" + key_vault_data = scenario_def["key_vault_data"] + encrypted_fields = scenario_def["encrypted_fields"] + json_schema = scenario_def["json_schema"] + data = scenario_def["data"] + coll = async_client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"] + await coll.delete_many({}) + if key_vault_data: + await coll.insert_many(key_vault_data) + + db_name = self.get_scenario_db_name(scenario_def) + coll_name = self.get_scenario_coll_name(scenario_def) + db = async_client_context.client.get_database(db_name, codec_options=OPTS) + await db.drop_collection(coll_name, encrypted_fields=encrypted_fields) + wc = WriteConcern(w="majority") + kwargs: Dict[str, Any] = {} + if json_schema: + kwargs["validator"] = {"$jsonSchema": json_schema} + kwargs["codec_options"] = OPTS + if not data: + kwargs["write_concern"] = wc + if encrypted_fields: + kwargs["encryptedFields"] = encrypted_fields + await db.create_collection(coll_name, **kwargs) + coll = db[coll_name] + if data: + # Load data. + await coll.with_options(write_concern=wc).insert_many(scenario_def["data"]) + + def allowable_errors(self, op): + """Override expected error classes.""" + errors = super().allowable_errors(op) + # An updateOne test expects encryption to error when no $ operator + # appears but pymongo raises a client side ValueError in this case. + if op["name"] == "updateOne": + errors += (ValueError,) + return errors + + +async def create_test(scenario_def, test, name): + @async_client_context.require_test_commands + async def run_scenario(self): + await self.run_scenario(scenario_def, test) + + return run_scenario + + +test_creator = AsyncSpecTestCreator(create_test, AsyncTestSpec, os.path.join(SPEC_PATH, "legacy")) +test_creator.create_tests() + +if _HAVE_PYMONGOCRYPT: + globals().update( + generate_test_classes( + os.path.join(SPEC_PATH, "unified"), + module=__name__, ) + ) # Prose Tests ALL_KMS_PROVIDERS = { diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 12cb13c2cd..4d9c4c8f20 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -15,8 +15,12 @@ """Utilities for testing driver specs.""" from __future__ import annotations +import asyncio import functools +import os import threading +import unittest +from asyncio import iscoroutinefunction from collections import abc from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs from test.utils import ( @@ -24,6 +28,7 @@ CompareType, EventListener, OvertCommandListener, + ScenarioDict, ServerAndTopologyEventListener, camel_to_snake, camel_to_snake_args, @@ -32,11 +37,12 @@ ) from typing import List -from bson import ObjectId, decode, encode +from bson import ObjectId, decode, encode, json_util from bson.binary import Binary from bson.int64 import Int64 from bson.son import SON from gridfs import GridFSBucket +from gridfs.asynchronous.grid_file import AsyncGridFSBucket from pymongo.asynchronous import client_session from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.cursor import AsyncCursor @@ -83,6 +89,161 @@ def run(self): self.stop() +class AsyncSpecTestCreator: + """Class to create test cases from specifications.""" + + def __init__(self, create_test, test_class, test_path): + """Create a TestCreator object. + + :Parameters: + - `create_test`: callback that returns a test case. The callback + must accept the following arguments - a dictionary containing the + entire test specification (the `scenario_def`), a dictionary + containing the specification for which the test case will be + generated (the `test_def`). + - `test_class`: the unittest.TestCase class in which to create the + test case. + - `test_path`: path to the directory containing the JSON files with + the test specifications. + """ + self._create_test = create_test + self._test_class = test_class + self.test_path = test_path + + def _ensure_min_max_server_version(self, scenario_def, method): + """Test modifier that enforces a version range for the server on a + test case. + """ + if "minServerVersion" in scenario_def: + min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split(".")) + if min_ver is not None: + method = async_client_context.require_version_min(*min_ver)(method) + + if "maxServerVersion" in scenario_def: + max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split(".")) + if max_ver is not None: + method = async_client_context.require_version_max(*max_ver)(method) + + if "serverless" in scenario_def: + serverless = scenario_def["serverless"] + if serverless == "require": + serverless_satisfied = async_client_context.serverless + elif serverless == "forbid": + serverless_satisfied = not async_client_context.serverless + else: # unset or "allow" + serverless_satisfied = True + method = unittest.skipUnless( + serverless_satisfied, "Serverless requirement not satisfied" + )(method) + + return method + + @staticmethod + async def valid_topology(run_on_req): + return await async_client_context.is_topology_type( + run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"]) + ) + + @staticmethod + def min_server_version(run_on_req): + version = run_on_req.get("minServerVersion") + if version: + min_ver = tuple(int(elt) for elt in version.split(".")) + return async_client_context.version >= min_ver + return True + + @staticmethod + def max_server_version(run_on_req): + version = run_on_req.get("maxServerVersion") + if version: + max_ver = tuple(int(elt) for elt in version.split(".")) + return async_client_context.version <= max_ver + return True + + @staticmethod + def valid_auth_enabled(run_on_req): + if "authEnabled" in run_on_req: + if run_on_req["authEnabled"]: + return async_client_context.auth_enabled + return not async_client_context.auth_enabled + return True + + @staticmethod + def serverless_ok(run_on_req): + serverless = run_on_req["serverless"] + if serverless == "require": + return async_client_context.serverless + elif serverless == "forbid": + return not async_client_context.serverless + else: # unset or "allow" + return True + + async def should_run_on(self, scenario_def): + run_on = scenario_def.get("runOn", []) + if not run_on: + # Always run these tests. + return True + + for req in run_on: + if ( + await self.valid_topology(req) + and self.min_server_version(req) + and self.max_server_version(req) + and self.valid_auth_enabled(req) + and self.serverless_ok(req) + ): + return True + return False + + def ensure_run_on(self, scenario_def, method): + """Test modifier that enforces a 'runOn' on a test case.""" + + async def predicate(): + return await self.should_run_on(scenario_def) + + return async_client_context._require(predicate, "runOn not satisfied", method) + + def tests(self, scenario_def): + """Allow CMAP spec test to override the location of test.""" + return scenario_def["tests"] + + async def _create_tests(self): + for dirpath, _, filenames in os.walk(self.test_path): + dirname = os.path.split(dirpath)[-1] + + for filename in filenames: + with open(os.path.join(dirpath, filename)) as scenario_stream: # noqa: ASYNC101, RUF100 + # Use tz_aware=False to match how CodecOptions decodes + # dates. + opts = json_util.JSONOptions(tz_aware=False) + scenario_def = ScenarioDict( + json_util.loads(scenario_stream.read(), json_options=opts) + ) + + test_type = os.path.splitext(filename)[0] + + # Construct test from scenario. + for test_def in self.tests(scenario_def): + test_name = "test_{}_{}_{}".format( + dirname, + test_type.replace("-", "_").replace(".", "_"), + str(test_def["description"].replace(" ", "_").replace(".", "_")), + ) + + new_test = await self._create_test(scenario_def, test_def, test_name) + new_test = self._ensure_min_max_server_version(scenario_def, new_test) + new_test = self.ensure_run_on(scenario_def, new_test) + + new_test.__name__ = test_name + setattr(self._test_class, new_test.__name__, new_test) + + def create_tests(self): + if _IS_SYNC: + self._create_tests() + else: + asyncio.run(self._create_tests()) + + class AsyncSpecRunner(AsyncIntegrationTest): mongos_clients: List knobs: client_knobs @@ -284,7 +445,7 @@ async def run_operation(self, sessions, collection, operation): if object_name == "gridfsbucket": # Only create the GridFSBucket when we need it (for the gridfs # retryable reads tests). - obj = GridFSBucket(database, bucket_name=collection.name) + obj = AsyncGridFSBucket(database, bucket_name=collection.name) else: objects = { "client": database.client, @@ -312,7 +473,10 @@ async def run_operation(self, sessions, collection, operation): args.update(arguments) arguments = args - result = cmd(**dict(arguments)) + if not _IS_SYNC and iscoroutinefunction(cmd): + result = await cmd(**dict(arguments)) + else: + result = cmd(**dict(arguments)) # Cleanup open change stream cursors. if name == "watch": self.addAsyncCleanup(result.close) @@ -588,7 +752,7 @@ async def run_scenario(self, scenario_def, test): read_preference=ReadPreference.PRIMARY, read_concern=ReadConcern("local"), ) - actual_data = await (await outcome_coll.find(sort=[("_id", 1)])).to_list() + actual_data = await outcome_coll.find(sort=[("_id", 1)]).to_list() # The expected data needs to be the left hand side here otherwise # CompareType(Binary) doesn't work. diff --git a/test/client-side-encryption/spec/legacy/timeoutMS.json b/test/client-side-encryption/spec/legacy/timeoutMS.json index b667767cfc..8411306224 100644 --- a/test/client-side-encryption/spec/legacy/timeoutMS.json +++ b/test/client-side-encryption/spec/legacy/timeoutMS.json @@ -110,7 +110,7 @@ "listCollections" ], "blockConnection": true, - "blockTimeMS": 60 + "blockTimeMS": 600 } }, "clientOptions": { @@ -119,7 +119,7 @@ "aws": {} } }, - "timeoutMS": 50 + "timeoutMS": 500 }, "operations": [ { diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index 142af0f9a7..d576a1184a 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -25,14 +25,13 @@ from test.pymongo_mocks import DummyMonitor from test.utils import ( CMAPListener, - SpecTestCreator, camel_to_snake, client_context, get_pool, get_pools, wait_until, ) -from test.utils_spec_runner import SpecRunnerThread +from test.utils_spec_runner import SpecRunnerThread, SpecTestCreator from bson.objectid import ObjectId from bson.son import SON diff --git a/test/test_encryption.py b/test/test_encryption.py index 64aa7ebf50..13a69ca9ad 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -30,6 +30,7 @@ import warnings from test import IntegrationTest, PyMongoTestCase, client_context from test.test_bulk import BulkTestBase +from test.utils_spec_runner import SpecRunner, SpecTestCreator from threading import Thread from typing import Any, Dict, Mapping, Optional @@ -58,7 +59,6 @@ from test.utils import ( AllowListEventListener, OvertCommandListener, - SpecTestCreator, TopologyEventListener, camel_to_snake_args, is_greenthread_patched, @@ -624,135 +624,132 @@ def test_with_statement(self): KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}} -if _IS_SYNC: - # TODO: Add synchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700) - class TestSpec(SpecRunner): - @classmethod - @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") - def setUpClass(cls): - super().setUpClass() - - def parse_auto_encrypt_opts(self, opts): - """Parse clientOptions.autoEncryptOpts.""" - opts = camel_to_snake_args(opts) - kms_providers = opts["kms_providers"] - if "aws" in kms_providers: - kms_providers["aws"] = AWS_CREDS - if not any(AWS_CREDS.values()): - self.skipTest("AWS environment credentials are not set") - if "awsTemporary" in kms_providers: - kms_providers["aws"] = AWS_TEMP_CREDS - del kms_providers["awsTemporary"] - if not any(AWS_TEMP_CREDS.values()): - self.skipTest("AWS Temp environment credentials are not set") - if "awsTemporaryNoSessionToken" in kms_providers: - kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS - del kms_providers["awsTemporaryNoSessionToken"] - if not any(AWS_TEMP_NO_SESSION_CREDS.values()): - self.skipTest("AWS Temp environment credentials are not set") - if "azure" in kms_providers: - kms_providers["azure"] = AZURE_CREDS - if not any(AZURE_CREDS.values()): - self.skipTest("Azure environment credentials are not set") - if "gcp" in kms_providers: - kms_providers["gcp"] = GCP_CREDS - if not any(AZURE_CREDS.values()): - self.skipTest("GCP environment credentials are not set") - if "kmip" in kms_providers: - kms_providers["kmip"] = KMIP_CREDS - opts["kms_tls_options"] = KMS_TLS_OPTS - if "key_vault_namespace" not in opts: - opts["key_vault_namespace"] = "keyvault.datakeys" - if "extra_options" in opts: - opts.update(camel_to_snake_args(opts.pop("extra_options"))) - - opts = dict(opts) - return AutoEncryptionOpts(**opts) - - def parse_client_options(self, opts): - """Override clientOptions parsing to support autoEncryptOpts.""" - encrypt_opts = opts.pop("autoEncryptOpts", None) - if encrypt_opts: - opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts) - - return super().parse_client_options(opts) - - def get_object_name(self, op): - """Default object is collection.""" - return op.get("object", "collection") - - def maybe_skip_scenario(self, test): - super().maybe_skip_scenario(test) - desc = test["description"].lower() - if ( - "timeoutms applied to listcollections to get collection schema" in desc - and sys.platform in ("win32", "darwin") - ): - self.skipTest("PYTHON-3706 flaky test on Windows/macOS") - if "type=symbol" in desc: - self.skipTest("PyMongo does not support the symbol type") - if ( - "timeoutms applied to listcollections to get collection schema" in desc - and not _IS_SYNC - ): - self.skipTest("PYTHON-4844 flaky test on async") - - def setup_scenario(self, scenario_def): - """Override a test's setup.""" - key_vault_data = scenario_def["key_vault_data"] - encrypted_fields = scenario_def["encrypted_fields"] - json_schema = scenario_def["json_schema"] - data = scenario_def["data"] - coll = client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"] - coll.delete_many({}) - if key_vault_data: - coll.insert_many(key_vault_data) - - db_name = self.get_scenario_db_name(scenario_def) - coll_name = self.get_scenario_coll_name(scenario_def) - db = client_context.client.get_database(db_name, codec_options=OPTS) - coll = db.drop_collection(coll_name, encrypted_fields=encrypted_fields) - wc = WriteConcern(w="majority") - kwargs: Dict[str, Any] = {} - if json_schema: - kwargs["validator"] = {"$jsonSchema": json_schema} - kwargs["codec_options"] = OPTS - if not data: - kwargs["write_concern"] = wc - if encrypted_fields: - kwargs["encryptedFields"] = encrypted_fields - db.create_collection(coll_name, **kwargs) - coll = db[coll_name] - if data: - # Load data. - coll.with_options(write_concern=wc).insert_many(scenario_def["data"]) - - def allowable_errors(self, op): - """Override expected error classes.""" - errors = super().allowable_errors(op) - # An updateOne test expects encryption to error when no $ operator - # appears but pymongo raises a client side ValueError in this case. - if op["name"] == "updateOne": - errors += (ValueError,) - return errors - - def create_test(scenario_def, test, name): - @client_context.require_test_commands - def run_scenario(self): - self.run_scenario(scenario_def, test) - - return run_scenario - - test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy")) - test_creator.create_tests() - - if _HAVE_PYMONGOCRYPT: - globals().update( - generate_test_classes( - os.path.join(SPEC_PATH, "unified"), - module=__name__, - ) +class TestSpec(SpecRunner): + @classmethod + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") + def _setup_class(cls): + super()._setup_class() + + def parse_auto_encrypt_opts(self, opts): + """Parse clientOptions.autoEncryptOpts.""" + opts = camel_to_snake_args(opts) + kms_providers = opts["kms_providers"] + if "aws" in kms_providers: + kms_providers["aws"] = AWS_CREDS + if not any(AWS_CREDS.values()): + self.skipTest("AWS environment credentials are not set") + if "awsTemporary" in kms_providers: + kms_providers["aws"] = AWS_TEMP_CREDS + del kms_providers["awsTemporary"] + if not any(AWS_TEMP_CREDS.values()): + self.skipTest("AWS Temp environment credentials are not set") + if "awsTemporaryNoSessionToken" in kms_providers: + kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS + del kms_providers["awsTemporaryNoSessionToken"] + if not any(AWS_TEMP_NO_SESSION_CREDS.values()): + self.skipTest("AWS Temp environment credentials are not set") + if "azure" in kms_providers: + kms_providers["azure"] = AZURE_CREDS + if not any(AZURE_CREDS.values()): + self.skipTest("Azure environment credentials are not set") + if "gcp" in kms_providers: + kms_providers["gcp"] = GCP_CREDS + if not any(AZURE_CREDS.values()): + self.skipTest("GCP environment credentials are not set") + if "kmip" in kms_providers: + kms_providers["kmip"] = KMIP_CREDS + opts["kms_tls_options"] = KMS_TLS_OPTS + if "key_vault_namespace" not in opts: + opts["key_vault_namespace"] = "keyvault.datakeys" + if "extra_options" in opts: + opts.update(camel_to_snake_args(opts.pop("extra_options"))) + + opts = dict(opts) + return AutoEncryptionOpts(**opts) + + def parse_client_options(self, opts): + """Override clientOptions parsing to support autoEncryptOpts.""" + encrypt_opts = opts.pop("autoEncryptOpts", None) + if encrypt_opts: + opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts) + + return super().parse_client_options(opts) + + def get_object_name(self, op): + """Default object is collection.""" + return op.get("object", "collection") + + def maybe_skip_scenario(self, test): + super().maybe_skip_scenario(test) + desc = test["description"].lower() + if ( + "timeoutms applied to listcollections to get collection schema" in desc + and sys.platform in ("win32", "darwin") + ): + self.skipTest("PYTHON-3706 flaky test on Windows/macOS") + if "type=symbol" in desc: + self.skipTest("PyMongo does not support the symbol type") + if "timeoutms applied to listcollections to get collection schema" in desc and not _IS_SYNC: + self.skipTest("PYTHON-4844 flaky test on async") + + def setup_scenario(self, scenario_def): + """Override a test's setup.""" + key_vault_data = scenario_def["key_vault_data"] + encrypted_fields = scenario_def["encrypted_fields"] + json_schema = scenario_def["json_schema"] + data = scenario_def["data"] + coll = client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"] + coll.delete_many({}) + if key_vault_data: + coll.insert_many(key_vault_data) + + db_name = self.get_scenario_db_name(scenario_def) + coll_name = self.get_scenario_coll_name(scenario_def) + db = client_context.client.get_database(db_name, codec_options=OPTS) + db.drop_collection(coll_name, encrypted_fields=encrypted_fields) + wc = WriteConcern(w="majority") + kwargs: Dict[str, Any] = {} + if json_schema: + kwargs["validator"] = {"$jsonSchema": json_schema} + kwargs["codec_options"] = OPTS + if not data: + kwargs["write_concern"] = wc + if encrypted_fields: + kwargs["encryptedFields"] = encrypted_fields + db.create_collection(coll_name, **kwargs) + coll = db[coll_name] + if data: + # Load data. + coll.with_options(write_concern=wc).insert_many(scenario_def["data"]) + + def allowable_errors(self, op): + """Override expected error classes.""" + errors = super().allowable_errors(op) + # An updateOne test expects encryption to error when no $ operator + # appears but pymongo raises a client side ValueError in this case. + if op["name"] == "updateOne": + errors += (ValueError,) + return errors + + +def create_test(scenario_def, test, name): + @client_context.require_test_commands + def run_scenario(self): + self.run_scenario(scenario_def, test) + + return run_scenario + + +test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy")) +test_creator.create_tests() + +if _HAVE_PYMONGOCRYPT: + globals().update( + generate_test_classes( + os.path.join(SPEC_PATH, "unified"), + module=__name__, ) + ) # Prose Tests ALL_KMS_PROVIDERS = { diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 7cab42cca2..05772fa385 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -21,11 +21,11 @@ from test.utils import ( CMAPListener, OvertCommandListener, - SpecTestCreator, get_pool, wait_until, ) from test.utils_selection_tests import create_topology +from test.utils_spec_runner import SpecTestCreator from pymongo.common import clean_node from pymongo.monitoring import ConnectionReadyEvent diff --git a/test/utils.py b/test/utils.py index 9c78cff3ad..4575a9fe10 100644 --- a/test/utils.py +++ b/test/utils.py @@ -418,153 +418,6 @@ def call_count(self): return len(self._call_list) -class SpecTestCreator: - """Class to create test cases from specifications.""" - - def __init__(self, create_test, test_class, test_path): - """Create a TestCreator object. - - :Parameters: - - `create_test`: callback that returns a test case. The callback - must accept the following arguments - a dictionary containing the - entire test specification (the `scenario_def`), a dictionary - containing the specification for which the test case will be - generated (the `test_def`). - - `test_class`: the unittest.TestCase class in which to create the - test case. - - `test_path`: path to the directory containing the JSON files with - the test specifications. - """ - self._create_test = create_test - self._test_class = test_class - self.test_path = test_path - - def _ensure_min_max_server_version(self, scenario_def, method): - """Test modifier that enforces a version range for the server on a - test case. - """ - if "minServerVersion" in scenario_def: - min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split(".")) - if min_ver is not None: - method = client_context.require_version_min(*min_ver)(method) - - if "maxServerVersion" in scenario_def: - max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split(".")) - if max_ver is not None: - method = client_context.require_version_max(*max_ver)(method) - - if "serverless" in scenario_def: - serverless = scenario_def["serverless"] - if serverless == "require": - serverless_satisfied = client_context.serverless - elif serverless == "forbid": - serverless_satisfied = not client_context.serverless - else: # unset or "allow" - serverless_satisfied = True - method = unittest.skipUnless( - serverless_satisfied, "Serverless requirement not satisfied" - )(method) - - return method - - @staticmethod - def valid_topology(run_on_req): - return client_context.is_topology_type( - run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"]) - ) - - @staticmethod - def min_server_version(run_on_req): - version = run_on_req.get("minServerVersion") - if version: - min_ver = tuple(int(elt) for elt in version.split(".")) - return client_context.version >= min_ver - return True - - @staticmethod - def max_server_version(run_on_req): - version = run_on_req.get("maxServerVersion") - if version: - max_ver = tuple(int(elt) for elt in version.split(".")) - return client_context.version <= max_ver - return True - - @staticmethod - def valid_auth_enabled(run_on_req): - if "authEnabled" in run_on_req: - if run_on_req["authEnabled"]: - return client_context.auth_enabled - return not client_context.auth_enabled - return True - - @staticmethod - def serverless_ok(run_on_req): - serverless = run_on_req["serverless"] - if serverless == "require": - return client_context.serverless - elif serverless == "forbid": - return not client_context.serverless - else: # unset or "allow" - return True - - def should_run_on(self, scenario_def): - run_on = scenario_def.get("runOn", []) - if not run_on: - # Always run these tests. - return True - - for req in run_on: - if ( - self.valid_topology(req) - and self.min_server_version(req) - and self.max_server_version(req) - and self.valid_auth_enabled(req) - and self.serverless_ok(req) - ): - return True - return False - - def ensure_run_on(self, scenario_def, method): - """Test modifier that enforces a 'runOn' on a test case.""" - return client_context._require( - lambda: self.should_run_on(scenario_def), "runOn not satisfied", method - ) - - def tests(self, scenario_def): - """Allow CMAP spec test to override the location of test.""" - return scenario_def["tests"] - - def create_tests(self): - for dirpath, _, filenames in os.walk(self.test_path): - dirname = os.path.split(dirpath)[-1] - - for filename in filenames: - with open(os.path.join(dirpath, filename)) as scenario_stream: - # Use tz_aware=False to match how CodecOptions decodes - # dates. - opts = json_util.JSONOptions(tz_aware=False) - scenario_def = ScenarioDict( - json_util.loads(scenario_stream.read(), json_options=opts) - ) - - test_type = os.path.splitext(filename)[0] - - # Construct test from scenario. - for test_def in self.tests(scenario_def): - test_name = "test_{}_{}_{}".format( - dirname, - test_type.replace("-", "_").replace(".", "_"), - str(test_def["description"].replace(" ", "_").replace(".", "_")), - ) - - new_test = self._create_test(scenario_def, test_def, test_name) - new_test = self._ensure_min_max_server_version(scenario_def, new_test) - new_test = self.ensure_run_on(scenario_def, new_test) - - new_test.__name__ = test_name - setattr(self._test_class, new_test.__name__, new_test) - - def ensure_all_connected(client: MongoClient) -> None: """Ensure that the client's connection pool has socket connections to all members of a replica set. Raises ConfigurationError when called with a diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 06a40351cd..8a061de0b1 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -15,8 +15,12 @@ """Utilities for testing driver specs.""" from __future__ import annotations +import asyncio import functools +import os import threading +import unittest +from asyncio import iscoroutinefunction from collections import abc from test import IntegrationTest, client_context, client_knobs from test.utils import ( @@ -24,6 +28,7 @@ CompareType, EventListener, OvertCommandListener, + ScenarioDict, ServerAndTopologyEventListener, camel_to_snake, camel_to_snake_args, @@ -32,11 +37,12 @@ ) from typing import List -from bson import ObjectId, decode, encode +from bson import ObjectId, decode, encode, json_util from bson.binary import Binary from bson.int64 import Int64 from bson.son import SON from gridfs import GridFSBucket +from gridfs.synchronous.grid_file import GridFSBucket from pymongo.errors import BulkWriteError, OperationFailure, PyMongoError from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference @@ -83,6 +89,161 @@ def run(self): self.stop() +class SpecTestCreator: + """Class to create test cases from specifications.""" + + def __init__(self, create_test, test_class, test_path): + """Create a TestCreator object. + + :Parameters: + - `create_test`: callback that returns a test case. The callback + must accept the following arguments - a dictionary containing the + entire test specification (the `scenario_def`), a dictionary + containing the specification for which the test case will be + generated (the `test_def`). + - `test_class`: the unittest.TestCase class in which to create the + test case. + - `test_path`: path to the directory containing the JSON files with + the test specifications. + """ + self._create_test = create_test + self._test_class = test_class + self.test_path = test_path + + def _ensure_min_max_server_version(self, scenario_def, method): + """Test modifier that enforces a version range for the server on a + test case. + """ + if "minServerVersion" in scenario_def: + min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split(".")) + if min_ver is not None: + method = client_context.require_version_min(*min_ver)(method) + + if "maxServerVersion" in scenario_def: + max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split(".")) + if max_ver is not None: + method = client_context.require_version_max(*max_ver)(method) + + if "serverless" in scenario_def: + serverless = scenario_def["serverless"] + if serverless == "require": + serverless_satisfied = client_context.serverless + elif serverless == "forbid": + serverless_satisfied = not client_context.serverless + else: # unset or "allow" + serverless_satisfied = True + method = unittest.skipUnless( + serverless_satisfied, "Serverless requirement not satisfied" + )(method) + + return method + + @staticmethod + def valid_topology(run_on_req): + return client_context.is_topology_type( + run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"]) + ) + + @staticmethod + def min_server_version(run_on_req): + version = run_on_req.get("minServerVersion") + if version: + min_ver = tuple(int(elt) for elt in version.split(".")) + return client_context.version >= min_ver + return True + + @staticmethod + def max_server_version(run_on_req): + version = run_on_req.get("maxServerVersion") + if version: + max_ver = tuple(int(elt) for elt in version.split(".")) + return client_context.version <= max_ver + return True + + @staticmethod + def valid_auth_enabled(run_on_req): + if "authEnabled" in run_on_req: + if run_on_req["authEnabled"]: + return client_context.auth_enabled + return not client_context.auth_enabled + return True + + @staticmethod + def serverless_ok(run_on_req): + serverless = run_on_req["serverless"] + if serverless == "require": + return client_context.serverless + elif serverless == "forbid": + return not client_context.serverless + else: # unset or "allow" + return True + + def should_run_on(self, scenario_def): + run_on = scenario_def.get("runOn", []) + if not run_on: + # Always run these tests. + return True + + for req in run_on: + if ( + self.valid_topology(req) + and self.min_server_version(req) + and self.max_server_version(req) + and self.valid_auth_enabled(req) + and self.serverless_ok(req) + ): + return True + return False + + def ensure_run_on(self, scenario_def, method): + """Test modifier that enforces a 'runOn' on a test case.""" + + def predicate(): + return self.should_run_on(scenario_def) + + return client_context._require(predicate, "runOn not satisfied", method) + + def tests(self, scenario_def): + """Allow CMAP spec test to override the location of test.""" + return scenario_def["tests"] + + def _create_tests(self): + for dirpath, _, filenames in os.walk(self.test_path): + dirname = os.path.split(dirpath)[-1] + + for filename in filenames: + with open(os.path.join(dirpath, filename)) as scenario_stream: # noqa: ASYNC101, RUF100 + # Use tz_aware=False to match how CodecOptions decodes + # dates. + opts = json_util.JSONOptions(tz_aware=False) + scenario_def = ScenarioDict( + json_util.loads(scenario_stream.read(), json_options=opts) + ) + + test_type = os.path.splitext(filename)[0] + + # Construct test from scenario. + for test_def in self.tests(scenario_def): + test_name = "test_{}_{}_{}".format( + dirname, + test_type.replace("-", "_").replace(".", "_"), + str(test_def["description"].replace(" ", "_").replace(".", "_")), + ) + + new_test = self._create_test(scenario_def, test_def, test_name) + new_test = self._ensure_min_max_server_version(scenario_def, new_test) + new_test = self.ensure_run_on(scenario_def, new_test) + + new_test.__name__ = test_name + setattr(self._test_class, new_test.__name__, new_test) + + def create_tests(self): + if _IS_SYNC: + self._create_tests() + else: + asyncio.run(self._create_tests()) + + class SpecRunner(IntegrationTest): mongos_clients: List knobs: client_knobs @@ -312,7 +473,10 @@ def run_operation(self, sessions, collection, operation): args.update(arguments) arguments = args - result = cmd(**dict(arguments)) + if not _IS_SYNC and iscoroutinefunction(cmd): + result = cmd(**dict(arguments)) + else: + result = cmd(**dict(arguments)) # Cleanup open change stream cursors. if name == "watch": self.addCleanup(result.close) @@ -583,7 +747,7 @@ def run_scenario(self, scenario_def, test): read_preference=ReadPreference.PRIMARY, read_concern=ReadConcern("local"), ) - actual_data = (outcome_coll.find(sort=[("_id", 1)])).to_list() + actual_data = outcome_coll.find(sort=[("_id", 1)]).to_list() # The expected data needs to be the left hand side here otherwise # CompareType(Binary) doesn't work. diff --git a/tools/synchro.py b/tools/synchro.py index 0ec8985a05..f704919a17 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -105,6 +105,8 @@ "PyMongo|c|async": "PyMongo|c", "AsyncTestGridFile": "TestGridFile", "AsyncTestGridFileNoConnect": "TestGridFileNoConnect", + "AsyncTestSpec": "TestSpec", + "AsyncSpecTestCreator": "SpecTestCreator", "async_set_fail_point": "set_fail_point", "async_ensure_all_connected": "ensure_all_connected", "async_repl_set_step_down": "repl_set_step_down", From 6973d2d2743b7679080b8be70391b767740cf674 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 11 Oct 2024 11:02:06 -0400 Subject: [PATCH 03/19] PYTHON-4528 - Convert unified test runner to async (#1913) --- test/asynchronous/unified_format.py | 1573 +++++++++++++++++++++++++++ test/unified_format.py | 711 +----------- test/unified_format_shared.py | 679 ++++++++++++ tools/synchro.py | 1 + 4 files changed, 2301 insertions(+), 663 deletions(-) create mode 100644 test/asynchronous/unified_format.py create mode 100644 test/unified_format_shared.py diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py new file mode 100644 index 0000000000..4c37422951 --- /dev/null +++ b/test/asynchronous/unified_format.py @@ -0,0 +1,1573 @@ +# Copyright 2020-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unified test format runner. + +https://github.com/mongodb/specifications/blob/master/source/unified-test-format/unified-test-format.rst +""" +from __future__ import annotations + +import asyncio +import binascii +import copy +import functools +import os +import re +import sys +import time +import traceback +from asyncio import iscoroutinefunction +from collections import defaultdict +from test.asynchronous import ( + AsyncIntegrationTest, + async_client_context, + client_knobs, + unittest, +) +from test.unified_format_shared import ( + IS_INTERRUPTED, + KMS_TLS_OPTS, + PLACEHOLDER_MAP, + SKIP_CSOT_TESTS, + EventListenerUtil, + MatchEvaluatorUtil, + coerce_result, + parse_bulk_write_error_result, + parse_bulk_write_result, + parse_client_bulk_write_error_result, + parse_collection_or_database_options, + with_metaclass, +) +from test.utils import ( + async_get_pool, + camel_to_snake, + camel_to_snake_args, + parse_spec_options, + prepare_spec_arguments, + snake_to_camel, + wait_until, +) +from test.utils_spec_runner import SpecRunnerThread +from test.version import Version +from typing import Any, Dict, List, Mapping, Optional + +import pymongo +from bson import SON, json_util +from bson.codec_options import DEFAULT_CODEC_OPTIONS +from bson.objectid import ObjectId +from gridfs import AsyncGridFSBucket, GridOut +from pymongo import ASCENDING, AsyncMongoClient, CursorType, _csot +from pymongo.asynchronous.change_stream import AsyncChangeStream +from pymongo.asynchronous.client_session import AsyncClientSession, TransactionOptions, _TxnState +from pymongo.asynchronous.collection import AsyncCollection +from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.asynchronous.database import AsyncDatabase +from pymongo.asynchronous.encryption import AsyncClientEncryption +from pymongo.asynchronous.helpers import anext +from pymongo.encryption_options import _HAVE_PYMONGOCRYPT +from pymongo.errors import ( + BulkWriteError, + ClientBulkWriteException, + ConfigurationError, + ConnectionFailure, + EncryptionError, + InvalidOperation, + NotPrimaryError, + OperationFailure, + PyMongoError, +) +from pymongo.monitoring import ( + CommandStartedEvent, +) +from pymongo.operations import ( + SearchIndexModel, +) +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference +from pymongo.server_api import ServerApi +from pymongo.server_selectors import Selection, writable_server_selector +from pymongo.server_type import SERVER_TYPE +from pymongo.topology_description import TopologyDescription +from pymongo.typings import _Address +from pymongo.write_concern import WriteConcern + +_IS_SYNC = False + + +async def is_run_on_requirement_satisfied(requirement): + topology_satisfied = True + req_topologies = requirement.get("topologies") + if req_topologies: + topology_satisfied = await async_client_context.is_topology_type(req_topologies) + + server_version = Version(*async_client_context.version[:3]) + + min_version_satisfied = True + req_min_server_version = requirement.get("minServerVersion") + if req_min_server_version: + min_version_satisfied = Version.from_string(req_min_server_version) <= server_version + + max_version_satisfied = True + req_max_server_version = requirement.get("maxServerVersion") + if req_max_server_version: + max_version_satisfied = Version.from_string(req_max_server_version) >= server_version + + serverless = requirement.get("serverless") + if serverless == "require": + serverless_satisfied = async_client_context.serverless + elif serverless == "forbid": + serverless_satisfied = not async_client_context.serverless + else: # unset or "allow" + serverless_satisfied = True + + params_satisfied = True + params = requirement.get("serverParameters") + if params: + for param, val in params.items(): + if param not in async_client_context.server_parameters: + params_satisfied = False + elif async_client_context.server_parameters[param] != val: + params_satisfied = False + + auth_satisfied = True + req_auth = requirement.get("auth") + if req_auth is not None: + if req_auth: + auth_satisfied = async_client_context.auth_enabled + if auth_satisfied and "authMechanism" in requirement: + auth_satisfied = async_client_context.check_auth_type(requirement["authMechanism"]) + else: + auth_satisfied = not async_client_context.auth_enabled + + csfle_satisfied = True + req_csfle = requirement.get("csfle") + if req_csfle is True: + min_version_satisfied = Version.from_string("4.2") <= server_version + csfle_satisfied = _HAVE_PYMONGOCRYPT and min_version_satisfied + + return ( + topology_satisfied + and min_version_satisfied + and max_version_satisfied + and serverless_satisfied + and params_satisfied + and auth_satisfied + and csfle_satisfied + ) + + +class NonLazyCursor: + """A find cursor proxy that creates the remote cursor when initialized.""" + + def __init__(self, find_cursor, client): + self.client = client + self.find_cursor = find_cursor + # Create the server side cursor. + self.first_result = None + + @classmethod + async def create(cls, find_cursor, client): + cursor = cls(find_cursor, client) + try: + cursor.first_result = await anext(cursor.find_cursor) + except StopAsyncIteration: + cursor.first_result = None + return cursor + + @property + def alive(self): + return self.first_result is not None or self.find_cursor.alive + + async def __anext__(self): + if self.first_result is not None: + first = self.first_result + self.first_result = None + return first + return await anext(self.find_cursor) + + # Added to support the iterateOnce operation. + try_next = __anext__ + + async def close(self): + await self.find_cursor.close() + self.client = None + + +class EntityMapUtil: + """Utility class that implements an entity map as per the unified + test format specification. + """ + + def __init__(self, test_class): + self._entities: Dict[str, Any] = {} + self._listeners: Dict[str, EventListenerUtil] = {} + self._session_lsids: Dict[str, Mapping[str, Any]] = {} + self.test: UnifiedSpecTestMixinV1 = test_class + self._cluster_time: Mapping[str, Any] = {} + + def __contains__(self, item): + return item in self._entities + + def __len__(self): + return len(self._entities) + + def __getitem__(self, item): + try: + return self._entities[item] + except KeyError: + self.test.fail(f"Could not find entity named {item} in map") + + def __setitem__(self, key, value): + if not isinstance(key, str): + self.test.fail("Expected entity name of type str, got %s" % (type(key))) + + if key in self._entities: + self.test.fail(f"Entity named {key} already in map") + + self._entities[key] = value + + def _handle_placeholders(self, spec: dict, current: dict, path: str) -> Any: + if "$$placeholder" in current: + if path not in PLACEHOLDER_MAP: + raise ValueError(f"Could not find a placeholder value for {path}") + return PLACEHOLDER_MAP[path] + + for key in list(current): + value = current[key] + if isinstance(value, dict): + subpath = f"{path}/{key}" + current[key] = self._handle_placeholders(spec, value, subpath) + return current + + async def _create_entity(self, entity_spec, uri=None): + if len(entity_spec) != 1: + self.test.fail(f"Entity spec {entity_spec} did not contain exactly one top-level key") + + entity_type, spec = next(iter(entity_spec.items())) + spec = self._handle_placeholders(spec, spec, "") + if entity_type == "client": + kwargs: dict = {} + observe_events = spec.get("observeEvents", []) + + # The unified tests use topologyOpeningEvent, we use topologyOpenedEvent + for i in range(len(observe_events)): + if "topologyOpeningEvent" == observe_events[i]: + observe_events[i] = "topologyOpenedEvent" + ignore_commands = spec.get("ignoreCommandMonitoringEvents", []) + observe_sensitive_commands = spec.get("observeSensitiveCommands", False) + ignore_commands = [cmd.lower() for cmd in ignore_commands] + listener = EventListenerUtil( + observe_events, + ignore_commands, + observe_sensitive_commands, + spec.get("storeEventsAsEntities"), + self, + ) + self._listeners[spec["id"]] = listener + kwargs["event_listeners"] = [listener] + if spec.get("useMultipleMongoses"): + if async_client_context.load_balancer or async_client_context.serverless: + kwargs["h"] = async_client_context.MULTI_MONGOS_LB_URI + elif async_client_context.is_mongos: + kwargs["h"] = async_client_context.mongos_seeds() + kwargs.update(spec.get("uriOptions", {})) + server_api = spec.get("serverApi") + if "waitQueueSize" in kwargs: + raise unittest.SkipTest("PyMongo does not support waitQueueSize") + if "waitQueueMultiple" in kwargs: + raise unittest.SkipTest("PyMongo does not support waitQueueMultiple") + if server_api: + kwargs["server_api"] = ServerApi( + server_api["version"], + strict=server_api.get("strict"), + deprecation_errors=server_api.get("deprecationErrors"), + ) + if uri: + kwargs["h"] = uri + client = await self.test.async_rs_or_single_client(**kwargs) + self[spec["id"]] = client + self.test.addAsyncCleanup(client.close) + return + elif entity_type == "database": + client = self[spec["client"]] + if type(client).__name__ != "AsyncMongoClient": + self.test.fail( + "Expected entity {} to be of type AsyncMongoClient, got {}".format( + spec["client"], type(client) + ) + ) + options = parse_collection_or_database_options(spec.get("databaseOptions", {})) + self[spec["id"]] = client.get_database(spec["databaseName"], **options) + return + elif entity_type == "collection": + database = self[spec["database"]] + if not isinstance(database, AsyncDatabase): + self.test.fail( + "Expected entity {} to be of type AsyncDatabase, got {}".format( + spec["database"], type(database) + ) + ) + options = parse_collection_or_database_options(spec.get("collectionOptions", {})) + self[spec["id"]] = database.get_collection(spec["collectionName"], **options) + return + elif entity_type == "session": + client = self[spec["client"]] + if type(client).__name__ != "AsyncMongoClient": + self.test.fail( + "Expected entity {} to be of type AsyncMongoClient, got {}".format( + spec["client"], type(client) + ) + ) + opts = camel_to_snake_args(spec.get("sessionOptions", {})) + if "default_transaction_options" in opts: + txn_opts = parse_spec_options(opts["default_transaction_options"]) + txn_opts = TransactionOptions(**txn_opts) + opts = copy.deepcopy(opts) + opts["default_transaction_options"] = txn_opts + session = client.start_session(**dict(opts)) + self[spec["id"]] = session + self._session_lsids[spec["id"]] = copy.deepcopy(session.session_id) + self.test.addAsyncCleanup(session.end_session) + return + elif entity_type == "bucket": + db = self[spec["database"]] + kwargs = parse_spec_options(spec.get("bucketOptions", {}).copy()) + bucket = AsyncGridFSBucket(db, **kwargs) + + # PyMongo does not support AsyncGridFSBucket.drop(), emulate it. + @_csot.apply + async def drop(self: AsyncGridFSBucket, *args: Any, **kwargs: Any) -> None: + await self._files.drop(*args, **kwargs) + await self._chunks.drop(*args, **kwargs) + + if not hasattr(bucket, "drop"): + bucket.drop = drop.__get__(bucket) + self[spec["id"]] = bucket + return + elif entity_type == "clientEncryption": + opts = camel_to_snake_args(spec["clientEncryptionOpts"].copy()) + if isinstance(opts["key_vault_client"], str): + opts["key_vault_client"] = self[opts["key_vault_client"]] + # Set TLS options for providers like "kmip:name1". + kms_tls_options = {} + for provider in opts["kms_providers"]: + provider_type = provider.split(":")[0] + if provider_type in KMS_TLS_OPTS: + kms_tls_options[provider] = KMS_TLS_OPTS[provider_type] + self[spec["id"]] = AsyncClientEncryption( + opts["kms_providers"], + opts["key_vault_namespace"], + opts["key_vault_client"], + DEFAULT_CODEC_OPTIONS, + opts.get("kms_tls_options", kms_tls_options), + ) + return + elif entity_type == "thread": + name = spec["id"] + thread = SpecRunnerThread(name) + thread.start() + self[name] = thread + return + + self.test.fail(f"Unable to create entity of unknown type {entity_type}") + + async def create_entities_from_spec(self, entity_spec, uri=None): + for spec in entity_spec: + await self._create_entity(spec, uri=uri) + + def get_listener_for_client(self, client_name: str) -> EventListenerUtil: + client = self[client_name] + if type(client).__name__ != "AsyncMongoClient": + self.test.fail( + f"Expected entity {client_name} to be of type AsyncMongoClient, got {type(client)}" + ) + + listener = self._listeners.get(client_name) + if not listener: + self.test.fail(f"No listeners configured for client {client_name}") + + return listener + + def get_lsid_for_session(self, session_name): + session = self[session_name] + if not isinstance(session, AsyncClientSession): + self.test.fail( + f"Expected entity {session_name} to be of type AsyncClientSession, got {type(session)}" + ) + + try: + return session.session_id + except InvalidOperation: + # session has been closed. + return self._session_lsids[session_name] + + async def advance_cluster_times(self) -> None: + """Manually synchronize entities when desired""" + if not self._cluster_time: + self._cluster_time = (await self.test.client.admin.command("ping")).get("$clusterTime") + for entity in self._entities.values(): + if isinstance(entity, AsyncClientSession) and self._cluster_time: + entity.advance_cluster_time(self._cluster_time) + + +class UnifiedSpecTestMixinV1(AsyncIntegrationTest): + """Mixin class to run test cases from test specification files. + + Assumes that tests conform to the `unified test format + `_. + + Specification of the test suite being currently run is available as + a class attribute ``TEST_SPEC``. + """ + + SCHEMA_VERSION = Version.from_string("1.21") + RUN_ON_LOAD_BALANCER = True + RUN_ON_SERVERLESS = True + TEST_SPEC: Any + mongos_clients: list[AsyncMongoClient] = [] + + @staticmethod + async def should_run_on(run_on_spec): + if not run_on_spec: + # Always run these tests. + return True + + for req in run_on_spec: + if await is_run_on_requirement_satisfied(req): + return True + return False + + async def insert_initial_data(self, initial_data): + for i, collection_data in enumerate(initial_data): + coll_name = collection_data["collectionName"] + db_name = collection_data["databaseName"] + opts = collection_data.get("createOptions", {}) + documents = collection_data["documents"] + + # Setup the collection with as few majority writes as possible. + db = self.client[db_name] + await db.drop_collection(coll_name) + # Only use majority wc only on the final write. + if i == len(initial_data) - 1: + wc = WriteConcern(w="majority") + else: + wc = WriteConcern(w=1) + if documents: + if opts: + await db.create_collection(coll_name, **opts) + await db.get_collection(coll_name, write_concern=wc).insert_many(documents) + else: + # Ensure collection exists + await db.create_collection(coll_name, write_concern=wc, **opts) + + @classmethod + async def _setup_class(cls): + # super call creates internal client cls.client + await super()._setup_class() + # process file-level runOnRequirements + run_on_spec = cls.TEST_SPEC.get("runOnRequirements", []) + if not await cls.should_run_on(run_on_spec): + raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied") + + # add any special-casing for skipping tests here + if async_client_context.storage_engine == "mmapv1": + if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str( + cls.TEST_PATH + ): + raise unittest.SkipTest("MMAPv1 does not support retryWrites=True") + + # Handle mongos_clients for transactions tests. + cls.mongos_clients = [] + if ( + async_client_context.supports_transactions() + and not async_client_context.load_balancer + and not async_client_context.serverless + ): + for address in async_client_context.mongoses: + cls.mongos_clients.append( + await cls.unmanaged_async_single_client("{}:{}".format(*address)) + ) + + # Speed up the tests by decreasing the heartbeat frequency. + cls.knobs = client_knobs( + heartbeat_frequency=0.1, + min_heartbeat_interval=0.1, + kill_cursor_frequency=0.1, + events_queue_frequency=0.1, + ) + cls.knobs.enable() + + @classmethod + async def _tearDown_class(cls): + cls.knobs.disable() + for client in cls.mongos_clients: + await client.close() + await super()._tearDown_class() + + async def asyncSetUp(self): + await super().asyncSetUp() + # process schemaVersion + # note: we check major schema version during class generation + # note: we do this here because we cannot run assertions in setUpClass + version = Version.from_string(self.TEST_SPEC["schemaVersion"]) + self.assertLessEqual( + version, + self.SCHEMA_VERSION, + f"expected schema version {self.SCHEMA_VERSION} or lower, got {version}", + ) + + # initialize internals + self.match_evaluator = MatchEvaluatorUtil(self) + + def maybe_skip_test(self, spec): + # add any special-casing for skipping tests here + if async_client_context.storage_engine == "mmapv1": + if ( + "Dirty explicit session is discarded" in spec["description"] + or "Dirty implicit session is discarded" in spec["description"] + or "Cancel server check" in spec["description"] + ): + self.skipTest("MMAPv1 does not support retryWrites=True") + if ( + "AsyncDatabase-level aggregate with $out includes read preference for 5.0+ server" + in spec["description"] + ): + if async_client_context.version[0] == 8: + self.skipTest("waiting on PYTHON-4356") + if "Aggregate with $out includes read preference for 5.0+ server" in spec["description"]: + if async_client_context.version[0] == 8: + self.skipTest("waiting on PYTHON-4356") + if "Client side error in command starting transaction" in spec["description"]: + self.skipTest("Implement PYTHON-1894") + if "timeoutMS applied to entire download" in spec["description"]: + self.skipTest("PyMongo's open_download_stream does not cap the stream's lifetime") + + class_name = self.__class__.__name__.lower() + description = spec["description"].lower() + if "csot" in class_name: + if "gridfs" in class_name and sys.platform == "win32": + self.skipTest("PYTHON-3522 CSOT GridFS tests are flaky on Windows") + if async_client_context.storage_engine == "mmapv1": + self.skipTest( + "MMAPv1 does not support retryable writes which is required for CSOT tests" + ) + if "change" in description or "change" in class_name: + self.skipTest("CSOT not implemented for watch()") + if "cursors" in class_name: + self.skipTest("CSOT not implemented for cursors") + if "tailable" in class_name: + self.skipTest("CSOT not implemented for tailable cursors") + if "sessions" in class_name: + self.skipTest("CSOT not implemented for sessions") + if "withtransaction" in description: + self.skipTest("CSOT not implemented for with_transaction") + if "transaction" in class_name or "transaction" in description: + self.skipTest("CSOT not implemented for transactions") + + # Some tests need to be skipped based on the operations they try to run. + for op in spec["operations"]: + name = op["name"] + if name == "count": + self.skipTest("PyMongo does not support count()") + if name == "listIndexNames": + self.skipTest("PyMongo does not support list_index_names()") + if async_client_context.storage_engine == "mmapv1": + if name == "createChangeStream": + self.skipTest("MMAPv1 does not support change streams") + if name == "withTransaction" or name == "startTransaction": + self.skipTest("MMAPv1 does not support document-level locking") + if not async_client_context.test_commands_enabled: + if name == "failPoint" or name == "targetedFailPoint": + self.skipTest("Test commands must be enabled to use fail points") + if name == "modifyCollection": + self.skipTest("PyMongo does not support modifyCollection") + if "timeoutMode" in op.get("arguments", {}): + self.skipTest("PyMongo does not support timeoutMode") + + def process_error(self, exception, spec): + if isinstance(exception, unittest.SkipTest): + raise + is_error = spec.get("isError") + is_client_error = spec.get("isClientError") + is_timeout_error = spec.get("isTimeoutError") + error_contains = spec.get("errorContains") + error_code = spec.get("errorCode") + error_code_name = spec.get("errorCodeName") + error_labels_contain = spec.get("errorLabelsContain") + error_labels_omit = spec.get("errorLabelsOmit") + expect_result = spec.get("expectResult") + error_response = spec.get("errorResponse") + if error_response: + if isinstance(exception, ClientBulkWriteException): + self.match_evaluator.match_result(error_response, exception.error.details) + else: + self.match_evaluator.match_result(error_response, exception.details) + + if is_error: + # already satisfied because exception was raised + pass + + if is_client_error: + if isinstance(exception, ClientBulkWriteException): + error = exception.error + else: + error = exception + # Connection errors are considered client errors. + if isinstance(error, ConnectionFailure): + self.assertNotIsInstance(error, NotPrimaryError) + elif isinstance(error, (InvalidOperation, ConfigurationError, EncryptionError)): + pass + else: + self.assertNotIsInstance(error, PyMongoError) + + if is_timeout_error: + self.assertIsInstance(exception, PyMongoError) + if not exception.timeout: + # Re-raise the exception for better diagnostics. + raise exception + + if error_contains: + if isinstance(exception, BulkWriteError): + errmsg = str(exception.details).lower() + elif isinstance(exception, ClientBulkWriteException): + errmsg = str(exception.details).lower() + else: + errmsg = str(exception).lower() + self.assertIn(error_contains.lower(), errmsg) + + if error_code: + if isinstance(exception, ClientBulkWriteException): + self.assertEqual(error_code, exception.error.details.get("code")) + else: + self.assertEqual(error_code, exception.details.get("code")) + + if error_code_name: + if isinstance(exception, ClientBulkWriteException): + self.assertEqual(error_code, exception.error.details.get("codeName")) + else: + self.assertEqual(error_code_name, exception.details.get("codeName")) + + if error_labels_contain: + if isinstance(exception, ClientBulkWriteException): + error = exception.error + else: + error = exception + labels = [ + err_label for err_label in error_labels_contain if error.has_error_label(err_label) + ] + self.assertEqual(labels, error_labels_contain) + + if error_labels_omit: + for err_label in error_labels_omit: + if exception.has_error_label(err_label): + self.fail(f"Exception '{exception}' unexpectedly had label '{err_label}'") + + if expect_result: + if isinstance(exception, BulkWriteError): + result = parse_bulk_write_error_result(exception) + self.match_evaluator.match_result(expect_result, result) + elif isinstance(exception, ClientBulkWriteException): + result = parse_client_bulk_write_error_result(exception) + self.match_evaluator.match_result(expect_result, result) + else: + self.fail( + f"expectResult can only be specified with {BulkWriteError} or {ClientBulkWriteException} exceptions" + ) + + return exception + + def __raise_if_unsupported(self, opname, target, *target_types): + if not isinstance(target, target_types): + self.fail(f"Operation {opname} not supported for entity of type {type(target)}") + + async def __entityOperation_createChangeStream(self, target, *args, **kwargs): + if async_client_context.storage_engine == "mmapv1": + self.skipTest("MMAPv1 does not support change streams") + self.__raise_if_unsupported( + "createChangeStream", target, AsyncMongoClient, AsyncDatabase, AsyncCollection + ) + stream = await target.watch(*args, **kwargs) + self.addAsyncCleanup(stream.close) + return stream + + async def _clientOperation_createChangeStream(self, target, *args, **kwargs): + return await self.__entityOperation_createChangeStream(target, *args, **kwargs) + + async def _databaseOperation_createChangeStream(self, target, *args, **kwargs): + return await self.__entityOperation_createChangeStream(target, *args, **kwargs) + + async def _collectionOperation_createChangeStream(self, target, *args, **kwargs): + return await self.__entityOperation_createChangeStream(target, *args, **kwargs) + + async def _databaseOperation_runCommand(self, target, **kwargs): + self.__raise_if_unsupported("runCommand", target, AsyncDatabase) + # Ensure the first key is the command name. + ordered_command = SON([(kwargs.pop("command_name"), 1)]) + ordered_command.update(kwargs["command"]) + kwargs["command"] = ordered_command + return await target.command(**kwargs) + + async def _databaseOperation_runCursorCommand(self, target, **kwargs): + return list(await self._databaseOperation_createCommandCursor(target, **kwargs)) + + async def _databaseOperation_createCommandCursor(self, target, **kwargs): + self.__raise_if_unsupported("createCommandCursor", target, AsyncDatabase) + # Ensure the first key is the command name. + ordered_command = SON([(kwargs.pop("command_name"), 1)]) + ordered_command.update(kwargs["command"]) + kwargs["command"] = ordered_command + batch_size = 0 + + cursor_type = kwargs.pop("cursor_type", "nonTailable") + if cursor_type == CursorType.TAILABLE: + ordered_command["tailable"] = True + elif cursor_type == CursorType.TAILABLE_AWAIT: + ordered_command["tailable"] = True + ordered_command["awaitData"] = True + elif cursor_type != "nonTailable": + self.fail(f"unknown cursorType: {cursor_type}") + + if "maxTimeMS" in kwargs: + kwargs["max_await_time_ms"] = kwargs.pop("maxTimeMS") + + if "batch_size" in kwargs: + batch_size = kwargs.pop("batch_size") + + cursor = await target.cursor_command(**kwargs) + + if batch_size > 0: + cursor.batch_size(batch_size) + + return cursor + + async def kill_all_sessions(self): + if getattr(self, "client", None) is None: + return + clients = self.mongos_clients if self.mongos_clients else [self.client] + for client in clients: + try: + await client.admin.command("killAllSessions", []) + except OperationFailure: + # "operation was interrupted" by killing the command's + # own session. + pass + + async def _databaseOperation_listCollections(self, target, *args, **kwargs): + if "batch_size" in kwargs: + kwargs["cursor"] = {"batchSize": kwargs.pop("batch_size")} + cursor = await target.list_collections(*args, **kwargs) + return list(cursor) + + async def _databaseOperation_createCollection(self, target, *args, **kwargs): + # PYTHON-1936 Ignore the listCollections event from create_collection. + kwargs["check_exists"] = False + ret = await target.create_collection(*args, **kwargs) + return ret + + async def __entityOperation_aggregate(self, target, *args, **kwargs): + self.__raise_if_unsupported("aggregate", target, AsyncDatabase, AsyncCollection) + return await (await target.aggregate(*args, **kwargs)).to_list() + + async def _databaseOperation_aggregate(self, target, *args, **kwargs): + return await self.__entityOperation_aggregate(target, *args, **kwargs) + + async def _collectionOperation_aggregate(self, target, *args, **kwargs): + return await self.__entityOperation_aggregate(target, *args, **kwargs) + + async def _collectionOperation_find(self, target, *args, **kwargs): + self.__raise_if_unsupported("find", target, AsyncCollection) + find_cursor = target.find(*args, **kwargs) + return await find_cursor.to_list() + + async def _collectionOperation_createFindCursor(self, target, *args, **kwargs): + self.__raise_if_unsupported("find", target, AsyncCollection) + if "filter" not in kwargs: + self.fail('createFindCursor requires a "filter" argument') + cursor = await NonLazyCursor.create(target.find(*args, **kwargs), target.database.client) + self.addAsyncCleanup(cursor.close) + return cursor + + def _collectionOperation_count(self, target, *args, **kwargs): + self.skipTest("PyMongo does not support collection.count()") + + async def _collectionOperation_listIndexes(self, target, *args, **kwargs): + if "batch_size" in kwargs: + self.skipTest("PyMongo does not support batch_size for list_indexes") + return await (await target.list_indexes(*args, **kwargs)).to_list() + + def _collectionOperation_listIndexNames(self, target, *args, **kwargs): + self.skipTest("PyMongo does not support list_index_names") + + async def _collectionOperation_createSearchIndexes(self, target, *args, **kwargs): + models = [SearchIndexModel(**i) for i in kwargs["models"]] + return await target.create_search_indexes(models) + + async def _collectionOperation_listSearchIndexes(self, target, *args, **kwargs): + name = kwargs.get("name") + agg_kwargs = kwargs.get("aggregation_options", dict()) + return await (await target.list_search_indexes(name, **agg_kwargs)).to_list() + + async def _sessionOperation_withTransaction(self, target, *args, **kwargs): + if async_client_context.storage_engine == "mmapv1": + self.skipTest("MMAPv1 does not support document-level locking") + self.__raise_if_unsupported("withTransaction", target, AsyncClientSession) + return await target.with_transaction(*args, **kwargs) + + async def _sessionOperation_startTransaction(self, target, *args, **kwargs): + if async_client_context.storage_engine == "mmapv1": + self.skipTest("MMAPv1 does not support document-level locking") + self.__raise_if_unsupported("startTransaction", target, AsyncClientSession) + return await target.start_transaction(*args, **kwargs) + + async def _changeStreamOperation_iterateUntilDocumentOrError(self, target, *args, **kwargs): + self.__raise_if_unsupported("iterateUntilDocumentOrError", target, AsyncChangeStream) + return await anext(target) + + async def _cursor_iterateUntilDocumentOrError(self, target, *args, **kwargs): + self.__raise_if_unsupported( + "iterateUntilDocumentOrError", target, NonLazyCursor, AsyncCommandCursor + ) + while target.alive: + try: + return await anext(target) + except StopAsyncIteration: + pass + return None + + async def _cursor_close(self, target, *args, **kwargs): + self.__raise_if_unsupported("close", target, NonLazyCursor, AsyncCommandCursor) + return await target.close() + + async def _clientEncryptionOperation_createDataKey(self, target, *args, **kwargs): + if "opts" in kwargs: + kwargs.update(camel_to_snake_args(kwargs.pop("opts"))) + + return await target.create_data_key(*args, **kwargs) + + async def _clientEncryptionOperation_getKeys(self, target, *args, **kwargs): + return await (await target.get_keys(*args, **kwargs)).to_list() + + async def _clientEncryptionOperation_deleteKey(self, target, *args, **kwargs): + result = await target.delete_key(*args, **kwargs) + response = result.raw_result + response["deletedCount"] = result.deleted_count + return response + + async def _clientEncryptionOperation_rewrapManyDataKey(self, target, *args, **kwargs): + if "opts" in kwargs: + kwargs.update(camel_to_snake_args(kwargs.pop("opts"))) + data = await target.rewrap_many_data_key(*args, **kwargs) + if data.bulk_write_result: + return {"bulkWriteResult": parse_bulk_write_result(data.bulk_write_result)} + return {} + + async def _clientEncryptionOperation_encrypt(self, target, *args, **kwargs): + if "opts" in kwargs: + kwargs.update(camel_to_snake_args(kwargs.pop("opts"))) + return await target.encrypt(*args, **kwargs) + + async def _bucketOperation_download( + self, target: AsyncGridFSBucket, *args: Any, **kwargs: Any + ) -> bytes: + async with await target.open_download_stream(*args, **kwargs) as gout: + return await gout.read() + + async def _bucketOperation_downloadByName( + self, target: AsyncGridFSBucket, *args: Any, **kwargs: Any + ) -> bytes: + async with await target.open_download_stream_by_name(*args, **kwargs) as gout: + return await gout.read() + + async def _bucketOperation_upload( + self, target: AsyncGridFSBucket, *args: Any, **kwargs: Any + ) -> ObjectId: + kwargs["source"] = binascii.unhexlify(kwargs.pop("source")["$$hexBytes"]) + if "content_type" in kwargs: + kwargs.setdefault("metadata", {})["contentType"] = kwargs.pop("content_type") + return await target.upload_from_stream(*args, **kwargs) + + async def _bucketOperation_uploadWithId( + self, target: AsyncGridFSBucket, *args: Any, **kwargs: Any + ) -> Any: + kwargs["source"] = binascii.unhexlify(kwargs.pop("source")["$$hexBytes"]) + if "content_type" in kwargs: + kwargs.setdefault("metadata", {})["contentType"] = kwargs.pop("content_type") + return await target.upload_from_stream_with_id(*args, **kwargs) + + async def _bucketOperation_find( + self, target: AsyncGridFSBucket, *args: Any, **kwargs: Any + ) -> List[GridOut]: + return await target.find(*args, **kwargs).to_list() + + async def run_entity_operation(self, spec): + target = self.entity_map[spec["object"]] + opname = spec["name"] + opargs = spec.get("arguments") + expect_error = spec.get("expectError") + save_as_entity = spec.get("saveResultAsEntity") + expect_result = spec.get("expectResult") + ignore = spec.get("ignoreResultAndError") + if ignore and (expect_error or save_as_entity or expect_result): + raise ValueError( + "ignoreResultAndError is incompatible with saveResultAsEntity" + ", expectError, and expectResult" + ) + if opargs: + arguments = parse_spec_options(copy.deepcopy(opargs)) + prepare_spec_arguments( + spec, + arguments, + camel_to_snake(opname), + self.entity_map, + self.run_operations_and_throw, + ) + else: + arguments = {} + + if isinstance(target, AsyncMongoClient): + method_name = f"_clientOperation_{opname}" + elif isinstance(target, AsyncDatabase): + method_name = f"_databaseOperation_{opname}" + elif isinstance(target, AsyncCollection): + method_name = f"_collectionOperation_{opname}" + # contentType is always stored in metadata in pymongo. + if target.name.endswith(".files") and opname == "find": + for doc in spec.get("expectResult", []): + if "contentType" in doc: + doc.setdefault("metadata", {})["contentType"] = doc.pop("contentType") + elif isinstance(target, AsyncChangeStream): + method_name = f"_changeStreamOperation_{opname}" + elif isinstance(target, (NonLazyCursor, AsyncCommandCursor)): + method_name = f"_cursor_{opname}" + elif isinstance(target, AsyncClientSession): + method_name = f"_sessionOperation_{opname}" + elif isinstance(target, AsyncGridFSBucket): + method_name = f"_bucketOperation_{opname}" + if "id" in arguments: + arguments["file_id"] = arguments.pop("id") + # MD5 is always disabled in pymongo. + arguments.pop("disable_md5", None) + elif isinstance(target, AsyncClientEncryption): + method_name = f"_clientEncryptionOperation_{opname}" + else: + method_name = "doesNotExist" + + try: + method = getattr(self, method_name) + except AttributeError: + target_opname = camel_to_snake(opname) + if target_opname == "iterate_once": + target_opname = "try_next" + if target_opname == "client_bulk_write": + target_opname = "bulk_write" + try: + cmd = getattr(target, target_opname) + except AttributeError: + self.fail(f"Unsupported operation {opname} on entity {target}") + else: + cmd = functools.partial(method, target) + + try: + # CSOT: Translate the spec test "timeout" arg into pymongo's context timeout API. + if "timeout" in arguments: + timeout = arguments.pop("timeout") + with pymongo.timeout(timeout): + result = await cmd(**dict(arguments)) + else: + result = await cmd(**dict(arguments)) + except Exception as exc: + # Ignore all operation errors but to avoid masking bugs don't + # ignore things like TypeError and ValueError. + if ignore and isinstance(exc, (PyMongoError,)): + return exc + if expect_error: + if method_name == "_collectionOperation_bulkWrite": + self.skipTest("Skipping test pending PYTHON-4598") + return self.process_error(exc, expect_error) + raise + else: + if method_name == "_collectionOperation_bulkWrite": + self.skipTest("Skipping test pending PYTHON-4598") + if expect_error: + self.fail(f'Excepted error {expect_error} but "{opname}" succeeded: {result}') + + if expect_result: + actual = coerce_result(opname, result) + self.match_evaluator.match_result(expect_result, actual) + + if save_as_entity: + self.entity_map[save_as_entity] = result + return None + return None + + async def __set_fail_point(self, client, command_args): + if not async_client_context.test_commands_enabled: + self.skipTest("Test commands must be enabled") + + cmd_on = SON([("configureFailPoint", "failCommand")]) + cmd_on.update(command_args) + await client.admin.command(cmd_on) + self.addAsyncCleanup( + client.admin.command, "configureFailPoint", cmd_on["configureFailPoint"], mode="off" + ) + + async def _testOperation_failPoint(self, spec): + await self.__set_fail_point( + client=self.entity_map[spec["client"]], command_args=spec["failPoint"] + ) + + async def _testOperation_targetedFailPoint(self, spec): + session = self.entity_map[spec["session"]] + if not session._pinned_address: + self.fail( + "Cannot use targetedFailPoint operation with unpinned " "session {}".format( + spec["session"] + ) + ) + + client = await self.async_single_client("{}:{}".format(*session._pinned_address)) + self.addAsyncCleanup(client.close) + await self.__set_fail_point(client=client, command_args=spec["failPoint"]) + + async def _testOperation_createEntities(self, spec): + await self.entity_map.create_entities_from_spec(spec["entities"], uri=self._uri) + await self.entity_map.advance_cluster_times() + + def _testOperation_assertSessionTransactionState(self, spec): + session = self.entity_map[spec["session"]] + expected_state = getattr(_TxnState, spec["state"].upper()) + self.assertEqual(expected_state, session._transaction.state) + + def _testOperation_assertSessionPinned(self, spec): + session = self.entity_map[spec["session"]] + self.assertIsNotNone(session._transaction.pinned_address) + + def _testOperation_assertSessionUnpinned(self, spec): + session = self.entity_map[spec["session"]] + self.assertIsNone(session._pinned_address) + self.assertIsNone(session._transaction.pinned_address) + + def __get_last_two_command_lsids(self, listener): + cmd_started_events = [] + for event in reversed(listener.events): + if isinstance(event, CommandStartedEvent): + cmd_started_events.append(event) + if len(cmd_started_events) < 2: + self.fail( + "Needed 2 CommandStartedEvents to compare lsids, " + "got %s" % (len(cmd_started_events)) + ) + return tuple([e.command["lsid"] for e in cmd_started_events][:2]) + + def _testOperation_assertDifferentLsidOnLastTwoCommands(self, spec): + listener = self.entity_map.get_listener_for_client(spec["client"]) + self.assertNotEqual(*self.__get_last_two_command_lsids(listener)) + + def _testOperation_assertSameLsidOnLastTwoCommands(self, spec): + listener = self.entity_map.get_listener_for_client(spec["client"]) + self.assertEqual(*self.__get_last_two_command_lsids(listener)) + + def _testOperation_assertSessionDirty(self, spec): + session = self.entity_map[spec["session"]] + self.assertTrue(session._server_session.dirty) + + def _testOperation_assertSessionNotDirty(self, spec): + session = self.entity_map[spec["session"]] + return self.assertFalse(session._server_session.dirty) + + async def _testOperation_assertCollectionExists(self, spec): + database_name = spec["databaseName"] + collection_name = spec["collectionName"] + collection_name_list = list( + await self.client.get_database(database_name).list_collection_names() + ) + self.assertIn(collection_name, collection_name_list) + + async def _testOperation_assertCollectionNotExists(self, spec): + database_name = spec["databaseName"] + collection_name = spec["collectionName"] + collection_name_list = list( + await self.client.get_database(database_name).list_collection_names() + ) + self.assertNotIn(collection_name, collection_name_list) + + async def _testOperation_assertIndexExists(self, spec): + collection = self.client[spec["databaseName"]][spec["collectionName"]] + index_names = [idx["name"] async for idx in await collection.list_indexes()] + self.assertIn(spec["indexName"], index_names) + + async def _testOperation_assertIndexNotExists(self, spec): + collection = self.client[spec["databaseName"]][spec["collectionName"]] + async for index in await collection.list_indexes(): + self.assertNotEqual(spec["indexName"], index["name"]) + + async def _testOperation_assertNumberConnectionsCheckedOut(self, spec): + client = self.entity_map[spec["client"]] + pool = await async_get_pool(client) + self.assertEqual(spec["connections"], pool.active_sockets) + + def _event_count(self, client_name, event): + listener = self.entity_map.get_listener_for_client(client_name) + actual_events = listener.get_events("all") + count = 0 + for actual in actual_events: + try: + self.match_evaluator.match_event(event, actual) + except AssertionError: + continue + else: + count += 1 + return count + + def _testOperation_assertEventCount(self, spec): + """Run the assertEventCount test operation. + + Assert the given event was published exactly `count` times. + """ + client, event, count = spec["client"], spec["event"], spec["count"] + self.assertEqual(self._event_count(client, event), count, f"expected {count} not {event!r}") + + def _testOperation_waitForEvent(self, spec): + """Run the waitForEvent test operation. + + Wait for a number of events to be published, or fail. + """ + client, event, count = spec["client"], spec["event"], spec["count"] + wait_until( + lambda: self._event_count(client, event) >= count, + f"find {count} {event} event(s)", + ) + + async def _testOperation_wait(self, spec): + """Run the "wait" test operation.""" + await asyncio.sleep(spec["ms"] / 1000.0) + + def _testOperation_recordTopologyDescription(self, spec): + """Run the recordTopologyDescription test operation.""" + self.entity_map[spec["id"]] = self.entity_map[spec["client"]].topology_description + + def _testOperation_assertTopologyType(self, spec): + """Run the assertTopologyType test operation.""" + description = self.entity_map[spec["topologyDescription"]] + self.assertIsInstance(description, TopologyDescription) + self.assertEqual(description.topology_type_name, spec["topologyType"]) + + def _testOperation_waitForPrimaryChange(self, spec: dict) -> None: + """Run the waitForPrimaryChange test operation.""" + client = self.entity_map[spec["client"]] + old_description: TopologyDescription = self.entity_map[spec["priorTopologyDescription"]] + timeout = spec["timeoutMS"] / 1000.0 + + def get_primary(td: TopologyDescription) -> Optional[_Address]: + servers = writable_server_selector(Selection.from_topology_description(td)) + if servers and servers[0].server_type == SERVER_TYPE.RSPrimary: + return servers[0].address + return None + + old_primary = get_primary(old_description) + + def primary_changed() -> bool: + primary = client.primary + if primary is None: + return False + return primary != old_primary + + wait_until(primary_changed, "change primary", timeout=timeout) + + def _testOperation_runOnThread(self, spec): + """Run the 'runOnThread' operation.""" + thread = self.entity_map[spec["thread"]] + thread.schedule(lambda: self.run_entity_operation(spec["operation"])) + + def _testOperation_waitForThread(self, spec): + """Run the 'waitForThread' operation.""" + thread = self.entity_map[spec["thread"]] + thread.stop() + thread.join(10) + if thread.exc: + raise thread.exc + self.assertFalse(thread.is_alive(), "Thread {} is still running".format(spec["thread"])) + + async def _testOperation_loop(self, spec): + failure_key = spec.get("storeFailuresAsEntity") + error_key = spec.get("storeErrorsAsEntity") + successes_key = spec.get("storeSuccessesAsEntity") + iteration_key = spec.get("storeIterationsAsEntity") + iteration_limiter_key = spec.get("numIterations") + for i in [failure_key, error_key]: + if i: + self.entity_map[i] = [] + for i in [successes_key, iteration_key]: + if i: + self.entity_map[i] = 0 + i = 0 + global IS_INTERRUPTED + while True: + if iteration_limiter_key and i >= iteration_limiter_key: + break + i += 1 + if IS_INTERRUPTED: + break + try: + if iteration_key: + self.entity_map._entities[iteration_key] += 1 + for op in spec["operations"]: + await self.run_entity_operation(op) + if successes_key: + self.entity_map._entities[successes_key] += 1 + except Exception as exc: + if isinstance(exc, AssertionError): + key = failure_key or error_key + else: + key = error_key or failure_key + if not key: + raise + self.entity_map[key].append( + {"error": str(exc), "time": time.time(), "type": type(exc).__name__} + ) + + async def run_special_operation(self, spec): + opname = spec["name"] + method_name = f"_testOperation_{opname}" + try: + method = getattr(self, method_name) + except AttributeError: + self.fail(f"Unsupported special test operation {opname}") + else: + if iscoroutinefunction(method): + await method(spec["arguments"]) + else: + method(spec["arguments"]) + + async def run_operations(self, spec): + for op in spec: + if op["object"] == "testRunner": + await self.run_special_operation(op) + else: + await self.run_entity_operation(op) + + async def run_operations_and_throw(self, spec): + for op in spec: + if op["object"] == "testRunner": + await self.run_special_operation(op) + else: + result = await self.run_entity_operation(op) + if isinstance(result, Exception): + raise result + + def check_events(self, spec): + for event_spec in spec: + client_name = event_spec["client"] + events = event_spec["events"] + event_type = event_spec.get("eventType", "command") + ignore_extra_events = event_spec.get("ignoreExtraEvents", False) + server_connection_id = event_spec.get("serverConnectionId") + has_server_connection_id = event_spec.get("hasServerConnectionId", False) + listener = self.entity_map.get_listener_for_client(client_name) + actual_events = listener.get_events(event_type) + if ignore_extra_events: + actual_events = actual_events[: len(events)] + + if len(events) == 0: + self.assertEqual(actual_events, []) + continue + + if len(actual_events) != len(events): + expected = "\n".join(str(e) for e in events) + actual = "\n".join(str(a) for a in actual_events) + self.assertEqual( + len(actual_events), + len(events), + f"expected events:\n{expected}\nactual events:\n{actual}", + ) + + for idx, expected_event in enumerate(events): + self.match_evaluator.match_event(expected_event, actual_events[idx]) + + if has_server_connection_id: + assert server_connection_id is not None + assert server_connection_id >= 0 + else: + assert server_connection_id is None + + def process_ignore_messages(self, ignore_logs, actual_logs): + final_logs = [] + for log in actual_logs: + ignored = False + for ignore_log in ignore_logs: + if log["data"]["message"] == ignore_log["data"][ + "message" + ] and self.match_evaluator.match_result(ignore_log, log, test=False): + ignored = True + break + if not ignored: + final_logs.append(log) + return final_logs + + async def check_log_messages(self, operations, spec): + def format_logs(log_list): + client_to_log = defaultdict(list) + for log in log_list: + if log.module == "ocsp_support": + continue + data = json_util.loads(log.getMessage()) + client = data.pop("clientId") if "clientId" in data else data.pop("topologyId") + client_to_log[client].append( + { + "level": log.levelname.lower(), + "component": log.name.replace("pymongo.", "", 1), + "data": data, + } + ) + return client_to_log + + with self.assertLogs("pymongo", level="DEBUG") as cm: + await self.run_operations(operations) + formatted_logs = format_logs(cm.records) + for client in spec: + components = set() + for message in client["messages"]: + components.add(message["component"]) + + clientid = self.entity_map[client["client"]]._topology_settings._topology_id + actual_logs = formatted_logs[clientid] + actual_logs = [log for log in actual_logs if log["component"] in components] + + ignore_logs = client.get("ignoreMessages", []) + if ignore_logs: + actual_logs = self.process_ignore_messages(ignore_logs, actual_logs) + + if client.get("ignoreExtraMessages", False): + actual_logs = actual_logs[: len(client["messages"])] + self.assertEqual( + len(client["messages"]), + len(actual_logs), + f"expected {client['messages']} but got {actual_logs}", + ) + for expected_msg, actual_msg in zip(client["messages"], actual_logs): + expected_data, actual_data = expected_msg.pop("data"), actual_msg.pop("data") + + if "failureIsRedacted" in expected_msg: + self.assertIn("failure", actual_data) + should_redact = expected_msg.pop("failureIsRedacted") + if should_redact: + actual_fields = set(json_util.loads(actual_data["failure"]).keys()) + self.assertTrue( + {"code", "codeName", "errorLabels"}.issuperset(actual_fields) + ) + + self.match_evaluator.match_result(expected_data, actual_data) + self.match_evaluator.match_result(expected_msg, actual_msg) + + async def verify_outcome(self, spec): + for collection_data in spec: + coll_name = collection_data["collectionName"] + db_name = collection_data["databaseName"] + expected_documents = collection_data["documents"] + + coll = self.client.get_database(db_name).get_collection( + coll_name, + read_preference=ReadPreference.PRIMARY, + read_concern=ReadConcern(level="local"), + ) + + if expected_documents: + sorted_expected_documents = sorted(expected_documents, key=lambda doc: doc["_id"]) + actual_documents = await coll.find({}, sort=[("_id", ASCENDING)]).to_list() + self.assertListEqual(sorted_expected_documents, actual_documents) + + async def run_scenario(self, spec, uri=None): + if "csot" in self.id().lower() and SKIP_CSOT_TESTS: + raise unittest.SkipTest("SKIP_CSOT_TESTS is set, skipping...") + + # Kill all sessions before and after each test to prevent an open + # transaction (from a test failure) from blocking collection/database + # operations during test set up and tear down. + await self.kill_all_sessions() + self.addAsyncCleanup(self.kill_all_sessions) + + if "csot" in self.id().lower(): + # Retry CSOT tests up to 2 times to deal with flakey tests. + attempts = 3 + for i in range(attempts): + try: + return await self._run_scenario(spec, uri) + except AssertionError: + if i < attempts - 1: + print( + f"Retrying after attempt {i+1} of {self.id()} failed with:\n" + f"{traceback.format_exc()}", + file=sys.stderr, + ) + await self.asyncSetUp() + continue + raise + return None + else: + await self._run_scenario(spec, uri) + return None + + async def _run_scenario(self, spec, uri=None): + # maybe skip test manually + self.maybe_skip_test(spec) + + # process test-level runOnRequirements + run_on_spec = spec.get("runOnRequirements", []) + if not await self.should_run_on(run_on_spec): + raise unittest.SkipTest("runOnRequirements not satisfied") + + # process skipReason + skip_reason = spec.get("skipReason", None) + if skip_reason is not None: + raise unittest.SkipTest(f"{skip_reason}") + + # process createEntities + self._uri = uri + self.entity_map = EntityMapUtil(self) + await self.entity_map.create_entities_from_spec( + self.TEST_SPEC.get("createEntities", []), uri=uri + ) + # process initialData + if "initialData" in self.TEST_SPEC: + await self.insert_initial_data(self.TEST_SPEC["initialData"]) + self._cluster_time = (await self.client.admin.command("ping")).get("$clusterTime") + await self.entity_map.advance_cluster_times() + + if "expectLogMessages" in spec: + expect_log_messages = spec["expectLogMessages"] + self.assertTrue(expect_log_messages, "expectEvents must be non-empty") + await self.check_log_messages(spec["operations"], expect_log_messages) + else: + # process operations + await self.run_operations(spec["operations"]) + + # process expectEvents + if "expectEvents" in spec: + expect_events = spec["expectEvents"] + self.assertTrue(expect_events, "expectEvents must be non-empty") + self.check_events(expect_events) + + # process outcome + await self.verify_outcome(spec.get("outcome", [])) + + +class UnifiedSpecTestMeta(type): + """Metaclass for generating test classes.""" + + TEST_SPEC: Any + EXPECTED_FAILURES: Any + + def __init__(cls, *args, **kwargs): + super().__init__(*args, **kwargs) + + def create_test(spec): + async def test_case(self): + await self.run_scenario(spec) + + return test_case + + for test_spec in cls.TEST_SPEC["tests"]: + description = test_spec["description"] + test_name = "test_{}".format( + description.strip(". ").replace(" ", "_").replace(".", "_") + ) + test_method = create_test(copy.deepcopy(test_spec)) + test_method.__name__ = str(test_name) + + for fail_pattern in cls.EXPECTED_FAILURES: + if re.search(fail_pattern, description): + test_method = unittest.expectedFailure(test_method) + break + + setattr(cls, test_name, test_method) + + +_ALL_MIXIN_CLASSES = [ + UnifiedSpecTestMixinV1, + # add mixin classes for new schema major versions here +] + + +_SCHEMA_VERSION_MAJOR_TO_MIXIN_CLASS = { + KLASS.SCHEMA_VERSION[0]: KLASS for KLASS in _ALL_MIXIN_CLASSES +} + + +def generate_test_classes( + test_path, + module=__name__, + class_name_prefix="", + expected_failures=[], # noqa: B006 + bypass_test_generation_errors=False, + **kwargs, +): + """Method for generating test classes. Returns a dictionary where keys are + the names of test classes and values are the test class objects. + """ + test_klasses = {} + + def test_base_class_factory(test_spec): + """Utility that creates the base class to use for test generation. + This is needed to ensure that cls.TEST_SPEC is appropriately set when + the metaclass __init__ is invoked. + """ + + class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)): # type: ignore + TEST_SPEC = test_spec + EXPECTED_FAILURES = expected_failures + + return SpecTestBase + + for dirpath, _, filenames in os.walk(test_path): + dirname = os.path.split(dirpath)[-1] + + for filename in filenames: + fpath = os.path.join(dirpath, filename) + with open(fpath) as scenario_stream: + # Use tz_aware=False to match how CodecOptions decodes + # dates. + opts = json_util.JSONOptions(tz_aware=False) + scenario_def = json_util.loads(scenario_stream.read(), json_options=opts) + + test_type = os.path.splitext(filename)[0] + snake_class_name = "Test{}_{}_{}".format( + class_name_prefix, + dirname.replace("-", "_"), + test_type.replace("-", "_").replace(".", "_"), + ) + class_name = snake_to_camel(snake_class_name) + + try: + schema_version = Version.from_string(scenario_def["schemaVersion"]) + mixin_class = _SCHEMA_VERSION_MAJOR_TO_MIXIN_CLASS.get(schema_version[0]) + if mixin_class is None: + raise ValueError( + f"test file '{fpath}' has unsupported schemaVersion '{schema_version}'" + ) + module_dict = {"__module__": module, "TEST_PATH": test_path} + module_dict.update(kwargs) + test_klasses[class_name] = type( + class_name, + ( + mixin_class, + test_base_class_factory(scenario_def), + ), + module_dict, + ) + except Exception: + if bypass_test_generation_errors: + continue + raise + + return test_klasses diff --git a/test/unified_format.py b/test/unified_format.py index 62211d3d25..6a19082b86 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -18,41 +18,41 @@ """ from __future__ import annotations +import asyncio import binascii -import collections import copy -import datetime import functools import os import re import sys import time import traceback -import types -from collections import abc, defaultdict +from asyncio import iscoroutinefunction +from collections import defaultdict from test import ( IntegrationTest, client_context, client_knobs, unittest, ) -from test.helpers import ( - AWS_CREDS, - AWS_CREDS_2, - AZURE_CREDS, - CA_PEM, - CLIENT_PEM, - GCP_CREDS, - KMIP_CREDS, - LOCAL_MASTER_KEY, - client_knobs, +from test.unified_format_shared import ( + IS_INTERRUPTED, + KMS_TLS_OPTS, + PLACEHOLDER_MAP, + SKIP_CSOT_TESTS, + EventListenerUtil, + MatchEvaluatorUtil, + coerce_result, + parse_bulk_write_error_result, + parse_bulk_write_result, + parse_client_bulk_write_error_result, + parse_collection_or_database_options, + with_metaclass, ) from test.utils import ( - CMAPListener, camel_to_snake, camel_to_snake_args, get_pool, - parse_collection_options, parse_spec_options, prepare_spec_arguments, snake_to_camel, @@ -60,14 +60,12 @@ ) from test.utils_spec_runner import SpecRunnerThread from test.version import Version -from typing import Any, Dict, List, Mapping, Optional, Union +from typing import Any, Dict, List, Mapping, Optional import pymongo -from bson import SON, Code, DBRef, Decimal128, Int64, MaxKey, MinKey, json_util -from bson.binary import Binary +from bson import SON, json_util from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.objectid import ObjectId -from bson.regex import RE_TYPE, Regex from gridfs import GridFSBucket, GridOut from pymongo import ASCENDING, CursorType, MongoClient, _csot from pymongo.encryption_options import _HAVE_PYMONGOCRYPT @@ -83,55 +81,14 @@ PyMongoError, ) from pymongo.monitoring import ( - _SENSITIVE_COMMANDS, - CommandFailedEvent, - CommandListener, CommandStartedEvent, - CommandSucceededEvent, - ConnectionCheckedInEvent, - ConnectionCheckedOutEvent, - ConnectionCheckOutFailedEvent, - ConnectionCheckOutStartedEvent, - ConnectionClosedEvent, - ConnectionCreatedEvent, - ConnectionReadyEvent, - PoolClearedEvent, - PoolClosedEvent, - PoolCreatedEvent, - PoolReadyEvent, - ServerClosedEvent, - ServerDescriptionChangedEvent, - ServerHeartbeatFailedEvent, - ServerHeartbeatListener, - ServerHeartbeatStartedEvent, - ServerHeartbeatSucceededEvent, - ServerListener, - ServerOpeningEvent, - TopologyClosedEvent, - TopologyDescriptionChangedEvent, - TopologyEvent, - TopologyListener, - TopologyOpenedEvent, - _CommandEvent, - _ConnectionEvent, - _PoolEvent, - _ServerEvent, - _ServerHeartbeatEvent, ) from pymongo.operations import ( - DeleteMany, - DeleteOne, - InsertOne, - ReplaceOne, SearchIndexModel, - UpdateMany, - UpdateOne, ) from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference -from pymongo.results import BulkWriteResult, ClientBulkWriteResult from pymongo.server_api import ServerApi -from pymongo.server_description import ServerDescription from pymongo.server_selectors import Selection, writable_server_selector from pymongo.server_type import SERVER_TYPE from pymongo.synchronous.change_stream import ChangeStream @@ -140,85 +97,12 @@ from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.database import Database from pymongo.synchronous.encryption import ClientEncryption +from pymongo.synchronous.helpers import next from pymongo.topology_description import TopologyDescription from pymongo.typings import _Address from pymongo.write_concern import WriteConcern -SKIP_CSOT_TESTS = os.getenv("SKIP_CSOT_TESTS") - -JSON_OPTS = json_util.JSONOptions(tz_aware=False) - -IS_INTERRUPTED = False - -KMS_TLS_OPTS = { - "kmip": { - "tlsCAFile": CA_PEM, - "tlsCertificateKeyFile": CLIENT_PEM, - } -} - - -# Build up a placeholder maps. -PLACEHOLDER_MAP = {} -for provider_name, provider_data in [ - ("local", {"key": LOCAL_MASTER_KEY}), - ("local:name1", {"key": LOCAL_MASTER_KEY}), - ("aws", AWS_CREDS), - ("aws:name1", AWS_CREDS), - ("aws:name2", AWS_CREDS_2), - ("azure", AZURE_CREDS), - ("azure:name1", AZURE_CREDS), - ("gcp", GCP_CREDS), - ("gcp:name1", GCP_CREDS), - ("kmip", KMIP_CREDS), - ("kmip:name1", KMIP_CREDS), -]: - for key, value in provider_data.items(): - placeholder = f"/clientEncryptionOpts/kmsProviders/{provider_name}/{key}" - PLACEHOLDER_MAP[placeholder] = value - -OIDC_ENV = os.environ.get("OIDC_ENV", "test") -if OIDC_ENV == "test": - PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {"ENVIRONMENT": "test"} -elif OIDC_ENV == "azure": - PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = { - "ENVIRONMENT": "azure", - "TOKEN_RESOURCE": os.environ["AZUREOIDC_RESOURCE"], - } -elif OIDC_ENV == "gcp": - PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = { - "ENVIRONMENT": "gcp", - "TOKEN_RESOURCE": os.environ["GCPOIDC_AUDIENCE"], - } - - -def interrupt_loop(): - global IS_INTERRUPTED - IS_INTERRUPTED = True - - -def with_metaclass(meta, *bases): - """Create a base class with a metaclass. - - Vendored from six: https://github.com/benjaminp/six/blob/master/six.py - """ - - # This requires a bit of explanation: the basic idea is to make a dummy - # metaclass for one level of class instantiation that replaces itself with - # the actual metaclass. - class metaclass(type): - def __new__(cls, name, this_bases, d): - # __orig_bases__ is required by PEP 560. - resolved_bases = types.resolve_bases(bases) - if resolved_bases is not bases: - d["__orig_bases__"] = bases - return meta(name, resolved_bases, d) - - @classmethod - def __prepare__(cls, name, this_bases): - return meta.__prepare__(name, bases) - - return type.__new__(metaclass, "temporary_class", (), {}) +_IS_SYNC = True def is_run_on_requirement_satisfied(requirement): @@ -283,77 +167,6 @@ def is_run_on_requirement_satisfied(requirement): ) -def parse_collection_or_database_options(options): - return parse_collection_options(options) - - -def parse_bulk_write_result(result): - upserted_ids = {str(int_idx): result.upserted_ids[int_idx] for int_idx in result.upserted_ids} - return { - "deletedCount": result.deleted_count, - "insertedCount": result.inserted_count, - "matchedCount": result.matched_count, - "modifiedCount": result.modified_count, - "upsertedCount": result.upserted_count, - "upsertedIds": upserted_ids, - } - - -def parse_client_bulk_write_individual(op_type, result): - if op_type == "insert": - return {"insertedId": result.inserted_id} - if op_type == "update": - if result.upserted_id: - return { - "matchedCount": result.matched_count, - "modifiedCount": result.modified_count, - "upsertedId": result.upserted_id, - } - else: - return { - "matchedCount": result.matched_count, - "modifiedCount": result.modified_count, - } - if op_type == "delete": - return { - "deletedCount": result.deleted_count, - } - - -def parse_client_bulk_write_result(result): - insert_results, update_results, delete_results = {}, {}, {} - if result.has_verbose_results: - for idx, res in result.insert_results.items(): - insert_results[str(idx)] = parse_client_bulk_write_individual("insert", res) - for idx, res in result.update_results.items(): - update_results[str(idx)] = parse_client_bulk_write_individual("update", res) - for idx, res in result.delete_results.items(): - delete_results[str(idx)] = parse_client_bulk_write_individual("delete", res) - - return { - "deletedCount": result.deleted_count, - "insertedCount": result.inserted_count, - "matchedCount": result.matched_count, - "modifiedCount": result.modified_count, - "upsertedCount": result.upserted_count, - "insertResults": insert_results, - "updateResults": update_results, - "deleteResults": delete_results, - } - - -def parse_bulk_write_error_result(error): - write_result = BulkWriteResult(error.details, True) - return parse_bulk_write_result(write_result) - - -def parse_client_bulk_write_error_result(error): - write_result = error.partial_result - if not write_result: - return None - return parse_client_bulk_write_result(write_result) - - class NonLazyCursor: """A find cursor proxy that creates the remote cursor when initialized.""" @@ -361,7 +174,16 @@ def __init__(self, find_cursor, client): self.client = client self.find_cursor = find_cursor # Create the server side cursor. - self.first_result = next(find_cursor, None) + self.first_result = None + + @classmethod + def create(cls, find_cursor, client): + cursor = cls(find_cursor, client) + try: + cursor.first_result = next(cursor.find_cursor) + except StopIteration: + cursor.first_result = None + return cursor @property def alive(self): @@ -382,105 +204,6 @@ def close(self): self.client = None -class EventListenerUtil( - CMAPListener, CommandListener, ServerListener, ServerHeartbeatListener, TopologyListener -): - def __init__( - self, observe_events, ignore_commands, observe_sensitive_commands, store_events, entity_map - ): - self._event_types = {name.lower() for name in observe_events} - if observe_sensitive_commands: - self._observe_sensitive_commands = True - self._ignore_commands = set(ignore_commands) - else: - self._observe_sensitive_commands = False - self._ignore_commands = _SENSITIVE_COMMANDS | set(ignore_commands) - self._ignore_commands.add("configurefailpoint") - self._event_mapping = collections.defaultdict(list) - self.entity_map = entity_map - if store_events: - for i in store_events: - id = i["id"] - events = (i.lower() for i in i["events"]) - for i in events: - self._event_mapping[i].append(id) - self.entity_map[id] = [] - super().__init__() - - def get_events(self, event_type): - assert event_type in ("command", "cmap", "sdam", "all"), event_type - if event_type == "all": - return list(self.events) - if event_type == "command": - return [e for e in self.events if isinstance(e, _CommandEvent)] - if event_type == "cmap": - return [e for e in self.events if isinstance(e, (_ConnectionEvent, _PoolEvent))] - return [ - e - for e in self.events - if isinstance(e, (_ServerEvent, TopologyEvent, _ServerHeartbeatEvent)) - ] - - def add_event(self, event): - event_name = type(event).__name__.lower() - if event_name in self._event_types: - super().add_event(event) - for id in self._event_mapping[event_name]: - self.entity_map[id].append( - { - "name": type(event).__name__, - "observedAt": time.time(), - "description": repr(event), - } - ) - - def _command_event(self, event): - if event.command_name.lower() not in self._ignore_commands: - self.add_event(event) - - def started(self, event): - if isinstance(event, CommandStartedEvent): - if event.command == {}: - # Command is redacted. Observe only if flag is set. - if self._observe_sensitive_commands: - self._command_event(event) - else: - self._command_event(event) - else: - self.add_event(event) - - def succeeded(self, event): - if isinstance(event, CommandSucceededEvent): - if event.reply == {}: - # Command is redacted. Observe only if flag is set. - if self._observe_sensitive_commands: - self._command_event(event) - else: - self._command_event(event) - else: - self.add_event(event) - - def failed(self, event): - if isinstance(event, CommandFailedEvent): - self._command_event(event) - else: - self.add_event(event) - - def opened(self, event: Union[ServerOpeningEvent, TopologyOpenedEvent]) -> None: - self.add_event(event) - - def description_changed( - self, event: Union[ServerDescriptionChangedEvent, TopologyDescriptionChangedEvent] - ) -> None: - self.add_event(event) - - def topology_changed(self, event: TopologyDescriptionChangedEvent) -> None: - self.add_event(event) - - def closed(self, event: Union[ServerClosedEvent, TopologyClosedEvent]) -> None: - self.add_event(event) - - class EntityMapUtil: """Utility class that implements an entity map as per the unified test format specification. @@ -692,353 +415,12 @@ def get_lsid_for_session(self, session_name): def advance_cluster_times(self) -> None: """Manually synchronize entities when desired""" if not self._cluster_time: - self._cluster_time = self.test.client.admin.command("ping").get("$clusterTime") + self._cluster_time = (self.test.client.admin.command("ping")).get("$clusterTime") for entity in self._entities.values(): if isinstance(entity, ClientSession) and self._cluster_time: entity.advance_cluster_time(self._cluster_time) -binary_types = (Binary, bytes) -long_types = (Int64,) -unicode_type = str - - -BSON_TYPE_ALIAS_MAP = { - # https://mongodb.com/docs/manual/reference/operator/query/type/ - # https://pymongo.readthedocs.io/en/stable/api/bson/index.html - "double": (float,), - "string": (str,), - "object": (abc.Mapping,), - "array": (abc.MutableSequence,), - "binData": binary_types, - "undefined": (type(None),), - "objectId": (ObjectId,), - "bool": (bool,), - "date": (datetime.datetime,), - "null": (type(None),), - "regex": (Regex, RE_TYPE), - "dbPointer": (DBRef,), - "javascript": (unicode_type, Code), - "symbol": (unicode_type,), - "javascriptWithScope": (unicode_type, Code), - "int": (int,), - "long": (Int64,), - "decimal": (Decimal128,), - "maxKey": (MaxKey,), - "minKey": (MinKey,), -} - - -class MatchEvaluatorUtil: - """Utility class that implements methods for evaluating matches as per - the unified test format specification. - """ - - def __init__(self, test_class): - self.test = test_class - - def _operation_exists(self, spec, actual, key_to_compare): - if spec is True: - if key_to_compare is None: - assert actual is not None - else: - self.test.assertIn(key_to_compare, actual) - elif spec is False: - if key_to_compare is None: - assert actual is None - else: - self.test.assertNotIn(key_to_compare, actual) - else: - self.test.fail(f"Expected boolean value for $$exists operator, got {spec}") - - def __type_alias_to_type(self, alias): - if alias not in BSON_TYPE_ALIAS_MAP: - self.test.fail(f"Unrecognized BSON type alias {alias}") - return BSON_TYPE_ALIAS_MAP[alias] - - def _operation_type(self, spec, actual, key_to_compare): - if isinstance(spec, abc.MutableSequence): - permissible_types = tuple( - [t for alias in spec for t in self.__type_alias_to_type(alias)] - ) - else: - permissible_types = self.__type_alias_to_type(spec) - value = actual[key_to_compare] if key_to_compare else actual - self.test.assertIsInstance(value, permissible_types) - - def _operation_matchesEntity(self, spec, actual, key_to_compare): - expected_entity = self.test.entity_map[spec] - self.test.assertEqual(expected_entity, actual[key_to_compare]) - - def _operation_matchesHexBytes(self, spec, actual, key_to_compare): - expected = binascii.unhexlify(spec) - value = actual[key_to_compare] if key_to_compare else actual - self.test.assertEqual(value, expected) - - def _operation_unsetOrMatches(self, spec, actual, key_to_compare): - if key_to_compare is None and not actual: - # top-level document can be None when unset - return - - if key_to_compare not in actual: - # we add a dummy value for the compared key to pass map size check - actual[key_to_compare] = "dummyValue" - return - self.match_result(spec, actual[key_to_compare], in_recursive_call=True) - - def _operation_sessionLsid(self, spec, actual, key_to_compare): - expected_lsid = self.test.entity_map.get_lsid_for_session(spec) - self.test.assertEqual(expected_lsid, actual[key_to_compare]) - - def _operation_lte(self, spec, actual, key_to_compare): - if key_to_compare not in actual: - self.test.fail(f"Actual command is missing the {key_to_compare} field: {spec}") - self.test.assertLessEqual(actual[key_to_compare], spec) - - def _operation_matchAsDocument(self, spec, actual, key_to_compare): - self._match_document(spec, json_util.loads(actual[key_to_compare]), False) - - def _operation_matchAsRoot(self, spec, actual, key_to_compare): - self._match_document(spec, actual, True) - - def _evaluate_special_operation(self, opname, spec, actual, key_to_compare): - method_name = "_operation_{}".format(opname.strip("$")) - try: - method = getattr(self, method_name) - except AttributeError: - self.test.fail(f"Unsupported special matching operator {opname}") - else: - method(spec, actual, key_to_compare) - - def _evaluate_if_special_operation(self, expectation, actual, key_to_compare=None): - """Returns True if a special operation is evaluated, False - otherwise. If the ``expectation`` map contains a single key, - value pair we check it for a special operation. - If given, ``key_to_compare`` is assumed to be the key in - ``expectation`` whose corresponding value needs to be - evaluated for a possible special operation. ``key_to_compare`` - is ignored when ``expectation`` has only one key. - """ - if not isinstance(expectation, abc.Mapping): - return False - - is_special_op, opname, spec = False, False, False - - if key_to_compare is not None: - if key_to_compare.startswith("$$"): - is_special_op = True - opname = key_to_compare - spec = expectation[key_to_compare] - key_to_compare = None - else: - nested = expectation[key_to_compare] - if isinstance(nested, abc.Mapping) and len(nested) == 1: - opname, spec = next(iter(nested.items())) - if opname.startswith("$$"): - is_special_op = True - elif len(expectation) == 1: - opname, spec = next(iter(expectation.items())) - if opname.startswith("$$"): - is_special_op = True - key_to_compare = None - - if is_special_op: - self._evaluate_special_operation( - opname=opname, spec=spec, actual=actual, key_to_compare=key_to_compare - ) - return True - - return False - - def _match_document(self, expectation, actual, is_root, test=False): - if self._evaluate_if_special_operation(expectation, actual): - return - - self.test.assertIsInstance(actual, abc.Mapping) - for key, value in expectation.items(): - if self._evaluate_if_special_operation(expectation, actual, key): - continue - - self.test.assertIn(key, actual) - if not self.match_result(value, actual[key], in_recursive_call=True, test=test): - return False - - if not is_root: - expected_keys = set(expectation.keys()) - for key, value in expectation.items(): - if value == {"$$exists": False}: - expected_keys.remove(key) - if test: - self.test.assertEqual(expected_keys, set(actual.keys())) - else: - return set(expected_keys).issubset(set(actual.keys())) - return True - - def match_result(self, expectation, actual, in_recursive_call=False, test=True): - if isinstance(expectation, abc.Mapping): - return self._match_document( - expectation, actual, is_root=not in_recursive_call, test=test - ) - - if isinstance(expectation, abc.MutableSequence): - self.test.assertIsInstance(actual, abc.MutableSequence) - for e, a in zip(expectation, actual): - if isinstance(e, abc.Mapping): - self._match_document(e, a, is_root=not in_recursive_call, test=test) - else: - self.match_result(e, a, in_recursive_call=True, test=test) - return None - - # account for flexible numerics in element-wise comparison - if isinstance(expectation, int) or isinstance(expectation, float): - if test: - self.test.assertEqual(expectation, actual) - else: - return expectation == actual - return None - else: - if test: - self.test.assertIsInstance(actual, type(expectation)) - self.test.assertEqual(expectation, actual) - else: - return isinstance(actual, type(expectation)) and expectation == actual - return None - - def match_server_description(self, actual: ServerDescription, spec: dict) -> None: - for field, expected in spec.items(): - field = camel_to_snake(field) - if field == "type": - field = "server_type_name" - self.test.assertEqual(getattr(actual, field), expected) - - def match_topology_description(self, actual: TopologyDescription, spec: dict) -> None: - for field, expected in spec.items(): - field = camel_to_snake(field) - if field == "type": - field = "topology_type_name" - self.test.assertEqual(getattr(actual, field), expected) - - def match_event_fields(self, actual: Any, spec: dict) -> None: - for field, expected in spec.items(): - if field == "command" and isinstance(actual, CommandStartedEvent): - command = spec["command"] - if command: - self.match_result(command, actual.command) - continue - if field == "reply" and isinstance(actual, CommandSucceededEvent): - reply = spec["reply"] - if reply: - self.match_result(reply, actual.reply) - continue - if field == "hasServiceId": - if spec["hasServiceId"]: - self.test.assertIsNotNone(actual.service_id) - self.test.assertIsInstance(actual.service_id, ObjectId) - else: - self.test.assertIsNone(actual.service_id) - continue - if field == "hasServerConnectionId": - if spec["hasServerConnectionId"]: - self.test.assertIsNotNone(actual.server_connection_id) - self.test.assertIsInstance(actual.server_connection_id, int) - else: - self.test.assertIsNone(actual.server_connection_id) - continue - if field in ("previousDescription", "newDescription"): - if isinstance(actual, ServerDescriptionChangedEvent): - self.match_server_description( - getattr(actual, camel_to_snake(field)), spec[field] - ) - continue - if isinstance(actual, TopologyDescriptionChangedEvent): - self.match_topology_description( - getattr(actual, camel_to_snake(field)), spec[field] - ) - continue - - if field == "interruptInUseConnections": - field = "interrupt_connections" - else: - field = camel_to_snake(field) - self.test.assertEqual(getattr(actual, field), expected) - - def match_event(self, expectation, actual): - name, spec = next(iter(expectation.items())) - if name == "commandStartedEvent": - self.test.assertIsInstance(actual, CommandStartedEvent) - elif name == "commandSucceededEvent": - self.test.assertIsInstance(actual, CommandSucceededEvent) - elif name == "commandFailedEvent": - self.test.assertIsInstance(actual, CommandFailedEvent) - elif name == "poolCreatedEvent": - self.test.assertIsInstance(actual, PoolCreatedEvent) - elif name == "poolReadyEvent": - self.test.assertIsInstance(actual, PoolReadyEvent) - elif name == "poolClearedEvent": - self.test.assertIsInstance(actual, PoolClearedEvent) - self.test.assertIsInstance(actual.interrupt_connections, bool) - elif name == "poolClosedEvent": - self.test.assertIsInstance(actual, PoolClosedEvent) - elif name == "connectionCreatedEvent": - self.test.assertIsInstance(actual, ConnectionCreatedEvent) - elif name == "connectionReadyEvent": - self.test.assertIsInstance(actual, ConnectionReadyEvent) - elif name == "connectionClosedEvent": - self.test.assertIsInstance(actual, ConnectionClosedEvent) - elif name == "connectionCheckOutStartedEvent": - self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent) - elif name == "connectionCheckOutFailedEvent": - self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent) - elif name == "connectionCheckedOutEvent": - self.test.assertIsInstance(actual, ConnectionCheckedOutEvent) - elif name == "connectionCheckedInEvent": - self.test.assertIsInstance(actual, ConnectionCheckedInEvent) - elif name == "serverDescriptionChangedEvent": - self.test.assertIsInstance(actual, ServerDescriptionChangedEvent) - elif name == "serverHeartbeatStartedEvent": - self.test.assertIsInstance(actual, ServerHeartbeatStartedEvent) - elif name == "serverHeartbeatSucceededEvent": - self.test.assertIsInstance(actual, ServerHeartbeatSucceededEvent) - elif name == "serverHeartbeatFailedEvent": - self.test.assertIsInstance(actual, ServerHeartbeatFailedEvent) - elif name == "topologyDescriptionChangedEvent": - self.test.assertIsInstance(actual, TopologyDescriptionChangedEvent) - elif name == "topologyOpeningEvent": - self.test.assertIsInstance(actual, TopologyOpenedEvent) - elif name == "topologyClosedEvent": - self.test.assertIsInstance(actual, TopologyClosedEvent) - else: - raise Exception(f"Unsupported event type {name}") - - self.match_event_fields(actual, spec) - - -def coerce_result(opname, result): - """Convert a pymongo result into the spec's result format.""" - if hasattr(result, "acknowledged") and not result.acknowledged: - return {"acknowledged": False} - if opname == "bulkWrite": - return parse_bulk_write_result(result) - if opname == "clientBulkWrite": - return parse_client_bulk_write_result(result) - if opname == "insertOne": - return {"insertedId": result.inserted_id} - if opname == "insertMany": - return dict(enumerate(result.inserted_ids)) - if opname in ("deleteOne", "deleteMany"): - return {"deletedCount": result.deleted_count} - if opname in ("updateOne", "updateMany", "replaceOne"): - value = { - "matchedCount": result.matched_count, - "modifiedCount": result.modified_count, - "upsertedCount": 0 if result.upserted_id is None else 1, - } - if result.upserted_id is not None: - value["upsertedId"] = result.upserted_id - return value - return result - - class UnifiedSpecTestMixinV1(IntegrationTest): """Mixin class to run test cases from test specification files. @@ -1090,9 +472,9 @@ def insert_initial_data(self, initial_data): db.create_collection(coll_name, write_concern=wc, **opts) @classmethod - def setUpClass(cls): + def _setup_class(cls): # super call creates internal client cls.client - super().setUpClass() + super()._setup_class() # process file-level runOnRequirements run_on_spec = cls.TEST_SPEC.get("runOnRequirements", []) if not cls.should_run_on(run_on_spec): @@ -1125,11 +507,11 @@ def setUpClass(cls): cls.knobs.enable() @classmethod - def tearDownClass(cls): + def _tearDown_class(cls): cls.knobs.disable() for client in cls.mongos_clients: client.close() - super().tearDownClass() + super()._tearDown_class() def setUp(self): super().setUp() @@ -1391,7 +773,7 @@ def _databaseOperation_createCollection(self, target, *args, **kwargs): def __entityOperation_aggregate(self, target, *args, **kwargs): self.__raise_if_unsupported("aggregate", target, Database, Collection) - return list(target.aggregate(*args, **kwargs)) + return (target.aggregate(*args, **kwargs)).to_list() def _databaseOperation_aggregate(self, target, *args, **kwargs): return self.__entityOperation_aggregate(target, *args, **kwargs) @@ -1402,13 +784,13 @@ def _collectionOperation_aggregate(self, target, *args, **kwargs): def _collectionOperation_find(self, target, *args, **kwargs): self.__raise_if_unsupported("find", target, Collection) find_cursor = target.find(*args, **kwargs) - return list(find_cursor) + return find_cursor.to_list() def _collectionOperation_createFindCursor(self, target, *args, **kwargs): self.__raise_if_unsupported("find", target, Collection) if "filter" not in kwargs: self.fail('createFindCursor requires a "filter" argument') - cursor = NonLazyCursor(target.find(*args, **kwargs), target.database.client) + cursor = NonLazyCursor.create(target.find(*args, **kwargs), target.database.client) self.addCleanup(cursor.close) return cursor @@ -1418,7 +800,7 @@ def _collectionOperation_count(self, target, *args, **kwargs): def _collectionOperation_listIndexes(self, target, *args, **kwargs): if "batch_size" in kwargs: self.skipTest("PyMongo does not support batch_size for list_indexes") - return list(target.list_indexes(*args, **kwargs)) + return (target.list_indexes(*args, **kwargs)).to_list() def _collectionOperation_listIndexNames(self, target, *args, **kwargs): self.skipTest("PyMongo does not support list_index_names") @@ -1430,7 +812,7 @@ def _collectionOperation_createSearchIndexes(self, target, *args, **kwargs): def _collectionOperation_listSearchIndexes(self, target, *args, **kwargs): name = kwargs.get("name") agg_kwargs = kwargs.get("aggregation_options", dict()) - return list(target.list_search_indexes(name, **agg_kwargs)) + return (target.list_search_indexes(name, **agg_kwargs)).to_list() def _sessionOperation_withTransaction(self, target, *args, **kwargs): if client_context.storage_engine == "mmapv1": @@ -1470,7 +852,7 @@ def _clientEncryptionOperation_createDataKey(self, target, *args, **kwargs): return target.create_data_key(*args, **kwargs) def _clientEncryptionOperation_getKeys(self, target, *args, **kwargs): - return list(target.get_keys(*args, **kwargs)) + return (target.get_keys(*args, **kwargs)).to_list() def _clientEncryptionOperation_deleteKey(self, target, *args, **kwargs): result = target.delete_key(*args, **kwargs) @@ -1516,7 +898,7 @@ def _bucketOperation_uploadWithId(self, target: GridFSBucket, *args: Any, **kwar def _bucketOperation_find( self, target: GridFSBucket, *args: Any, **kwargs: Any ) -> List[GridOut]: - return list(target.find(*args, **kwargs)) + return target.find(*args, **kwargs).to_list() def run_entity_operation(self, spec): target = self.entity_map[spec["object"]] @@ -1849,7 +1231,10 @@ def run_special_operation(self, spec): except AttributeError: self.fail(f"Unsupported special test operation {opname}") else: - method(spec["arguments"]) + if iscoroutinefunction(method): + method(spec["arguments"]) + else: + method(spec["arguments"]) def run_operations(self, spec): for op in spec: @@ -1985,7 +1370,7 @@ def verify_outcome(self, spec): if expected_documents: sorted_expected_documents = sorted(expected_documents, key=lambda doc: doc["_id"]) - actual_documents = list(coll.find({}, sort=[("_id", ASCENDING)])) + actual_documents = coll.find({}, sort=[("_id", ASCENDING)]).to_list() self.assertListEqual(sorted_expected_documents, actual_documents) def run_scenario(self, spec, uri=None): @@ -2040,7 +1425,7 @@ def _run_scenario(self, spec, uri=None): # process initialData if "initialData" in self.TEST_SPEC: self.insert_initial_data(self.TEST_SPEC["initialData"]) - self._cluster_time = self.client.admin.command("ping").get("$clusterTime") + self._cluster_time = (self.client.admin.command("ping")).get("$clusterTime") self.entity_map.advance_cluster_times() if "expectLogMessages" in spec: diff --git a/test/unified_format_shared.py b/test/unified_format_shared.py new file mode 100644 index 0000000000..d11624476d --- /dev/null +++ b/test/unified_format_shared.py @@ -0,0 +1,679 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared utility functions and constants for the unified test format runner. + +https://github.com/mongodb/specifications/blob/master/source/unified-test-format/unified-test-format.rst +""" +from __future__ import annotations + +import binascii +import collections +import datetime +import os +import time +import types +from collections import abc +from test.helpers import ( + AWS_CREDS, + AWS_CREDS_2, + AZURE_CREDS, + CA_PEM, + CLIENT_PEM, + GCP_CREDS, + KMIP_CREDS, + LOCAL_MASTER_KEY, +) +from test.utils import CMAPListener, camel_to_snake, parse_collection_options +from typing import Any, Union + +from bson import ( + RE_TYPE, + Binary, + Code, + DBRef, + Decimal128, + Int64, + MaxKey, + MinKey, + ObjectId, + Regex, + json_util, +) +from pymongo.monitoring import ( + _SENSITIVE_COMMANDS, + CommandFailedEvent, + CommandListener, + CommandStartedEvent, + CommandSucceededEvent, + ConnectionCheckedInEvent, + ConnectionCheckedOutEvent, + ConnectionCheckOutFailedEvent, + ConnectionCheckOutStartedEvent, + ConnectionClosedEvent, + ConnectionCreatedEvent, + ConnectionReadyEvent, + PoolClearedEvent, + PoolClosedEvent, + PoolCreatedEvent, + PoolReadyEvent, + ServerClosedEvent, + ServerDescriptionChangedEvent, + ServerHeartbeatFailedEvent, + ServerHeartbeatListener, + ServerHeartbeatStartedEvent, + ServerHeartbeatSucceededEvent, + ServerListener, + ServerOpeningEvent, + TopologyClosedEvent, + TopologyDescriptionChangedEvent, + TopologyEvent, + TopologyListener, + TopologyOpenedEvent, + _CommandEvent, + _ConnectionEvent, + _PoolEvent, + _ServerEvent, + _ServerHeartbeatEvent, +) +from pymongo.results import BulkWriteResult +from pymongo.server_description import ServerDescription +from pymongo.topology_description import TopologyDescription + +SKIP_CSOT_TESTS = os.getenv("SKIP_CSOT_TESTS") + +JSON_OPTS = json_util.JSONOptions(tz_aware=False) + +IS_INTERRUPTED = False + +KMS_TLS_OPTS = { + "kmip": { + "tlsCAFile": CA_PEM, + "tlsCertificateKeyFile": CLIENT_PEM, + } +} + + +# Build up a placeholder maps. +PLACEHOLDER_MAP = {} +for provider_name, provider_data in [ + ("local", {"key": LOCAL_MASTER_KEY}), + ("local:name1", {"key": LOCAL_MASTER_KEY}), + ("aws", AWS_CREDS), + ("aws:name1", AWS_CREDS), + ("aws:name2", AWS_CREDS_2), + ("azure", AZURE_CREDS), + ("azure:name1", AZURE_CREDS), + ("gcp", GCP_CREDS), + ("gcp:name1", GCP_CREDS), + ("kmip", KMIP_CREDS), + ("kmip:name1", KMIP_CREDS), +]: + for key, value in provider_data.items(): + placeholder = f"/clientEncryptionOpts/kmsProviders/{provider_name}/{key}" + PLACEHOLDER_MAP[placeholder] = value + +OIDC_ENV = os.environ.get("OIDC_ENV", "test") +if OIDC_ENV == "test": + PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {"ENVIRONMENT": "test"} +elif OIDC_ENV == "azure": + PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = { + "ENVIRONMENT": "azure", + "TOKEN_RESOURCE": os.environ["AZUREOIDC_RESOURCE"], + } +elif OIDC_ENV == "gcp": + PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = { + "ENVIRONMENT": "gcp", + "TOKEN_RESOURCE": os.environ["GCPOIDC_AUDIENCE"], + } + + +def interrupt_loop(): + global IS_INTERRUPTED + IS_INTERRUPTED = True + + +def with_metaclass(meta, *bases): + """Create a base class with a metaclass. + + Vendored from six: https://github.com/benjaminp/six/blob/master/six.py + """ + + # This requires a bit of explanation: the basic idea is to make a dummy + # metaclass for one level of class instantiation that replaces itself with + # the actual metaclass. + class metaclass(type): + def __new__(cls, name, this_bases, d): + # __orig_bases__ is required by PEP 560. + resolved_bases = types.resolve_bases(bases) + if resolved_bases is not bases: + d["__orig_bases__"] = bases + return meta(name, resolved_bases, d) + + @classmethod + def __prepare__(cls, name, this_bases): + return meta.__prepare__(name, bases) + + return type.__new__(metaclass, "temporary_class", (), {}) + + +def parse_collection_or_database_options(options): + return parse_collection_options(options) + + +def parse_bulk_write_result(result): + upserted_ids = {str(int_idx): result.upserted_ids[int_idx] for int_idx in result.upserted_ids} + return { + "deletedCount": result.deleted_count, + "insertedCount": result.inserted_count, + "matchedCount": result.matched_count, + "modifiedCount": result.modified_count, + "upsertedCount": result.upserted_count, + "upsertedIds": upserted_ids, + } + + +def parse_client_bulk_write_individual(op_type, result): + if op_type == "insert": + return {"insertedId": result.inserted_id} + if op_type == "update": + if result.upserted_id: + return { + "matchedCount": result.matched_count, + "modifiedCount": result.modified_count, + "upsertedId": result.upserted_id, + } + else: + return { + "matchedCount": result.matched_count, + "modifiedCount": result.modified_count, + } + if op_type == "delete": + return { + "deletedCount": result.deleted_count, + } + + +def parse_client_bulk_write_result(result): + insert_results, update_results, delete_results = {}, {}, {} + if result.has_verbose_results: + for idx, res in result.insert_results.items(): + insert_results[str(idx)] = parse_client_bulk_write_individual("insert", res) + for idx, res in result.update_results.items(): + update_results[str(idx)] = parse_client_bulk_write_individual("update", res) + for idx, res in result.delete_results.items(): + delete_results[str(idx)] = parse_client_bulk_write_individual("delete", res) + + return { + "deletedCount": result.deleted_count, + "insertedCount": result.inserted_count, + "matchedCount": result.matched_count, + "modifiedCount": result.modified_count, + "upsertedCount": result.upserted_count, + "insertResults": insert_results, + "updateResults": update_results, + "deleteResults": delete_results, + } + + +def parse_bulk_write_error_result(error): + write_result = BulkWriteResult(error.details, True) + return parse_bulk_write_result(write_result) + + +def parse_client_bulk_write_error_result(error): + write_result = error.partial_result + if not write_result: + return None + return parse_client_bulk_write_result(write_result) + + +class EventListenerUtil( + CMAPListener, CommandListener, ServerListener, ServerHeartbeatListener, TopologyListener +): + def __init__( + self, observe_events, ignore_commands, observe_sensitive_commands, store_events, entity_map + ): + self._event_types = {name.lower() for name in observe_events} + if observe_sensitive_commands: + self._observe_sensitive_commands = True + self._ignore_commands = set(ignore_commands) + else: + self._observe_sensitive_commands = False + self._ignore_commands = _SENSITIVE_COMMANDS | set(ignore_commands) + self._ignore_commands.add("configurefailpoint") + self._event_mapping = collections.defaultdict(list) + self.entity_map = entity_map + if store_events: + for i in store_events: + id = i["id"] + events = (i.lower() for i in i["events"]) + for i in events: + self._event_mapping[i].append(id) + self.entity_map[id] = [] + super().__init__() + + def get_events(self, event_type): + assert event_type in ("command", "cmap", "sdam", "all"), event_type + if event_type == "all": + return list(self.events) + if event_type == "command": + return [e for e in self.events if isinstance(e, _CommandEvent)] + if event_type == "cmap": + return [e for e in self.events if isinstance(e, (_ConnectionEvent, _PoolEvent))] + return [ + e + for e in self.events + if isinstance(e, (_ServerEvent, TopologyEvent, _ServerHeartbeatEvent)) + ] + + def add_event(self, event): + event_name = type(event).__name__.lower() + if event_name in self._event_types: + super().add_event(event) + for id in self._event_mapping[event_name]: + self.entity_map[id].append( + { + "name": type(event).__name__, + "observedAt": time.time(), + "description": repr(event), + } + ) + + def _command_event(self, event): + if event.command_name.lower() not in self._ignore_commands: + self.add_event(event) + + def started(self, event): + if isinstance(event, CommandStartedEvent): + if event.command == {}: + # Command is redacted. Observe only if flag is set. + if self._observe_sensitive_commands: + self._command_event(event) + else: + self._command_event(event) + else: + self.add_event(event) + + def succeeded(self, event): + if isinstance(event, CommandSucceededEvent): + if event.reply == {}: + # Command is redacted. Observe only if flag is set. + if self._observe_sensitive_commands: + self._command_event(event) + else: + self._command_event(event) + else: + self.add_event(event) + + def failed(self, event): + if isinstance(event, CommandFailedEvent): + self._command_event(event) + else: + self.add_event(event) + + def opened(self, event: Union[ServerOpeningEvent, TopologyOpenedEvent]) -> None: + self.add_event(event) + + def description_changed( + self, event: Union[ServerDescriptionChangedEvent, TopologyDescriptionChangedEvent] + ) -> None: + self.add_event(event) + + def topology_changed(self, event: TopologyDescriptionChangedEvent) -> None: + self.add_event(event) + + def closed(self, event: Union[ServerClosedEvent, TopologyClosedEvent]) -> None: + self.add_event(event) + + +binary_types = (Binary, bytes) +long_types = (Int64,) +unicode_type = str + + +BSON_TYPE_ALIAS_MAP = { + # https://mongodb.com/docs/manual/reference/operator/query/type/ + # https://pymongo.readthedocs.io/en/stable/api/bson/index.html + "double": (float,), + "string": (str,), + "object": (abc.Mapping,), + "array": (abc.MutableSequence,), + "binData": binary_types, + "undefined": (type(None),), + "objectId": (ObjectId,), + "bool": (bool,), + "date": (datetime.datetime,), + "null": (type(None),), + "regex": (Regex, RE_TYPE), + "dbPointer": (DBRef,), + "javascript": (unicode_type, Code), + "symbol": (unicode_type,), + "javascriptWithScope": (unicode_type, Code), + "int": (int,), + "long": (Int64,), + "decimal": (Decimal128,), + "maxKey": (MaxKey,), + "minKey": (MinKey,), +} + + +class MatchEvaluatorUtil: + """Utility class that implements methods for evaluating matches as per + the unified test format specification. + """ + + def __init__(self, test_class): + self.test = test_class + + def _operation_exists(self, spec, actual, key_to_compare): + if spec is True: + if key_to_compare is None: + assert actual is not None + else: + self.test.assertIn(key_to_compare, actual) + elif spec is False: + if key_to_compare is None: + assert actual is None + else: + self.test.assertNotIn(key_to_compare, actual) + else: + self.test.fail(f"Expected boolean value for $$exists operator, got {spec}") + + def __type_alias_to_type(self, alias): + if alias not in BSON_TYPE_ALIAS_MAP: + self.test.fail(f"Unrecognized BSON type alias {alias}") + return BSON_TYPE_ALIAS_MAP[alias] + + def _operation_type(self, spec, actual, key_to_compare): + if isinstance(spec, abc.MutableSequence): + permissible_types = tuple( + [t for alias in spec for t in self.__type_alias_to_type(alias)] + ) + else: + permissible_types = self.__type_alias_to_type(spec) + value = actual[key_to_compare] if key_to_compare else actual + self.test.assertIsInstance(value, permissible_types) + + def _operation_matchesEntity(self, spec, actual, key_to_compare): + expected_entity = self.test.entity_map[spec] + self.test.assertEqual(expected_entity, actual[key_to_compare]) + + def _operation_matchesHexBytes(self, spec, actual, key_to_compare): + expected = binascii.unhexlify(spec) + value = actual[key_to_compare] if key_to_compare else actual + self.test.assertEqual(value, expected) + + def _operation_unsetOrMatches(self, spec, actual, key_to_compare): + if key_to_compare is None and not actual: + # top-level document can be None when unset + return + + if key_to_compare not in actual: + # we add a dummy value for the compared key to pass map size check + actual[key_to_compare] = "dummyValue" + return + self.match_result(spec, actual[key_to_compare], in_recursive_call=True) + + def _operation_sessionLsid(self, spec, actual, key_to_compare): + expected_lsid = self.test.entity_map.get_lsid_for_session(spec) + self.test.assertEqual(expected_lsid, actual[key_to_compare]) + + def _operation_lte(self, spec, actual, key_to_compare): + if key_to_compare not in actual: + self.test.fail(f"Actual command is missing the {key_to_compare} field: {spec}") + self.test.assertLessEqual(actual[key_to_compare], spec) + + def _operation_matchAsDocument(self, spec, actual, key_to_compare): + self._match_document(spec, json_util.loads(actual[key_to_compare]), False) + + def _operation_matchAsRoot(self, spec, actual, key_to_compare): + self._match_document(spec, actual, True) + + def _evaluate_special_operation(self, opname, spec, actual, key_to_compare): + method_name = "_operation_{}".format(opname.strip("$")) + try: + method = getattr(self, method_name) + except AttributeError: + self.test.fail(f"Unsupported special matching operator {opname}") + else: + method(spec, actual, key_to_compare) + + def _evaluate_if_special_operation(self, expectation, actual, key_to_compare=None): + """Returns True if a special operation is evaluated, False + otherwise. If the ``expectation`` map contains a single key, + value pair we check it for a special operation. + If given, ``key_to_compare`` is assumed to be the key in + ``expectation`` whose corresponding value needs to be + evaluated for a possible special operation. ``key_to_compare`` + is ignored when ``expectation`` has only one key. + """ + if not isinstance(expectation, abc.Mapping): + return False + + is_special_op, opname, spec = False, False, False + + if key_to_compare is not None: + if key_to_compare.startswith("$$"): + is_special_op = True + opname = key_to_compare + spec = expectation[key_to_compare] + key_to_compare = None + else: + nested = expectation[key_to_compare] + if isinstance(nested, abc.Mapping) and len(nested) == 1: + opname, spec = next(iter(nested.items())) + if opname.startswith("$$"): + is_special_op = True + elif len(expectation) == 1: + opname, spec = next(iter(expectation.items())) + if opname.startswith("$$"): + is_special_op = True + key_to_compare = None + + if is_special_op: + self._evaluate_special_operation( + opname=opname, spec=spec, actual=actual, key_to_compare=key_to_compare + ) + return True + + return False + + def _match_document(self, expectation, actual, is_root, test=False): + if self._evaluate_if_special_operation(expectation, actual): + return + + self.test.assertIsInstance(actual, abc.Mapping) + for key, value in expectation.items(): + if self._evaluate_if_special_operation(expectation, actual, key): + continue + + self.test.assertIn(key, actual) + if not self.match_result(value, actual[key], in_recursive_call=True, test=test): + return False + + if not is_root: + expected_keys = set(expectation.keys()) + for key, value in expectation.items(): + if value == {"$$exists": False}: + expected_keys.remove(key) + if test: + self.test.assertEqual(expected_keys, set(actual.keys())) + else: + return set(expected_keys).issubset(set(actual.keys())) + return True + + def match_result(self, expectation, actual, in_recursive_call=False, test=True): + if isinstance(expectation, abc.Mapping): + return self._match_document( + expectation, actual, is_root=not in_recursive_call, test=test + ) + + if isinstance(expectation, abc.MutableSequence): + self.test.assertIsInstance(actual, abc.MutableSequence) + for e, a in zip(expectation, actual): + if isinstance(e, abc.Mapping): + self._match_document(e, a, is_root=not in_recursive_call, test=test) + else: + self.match_result(e, a, in_recursive_call=True, test=test) + return None + + # account for flexible numerics in element-wise comparison + if isinstance(expectation, int) or isinstance(expectation, float): + if test: + self.test.assertEqual(expectation, actual) + else: + return expectation == actual + return None + else: + if test: + self.test.assertIsInstance(actual, type(expectation)) + self.test.assertEqual(expectation, actual) + else: + return isinstance(actual, type(expectation)) and expectation == actual + return None + + def match_server_description(self, actual: ServerDescription, spec: dict) -> None: + for field, expected in spec.items(): + field = camel_to_snake(field) + if field == "type": + field = "server_type_name" + self.test.assertEqual(getattr(actual, field), expected) + + def match_topology_description(self, actual: TopologyDescription, spec: dict) -> None: + for field, expected in spec.items(): + field = camel_to_snake(field) + if field == "type": + field = "topology_type_name" + self.test.assertEqual(getattr(actual, field), expected) + + def match_event_fields(self, actual: Any, spec: dict) -> None: + for field, expected in spec.items(): + if field == "command" and isinstance(actual, CommandStartedEvent): + command = spec["command"] + if command: + self.match_result(command, actual.command) + continue + if field == "reply" and isinstance(actual, CommandSucceededEvent): + reply = spec["reply"] + if reply: + self.match_result(reply, actual.reply) + continue + if field == "hasServiceId": + if spec["hasServiceId"]: + self.test.assertIsNotNone(actual.service_id) + self.test.assertIsInstance(actual.service_id, ObjectId) + else: + self.test.assertIsNone(actual.service_id) + continue + if field == "hasServerConnectionId": + if spec["hasServerConnectionId"]: + self.test.assertIsNotNone(actual.server_connection_id) + self.test.assertIsInstance(actual.server_connection_id, int) + else: + self.test.assertIsNone(actual.server_connection_id) + continue + if field in ("previousDescription", "newDescription"): + if isinstance(actual, ServerDescriptionChangedEvent): + self.match_server_description( + getattr(actual, camel_to_snake(field)), spec[field] + ) + continue + if isinstance(actual, TopologyDescriptionChangedEvent): + self.match_topology_description( + getattr(actual, camel_to_snake(field)), spec[field] + ) + continue + + if field == "interruptInUseConnections": + field = "interrupt_connections" + else: + field = camel_to_snake(field) + self.test.assertEqual(getattr(actual, field), expected) + + def match_event(self, expectation, actual): + name, spec = next(iter(expectation.items())) + if name == "commandStartedEvent": + self.test.assertIsInstance(actual, CommandStartedEvent) + elif name == "commandSucceededEvent": + self.test.assertIsInstance(actual, CommandSucceededEvent) + elif name == "commandFailedEvent": + self.test.assertIsInstance(actual, CommandFailedEvent) + elif name == "poolCreatedEvent": + self.test.assertIsInstance(actual, PoolCreatedEvent) + elif name == "poolReadyEvent": + self.test.assertIsInstance(actual, PoolReadyEvent) + elif name == "poolClearedEvent": + self.test.assertIsInstance(actual, PoolClearedEvent) + self.test.assertIsInstance(actual.interrupt_connections, bool) + elif name == "poolClosedEvent": + self.test.assertIsInstance(actual, PoolClosedEvent) + elif name == "connectionCreatedEvent": + self.test.assertIsInstance(actual, ConnectionCreatedEvent) + elif name == "connectionReadyEvent": + self.test.assertIsInstance(actual, ConnectionReadyEvent) + elif name == "connectionClosedEvent": + self.test.assertIsInstance(actual, ConnectionClosedEvent) + elif name == "connectionCheckOutStartedEvent": + self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent) + elif name == "connectionCheckOutFailedEvent": + self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent) + elif name == "connectionCheckedOutEvent": + self.test.assertIsInstance(actual, ConnectionCheckedOutEvent) + elif name == "connectionCheckedInEvent": + self.test.assertIsInstance(actual, ConnectionCheckedInEvent) + elif name == "serverDescriptionChangedEvent": + self.test.assertIsInstance(actual, ServerDescriptionChangedEvent) + elif name == "serverHeartbeatStartedEvent": + self.test.assertIsInstance(actual, ServerHeartbeatStartedEvent) + elif name == "serverHeartbeatSucceededEvent": + self.test.assertIsInstance(actual, ServerHeartbeatSucceededEvent) + elif name == "serverHeartbeatFailedEvent": + self.test.assertIsInstance(actual, ServerHeartbeatFailedEvent) + elif name == "topologyDescriptionChangedEvent": + self.test.assertIsInstance(actual, TopologyDescriptionChangedEvent) + elif name == "topologyOpeningEvent": + self.test.assertIsInstance(actual, TopologyOpenedEvent) + elif name == "topologyClosedEvent": + self.test.assertIsInstance(actual, TopologyClosedEvent) + else: + raise Exception(f"Unsupported event type {name}") + + self.match_event_fields(actual, spec) + + +def coerce_result(opname, result): + """Convert a pymongo result into the spec's result format.""" + if hasattr(result, "acknowledged") and not result.acknowledged: + return {"acknowledged": False} + if opname == "bulkWrite": + return parse_bulk_write_result(result) + if opname == "clientBulkWrite": + return parse_client_bulk_write_result(result) + if opname == "insertOne": + return {"insertedId": result.inserted_id} + if opname == "insertMany": + return dict(enumerate(result.inserted_ids)) + if opname in ("deleteOne", "deleteMany"): + return {"deletedCount": result.deleted_count} + if opname in ("updateOne", "updateMany", "replaceOne"): + value = { + "matchedCount": result.matched_count, + "modifiedCount": result.modified_count, + "upsertedCount": 0 if result.upserted_id is None else 1, + } + if result.upserted_id is not None: + value["upsertedId"] = result.upserted_id + return value + return result diff --git a/tools/synchro.py b/tools/synchro.py index f704919a17..e0af5efa44 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -205,6 +205,7 @@ def async_only_test(f: str) -> bool: "test_retryable_writes.py", "test_session.py", "test_transactions.py", + "unified_format.py", ] sync_test_files = [ From 7e86d24c7bffe4da0a4d32580b5da0e6230b78d2 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 11 Oct 2024 13:59:37 -0400 Subject: [PATCH 04/19] PYTHON-4849 - Convert test.test_connection_logging.py to async (#1918) --- test/asynchronous/test_connection_logging.py | 45 ++++++++++++++++++++ test/test_connection_logging.py | 8 +++- tools/synchro.py | 1 + 3 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 test/asynchronous/test_connection_logging.py diff --git a/test/asynchronous/test_connection_logging.py b/test/asynchronous/test_connection_logging.py new file mode 100644 index 0000000000..6bc9835b70 --- /dev/null +++ b/test/asynchronous/test_connection_logging.py @@ -0,0 +1,45 @@ +# Copyright 2023-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the connection logging unified format spec tests.""" +from __future__ import annotations + +import os +import pathlib +import sys + +sys.path[0:0] = [""] + +from test import unittest +from test.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_logging") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "connection_logging") + + +globals().update( + generate_test_classes( + _TEST_PATH, + module=__name__, + ) +) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_connection_logging.py b/test/test_connection_logging.py index 262ce821eb..253193cc43 100644 --- a/test/test_connection_logging.py +++ b/test/test_connection_logging.py @@ -16,6 +16,7 @@ from __future__ import annotations import os +import pathlib import sys sys.path[0:0] = [""] @@ -23,8 +24,13 @@ from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "connection_logging") +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_logging") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "connection_logging") globals().update( diff --git a/tools/synchro.py b/tools/synchro.py index e0af5efa44..dbaf0a15e9 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -193,6 +193,7 @@ def async_only_test(f: str) -> bool: "test_collation.py", "test_collection.py", "test_common.py", + "test_connection_logging.py", "test_connections_survive_primary_stepdown_spec.py", "test_cursor.py", "test_database.py", From e0fde2338126ee3e8ca7771b3f88c4a2706638f2 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 11 Oct 2024 13:59:44 -0400 Subject: [PATCH 05/19] PYTHON-4850 - Convert test.test_crud_unified to async (#1920) --- test/asynchronous/test_crud_unified.py | 39 ++++++++++++++++++++++++++ test/test_crud_unified.py | 10 +++++-- tools/synchro.py | 1 + 3 files changed, 48 insertions(+), 2 deletions(-) create mode 100644 test/asynchronous/test_crud_unified.py diff --git a/test/asynchronous/test_crud_unified.py b/test/asynchronous/test_crud_unified.py new file mode 100644 index 0000000000..3d8deb36e9 --- /dev/null +++ b/test/asynchronous/test_crud_unified.py @@ -0,0 +1,39 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the CRUD unified spec tests.""" +from __future__ import annotations + +import os +import pathlib +import sys + +sys.path[0:0] = [""] + +from test import unittest +from test.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "crud", "unified") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "crud", "unified") + +# Generate unified tests. +globals().update(generate_test_classes(_TEST_PATH, module=__name__, RUN_ON_SERVERLESS=True)) + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_crud_unified.py b/test/test_crud_unified.py index 92a60a47fc..26f34cba88 100644 --- a/test/test_crud_unified.py +++ b/test/test_crud_unified.py @@ -16,6 +16,7 @@ from __future__ import annotations import os +import pathlib import sys sys.path[0:0] = [""] @@ -23,11 +24,16 @@ from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "crud", "unified") +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "crud", "unified") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "crud", "unified") # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__, RUN_ON_SERVERLESS=True)) +globals().update(generate_test_classes(_TEST_PATH, module=__name__, RUN_ON_SERVERLESS=True)) if __name__ == "__main__": unittest.main() diff --git a/tools/synchro.py b/tools/synchro.py index dbaf0a15e9..39ce7fbdd0 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -195,6 +195,7 @@ def async_only_test(f: str) -> bool: "test_common.py", "test_connection_logging.py", "test_connections_survive_primary_stepdown_spec.py", + "test_crud_unified.py", "test_cursor.py", "test_database.py", "test_encryption.py", From b2332b2aaeb26ecd7efa4992f037ca4dc56583db Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 11 Oct 2024 13:59:49 -0400 Subject: [PATCH 06/19] PYTHON-4846 - Convert test.test_command_logging.py to async (#1915) --- test/asynchronous/test_command_logging.py | 44 +++++++++++++++++++++++ test/test_command_logging.py | 9 ++++- tools/synchro.py | 1 + 3 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 test/asynchronous/test_command_logging.py diff --git a/test/asynchronous/test_command_logging.py b/test/asynchronous/test_command_logging.py new file mode 100644 index 0000000000..f9b459c152 --- /dev/null +++ b/test/asynchronous/test_command_logging.py @@ -0,0 +1,44 @@ +# Copyright 2023-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the command monitoring unified format spec tests.""" +from __future__ import annotations + +import os +import pathlib +import sys + +sys.path[0:0] = [""] + +from test import unittest +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_logging") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_logging") + + +globals().update( + generate_test_classes( + _TEST_PATH, + module=__name__, + ) +) + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_command_logging.py b/test/test_command_logging.py index 9b2d52e66b..cf865920ca 100644 --- a/test/test_command_logging.py +++ b/test/test_command_logging.py @@ -16,6 +16,7 @@ from __future__ import annotations import os +import pathlib import sys sys.path[0:0] = [""] @@ -23,8 +24,14 @@ from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "command_logging") +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_logging") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_logging") + globals().update( generate_test_classes( diff --git a/tools/synchro.py b/tools/synchro.py index 39ce7fbdd0..f40a64e4c2 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -192,6 +192,7 @@ def async_only_test(f: str) -> bool: "test_client_context.py", "test_collation.py", "test_collection.py", + "test_command_logging.py", "test_common.py", "test_connection_logging.py", "test_connections_survive_primary_stepdown_spec.py", From 4eeaa4b7be9e814fd207166904f42556c10ce63b Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 11 Oct 2024 14:56:43 -0400 Subject: [PATCH 07/19] PYTHON-4848 - Convert test.test_command_monitoring.py to async (#1917) --- test/asynchronous/test_command_monitoring.py | 45 ++++++++++++++++++++ test/test_command_monitoring.py | 8 +++- tools/synchro.py | 1 + 3 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 test/asynchronous/test_command_monitoring.py diff --git a/test/asynchronous/test_command_monitoring.py b/test/asynchronous/test_command_monitoring.py new file mode 100644 index 0000000000..311fd1fdc1 --- /dev/null +++ b/test/asynchronous/test_command_monitoring.py @@ -0,0 +1,45 @@ +# Copyright 2015-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run the command monitoring unified format spec tests.""" +from __future__ import annotations + +import os +import pathlib +import sys + +sys.path[0:0] = [""] + +from test import unittest +from test.asynchronous.unified_format import generate_test_classes + +_IS_SYNC = False + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_monitoring") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_monitoring") + + +globals().update( + generate_test_classes( + _TEST_PATH, + module=__name__, + ) +) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_command_monitoring.py b/test/test_command_monitoring.py index d2f578824d..4f5ef06f28 100644 --- a/test/test_command_monitoring.py +++ b/test/test_command_monitoring.py @@ -16,6 +16,7 @@ from __future__ import annotations import os +import pathlib import sys sys.path[0:0] = [""] @@ -23,8 +24,13 @@ from test import unittest from test.unified_format import generate_test_classes +_IS_SYNC = True + # Location of JSON test specifications. -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "command_monitoring") +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_monitoring") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_monitoring") globals().update( diff --git a/tools/synchro.py b/tools/synchro.py index f40a64e4c2..b6812e9be6 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -193,6 +193,7 @@ def async_only_test(f: str) -> bool: "test_collation.py", "test_collection.py", "test_command_logging.py", + "test_command_monitoring.py", "test_common.py", "test_connection_logging.py", "test_connections_survive_primary_stepdown_spec.py", From 33163ecc0d4fe7dc8f7bfc12ef93d89513203fe2 Mon Sep 17 00:00:00 2001 From: Iris <58442094+sleepyStick@users.noreply.github.com> Date: Fri, 11 Oct 2024 16:02:13 -0700 Subject: [PATCH 08/19] PYTHON-4804 Migrate test_comment.py to async (#1887) --- test/asynchronous/test_comment.py | 159 ++++++++++++++++++++++++++++++ test/test_comment.py | 60 ++++------- tools/synchro.py | 2 + 3 files changed, 179 insertions(+), 42 deletions(-) create mode 100644 test/asynchronous/test_comment.py diff --git a/test/asynchronous/test_comment.py b/test/asynchronous/test_comment.py new file mode 100644 index 0000000000..be3626a8b8 --- /dev/null +++ b/test/asynchronous/test_comment.py @@ -0,0 +1,159 @@ +# Copyright 2022-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the keyword argument 'comment' in various helpers.""" + +from __future__ import annotations + +import inspect +import sys + +sys.path[0:0] = [""] +from asyncio import iscoroutinefunction +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.utils import OvertCommandListener + +from bson.dbref import DBRef +from pymongo.asynchronous.command_cursor import AsyncCommandCursor +from pymongo.operations import IndexModel + +_IS_SYNC = False + + +class AsyncTestComment(AsyncIntegrationTest): + async def _test_ops( + self, + helpers, + already_supported, + listener, + ): + for h, args in helpers: + c = "testing comment with " + h.__name__ + with self.subTest("collection-" + h.__name__ + "-comment"): + for cc in [c, {"key": c}, ["any", 1]]: + listener.reset() + kwargs = {"comment": cc} + try: + maybe_cursor = await h(*args, **kwargs) + except Exception: + maybe_cursor = None + self.assertIn( + "comment", + inspect.signature(h).parameters, + msg="Could not find 'comment' in the " + "signature of function %s" % (h.__name__), + ) + self.assertEqual( + inspect.signature(h).parameters["comment"].annotation, "Optional[Any]" + ) + if isinstance(maybe_cursor, AsyncCommandCursor): + await maybe_cursor.close() + + cmd = listener.started_events[0] + self.assertEqual(cc, cmd.command.get("comment"), msg=cmd) + + if h.__name__ != "aggregate_raw_batches": + self.assertIn( + ":param comment:", + h.__doc__, + ) + if h not in already_supported: + self.assertIn( + "Added ``comment`` parameter", + h.__doc__, + ) + else: + self.assertNotIn( + "Added ``comment`` parameter", + h.__doc__, + ) + + listener.reset() + + @async_client_context.require_version_min(4, 7, -1) + @async_client_context.require_replica_set + async def test_database_helpers(self): + listener = OvertCommandListener() + db = (await self.async_rs_or_single_client(event_listeners=[listener])).db + helpers = [ + (db.watch, []), + (db.command, ["hello"]), + (db.list_collections, []), + (db.list_collection_names, []), + (db.drop_collection, ["hello"]), + (db.validate_collection, ["test"]), + (db.dereference, [DBRef("collection", 1)]), + ] + already_supported = [db.command, db.list_collections, db.list_collection_names] + await self._test_ops(helpers, already_supported, listener) + + @async_client_context.require_version_min(4, 7, -1) + @async_client_context.require_replica_set + async def test_client_helpers(self): + listener = OvertCommandListener() + cli = await self.async_rs_or_single_client(event_listeners=[listener]) + helpers = [ + (cli.watch, []), + (cli.list_databases, []), + (cli.list_database_names, []), + (cli.drop_database, ["test"]), + ] + already_supported = [ + cli.list_databases, + ] + await self._test_ops(helpers, already_supported, listener) + + @async_client_context.require_version_min(4, 7, -1) + async def test_collection_helpers(self): + listener = OvertCommandListener() + db = (await self.async_rs_or_single_client(event_listeners=[listener]))[self.db.name] + coll = db.get_collection("test") + + helpers = [ + (coll.list_indexes, []), + (coll.drop, []), + (coll.index_information, []), + (coll.options, []), + (coll.aggregate, [[{"$set": {"x": 1}}]]), + (coll.aggregate_raw_batches, [[{"$set": {"x": 1}}]]), + (coll.rename, ["temp_temp_temp"]), + (coll.distinct, ["_id"]), + (coll.find_one_and_delete, [{}]), + (coll.find_one_and_replace, [{}, {}]), + (coll.find_one_and_update, [{}, {"$set": {"a": 1}}]), + (coll.estimated_document_count, []), + (coll.count_documents, [{}]), + (coll.create_indexes, [[IndexModel("a")]]), + (coll.create_index, ["a"]), + (coll.drop_index, [[("a", 1)]]), + (coll.drop_indexes, []), + ] + already_supported = [ + coll.estimated_document_count, + coll.count_documents, + coll.create_indexes, + coll.drop_indexes, + coll.options, + coll.find_one_and_replace, + coll.drop_index, + coll.rename, + coll.distinct, + coll.find_one_and_delete, + coll.find_one_and_update, + ] + await self._test_ops(helpers, already_supported, listener) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_comment.py b/test/test_comment.py index c0f037ea44..9f9bf98640 100644 --- a/test/test_comment.py +++ b/test/test_comment.py @@ -20,24 +20,15 @@ import sys sys.path[0:0] = [""] - +from asyncio import iscoroutinefunction from test import IntegrationTest, client_context, unittest -from test.utils import EventListener +from test.utils import OvertCommandListener from bson.dbref import DBRef from pymongo.operations import IndexModel from pymongo.synchronous.command_cursor import CommandCursor - -class Empty: - def __getattr__(self, item): - try: - self.__dict__[item] - except KeyError: - return self.empty - - def empty(self, *args, **kwargs): - return Empty() +_IS_SYNC = True class TestComment(IntegrationTest): @@ -46,8 +37,6 @@ def _test_ops( helpers, already_supported, listener, - db=Empty(), # noqa: B008 - coll=Empty(), # noqa: B008 ): for h, args in helpers: c = "testing comment with " + h.__name__ @@ -55,19 +44,10 @@ def _test_ops( for cc in [c, {"key": c}, ["any", 1]]: listener.reset() kwargs = {"comment": cc} - if h == coll.rename: - _ = db.get_collection("temp_temp_temp").drop() - destruct_coll = db.get_collection("test_temp") - destruct_coll.insert_one({}) - maybe_cursor = destruct_coll.rename(*args, **kwargs) - destruct_coll.drop() - elif h == db.validate_collection: - coll = db.get_collection("test") - coll.insert_one({}) - maybe_cursor = db.validate_collection(*args, **kwargs) - else: - coll.create_index("a") + try: maybe_cursor = h(*args, **kwargs) + except Exception: + maybe_cursor = None self.assertIn( "comment", inspect.signature(h).parameters, @@ -79,15 +59,11 @@ def _test_ops( ) if isinstance(maybe_cursor, CommandCursor): maybe_cursor.close() - tested = False - # For some reason collection.list_indexes creates two commands and the first - # one doesn't contain 'comment'. - for i in listener.started_events: - if cc == i.command.get("comment", ""): - self.assertEqual(cc, i.command["comment"]) - tested = True - self.assertTrue(tested) - if h not in [coll.aggregate_raw_batches]: + + cmd = listener.started_events[0] + self.assertEqual(cc, cmd.command.get("comment"), msg=cmd) + + if h.__name__ != "aggregate_raw_batches": self.assertIn( ":param comment:", h.__doc__, @@ -108,8 +84,8 @@ def _test_ops( @client_context.require_version_min(4, 7, -1) @client_context.require_replica_set def test_database_helpers(self): - listener = EventListener() - db = self.rs_or_single_client(event_listeners=[listener]).db + listener = OvertCommandListener() + db = (self.rs_or_single_client(event_listeners=[listener])).db helpers = [ (db.watch, []), (db.command, ["hello"]), @@ -120,12 +96,12 @@ def test_database_helpers(self): (db.dereference, [DBRef("collection", 1)]), ] already_supported = [db.command, db.list_collections, db.list_collection_names] - self._test_ops(helpers, already_supported, listener, db=db, coll=db.get_collection("test")) + self._test_ops(helpers, already_supported, listener) @client_context.require_version_min(4, 7, -1) @client_context.require_replica_set def test_client_helpers(self): - listener = EventListener() + listener = OvertCommandListener() cli = self.rs_or_single_client(event_listeners=[listener]) helpers = [ (cli.watch, []), @@ -140,8 +116,8 @@ def test_client_helpers(self): @client_context.require_version_min(4, 7, -1) def test_collection_helpers(self): - listener = EventListener() - db = self.rs_or_single_client(event_listeners=[listener])[self.db.name] + listener = OvertCommandListener() + db = (self.rs_or_single_client(event_listeners=[listener]))[self.db.name] coll = db.get_collection("test") helpers = [ @@ -176,7 +152,7 @@ def test_collection_helpers(self): coll.find_one_and_delete, coll.find_one_and_update, ] - self._test_ops(helpers, already_supported, listener, coll=coll, db=db) + self._test_ops(helpers, already_supported, listener) if __name__ == "__main__": diff --git a/tools/synchro.py b/tools/synchro.py index b6812e9be6..25f506ed5a 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -193,7 +193,9 @@ def async_only_test(f: str) -> bool: "test_collation.py", "test_collection.py", "test_command_logging.py", + "test_command_logging.py", "test_command_monitoring.py", + "test_comment.py", "test_common.py", "test_connection_logging.py", "test_connections_survive_primary_stepdown_spec.py", From 3c5e71a1cb28b695bc2eec4c3927ef6af56835a8 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 14 Oct 2024 07:32:38 -0500 Subject: [PATCH 09/19] PYTHON-4862 Fix handling of interrupt_loop in unified test runner (#1924) --- test/asynchronous/unified_format.py | 8 +++++++- test/unified_format.py | 8 +++++++- test/unified_format_shared.py | 5 ----- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 4c37422951..42bda59cb2 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -36,7 +36,6 @@ unittest, ) from test.unified_format_shared import ( - IS_INTERRUPTED, KMS_TLS_OPTS, PLACEHOLDER_MAP, SKIP_CSOT_TESTS, @@ -104,6 +103,13 @@ _IS_SYNC = False +IS_INTERRUPTED = False + + +def interrupt_loop(): + global IS_INTERRUPTED + IS_INTERRUPTED = True + async def is_run_on_requirement_satisfied(requirement): topology_satisfied = True diff --git a/test/unified_format.py b/test/unified_format.py index 6a19082b86..13ab0af69b 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -36,7 +36,6 @@ unittest, ) from test.unified_format_shared import ( - IS_INTERRUPTED, KMS_TLS_OPTS, PLACEHOLDER_MAP, SKIP_CSOT_TESTS, @@ -104,6 +103,13 @@ _IS_SYNC = True +IS_INTERRUPTED = False + + +def interrupt_loop(): + global IS_INTERRUPTED + IS_INTERRUPTED = True + def is_run_on_requirement_satisfied(requirement): topology_satisfied = True diff --git a/test/unified_format_shared.py b/test/unified_format_shared.py index d11624476d..f1b908a7a6 100644 --- a/test/unified_format_shared.py +++ b/test/unified_format_shared.py @@ -139,11 +139,6 @@ } -def interrupt_loop(): - global IS_INTERRUPTED - IS_INTERRUPTED = True - - def with_metaclass(meta, *bases): """Create a base class with a metaclass. From 9ba780cac256720be5c3c5051c7f8a19d27693d5 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 14 Oct 2024 07:34:01 -0500 Subject: [PATCH 10/19] PYTHON-4861 Ensure hatch is isolated in Evergreen (#1923) --- .evergreen/hatch.sh | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/.evergreen/hatch.sh b/.evergreen/hatch.sh index db0da2f4d0..8f862c39d2 100644 --- a/.evergreen/hatch.sh +++ b/.evergreen/hatch.sh @@ -18,17 +18,22 @@ if [ -n "$SKIP_HATCH" ]; then run_hatch() { bash ./.evergreen/run-tests.sh } -elif $PYTHON_BINARY -m hatch --version; then - run_hatch() { - $PYTHON_BINARY -m hatch run "$@" - } -else # No toolchain hatch present, set up virtualenv before installing hatch +else # Set up virtualenv before installing hatch # Use a random venv name because the encryption tasks run this script multiple times in the same run. ENV_NAME=hatchenv-$RANDOM createvirtualenv "$PYTHON_BINARY" $ENV_NAME # shellcheck disable=SC2064 trap "deactivate; rm -rf $ENV_NAME" EXIT HUP python -m pip install -q hatch + + # Ensure hatch does not write to user or global locations. + touch hatch_config.toml + HATCH_CONFIG=$(pwd)/hatch_config.toml + export HATCH_CONFIG + hatch config restore + hatch config set dirs.data ".hatch/data" + hatch config set dirs.cache ".hatch/cache" + run_hatch() { python -m hatch run "$@" } From 3cc722e9105d5818d57739d623d985d69b0eb626 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Mon, 14 Oct 2024 14:05:22 -0500 Subject: [PATCH 11/19] PYTHON-4838 Generate OCSP build variants using shrub.py (#1910) --- .evergreen/config.yml | 174 +++++++++++++++++++++----- .evergreen/scripts/generate_config.py | 167 ++++++++++++++++++++++++ 2 files changed, 308 insertions(+), 33 deletions(-) create mode 100644 .evergreen/scripts/generate_config.py diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 1ef8751501..dee4b608ec 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -2826,42 +2826,150 @@ buildvariants: - "test-6.0-standalone" - "test-5.0-standalone" -- matrix_name: "ocsp-test" - matrix_spec: - platform: rhel8 - python-version: ["3.9", "3.10", "pypy3.9", "pypy3.10"] - mongodb-version: ["4.4", "5.0", "6.0", "7.0", "8.0", "latest"] - auth: "noauth" - ssl: "ssl" - display_name: "OCSP test ${platform} ${python-version} ${mongodb-version}" - batchtime: 20160 # 14 days +# OCSP test matrix. +- name: ocsp-test-rhel8-v4.4-py3.9 tasks: - - name: ".ocsp" - -- matrix_name: "ocsp-test-windows" - matrix_spec: - platform: windows - python-version-windows: ["3.9", "3.10"] - mongodb-version: ["4.4", "5.0", "6.0", "7.0", "8.0", "latest"] - auth: "noauth" - ssl: "ssl" - display_name: "OCSP test ${platform} ${python-version-windows} ${mongodb-version}" - batchtime: 20160 # 14 days + - name: .ocsp + display_name: OCSP test RHEL8 v4.4 py3.9 + run_on: + - rhel87-small + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: "4.4" + PYTHON_BINARY: /opt/python/3.9/bin/python3 +- name: ocsp-test-rhel8-v5.0-py3.10 tasks: - # Windows MongoDB servers do not staple OCSP responses and only support RSA. - - name: ".ocsp-rsa !.ocsp-staple" - -- matrix_name: "ocsp-test-macos" - matrix_spec: - platform: macos - mongodb-version: ["4.4", "5.0", "6.0", "7.0", "8.0", "latest"] - auth: "noauth" - ssl: "ssl" - display_name: "OCSP test ${platform} ${mongodb-version}" - batchtime: 20160 # 14 days + - name: .ocsp + display_name: OCSP test RHEL8 v5.0 py3.10 + run_on: + - rhel87-small + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: "5.0" + PYTHON_BINARY: /opt/python/3.10/bin/python3 +- name: ocsp-test-rhel8-v6.0-py3.11 tasks: - # macOS MongoDB servers do not staple OCSP responses and only support RSA. - - name: ".ocsp-rsa !.ocsp-staple" + - name: .ocsp + display_name: OCSP test RHEL8 v6.0 py3.11 + run_on: + - rhel87-small + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: "6.0" + PYTHON_BINARY: /opt/python/3.11/bin/python3 +- name: ocsp-test-rhel8-v7.0-py3.12 + tasks: + - name: .ocsp + display_name: OCSP test RHEL8 v7.0 py3.12 + run_on: + - rhel87-small + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: "7.0" + PYTHON_BINARY: /opt/python/3.12/bin/python3 +- name: ocsp-test-rhel8-v8.0-py3.13 + tasks: + - name: .ocsp + display_name: OCSP test RHEL8 v8.0 py3.13 + run_on: + - rhel87-small + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: "8.0" + PYTHON_BINARY: /opt/python/3.13/bin/python3 +- name: ocsp-test-rhel8-rapid-pypy3.9 + tasks: + - name: .ocsp + display_name: OCSP test RHEL8 rapid pypy3.9 + run_on: + - rhel87-small + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: rapid + PYTHON_BINARY: /opt/python/pypy3.9/bin/python3 +- name: ocsp-test-rhel8-latest-pypy3.10 + tasks: + - name: .ocsp + display_name: OCSP test RHEL8 latest pypy3.10 + run_on: + - rhel87-small + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: latest + PYTHON_BINARY: /opt/python/pypy3.10/bin/python3 +- name: ocsp-test-win64-v4.4-py3.9 + tasks: + - name: .ocsp-rsa !.ocsp-staple + display_name: OCSP test Win64 v4.4 py3.9 + run_on: + - windows-64-vsMulti-small + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: "4.4" + PYTHON_BINARY: C:/python/Python39/python.exe +- name: ocsp-test-win64-v8.0-py3.13 + tasks: + - name: .ocsp-rsa !.ocsp-staple + display_name: OCSP test Win64 v8.0 py3.13 + run_on: + - windows-64-vsMulti-small + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: "8.0" + PYTHON_BINARY: C:/python/Python313/python.exe +- name: ocsp-test-macos-v4.4-py3.9 + tasks: + - name: .ocsp-rsa !.ocsp-staple + display_name: OCSP test macOS v4.4 py3.9 + run_on: + - macos-14 + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: "4.4" + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.9/bin/python3 +- name: ocsp-test-macos-v8.0-py3.13 + tasks: + - name: .ocsp-rsa !.ocsp-staple + display_name: OCSP test macOS v8.0 py3.13 + run_on: + - macos-14 + batchtime: 20160 + expansions: + AUTH: noauth + SSL: ssl + TOPOLOGY: server + VERSION: "8.0" + PYTHON_BINARY: /Library/Frameworks/Python.Framework/Versions/3.13/bin/python3 - matrix_name: "oidc-auth-test" matrix_spec: diff --git a/.evergreen/scripts/generate_config.py b/.evergreen/scripts/generate_config.py new file mode 100644 index 0000000000..e98e527b72 --- /dev/null +++ b/.evergreen/scripts/generate_config.py @@ -0,0 +1,167 @@ +# /// script +# requires-python = ">=3.9" +# dependencies = [ +# "shrub.py>=3.2.0", +# "pyyaml>=6.0.2" +# ] +# /// + +# Note: Run this file with `hatch run`, `pipx run`, or `uv run`. +from __future__ import annotations + +from dataclasses import dataclass +from itertools import cycle, product, zip_longest +from typing import Any + +from shrub.v3.evg_build_variant import BuildVariant +from shrub.v3.evg_project import EvgProject +from shrub.v3.evg_task import EvgTaskRef +from shrub.v3.shrub_service import ShrubService + +############## +# Globals +############## + +ALL_VERSIONS = ["4.0", "4.4", "5.0", "6.0", "7.0", "8.0", "rapid", "latest"] +CPYTHONS = ["3.9", "3.10", "3.11", "3.12", "3.13"] +PYPYS = ["pypy3.9", "pypy3.10"] +ALL_PYTHONS = CPYTHONS + PYPYS +BATCHTIME_WEEK = 10080 +HOSTS = dict() + + +@dataclass +class Host: + name: str + run_on: str + display_name: str + + +HOSTS["rhel8"] = Host("rhel8", "rhel87-small", "RHEL8") +HOSTS["win64"] = Host("win64", "windows-64-vsMulti-small", "Win64") +HOSTS["macos"] = Host("macos", "macos-14", "macOS") + + +############## +# Helpers +############## + + +def create_variant( + task_names: list[str], + display_name: str, + *, + python: str | None = None, + version: str | None = None, + host: str | None = None, + **kwargs: Any, +) -> BuildVariant: + """Create a build variant for the given inputs.""" + task_refs = [EvgTaskRef(name=n) for n in task_names] + kwargs.setdefault("expansions", dict()) + expansions = kwargs.pop("expansions", dict()).copy() + host = host or "rhel8" + run_on = [HOSTS[host].run_on] + name = display_name.replace(" ", "-").lower() + if python: + expansions["PYTHON_BINARY"] = get_python_binary(python, host) + if version: + expansions["VERSION"] = version + expansions = expansions or None + return BuildVariant( + name=name, + display_name=display_name, + tasks=task_refs, + expansions=expansions, + run_on=run_on, + **kwargs, + ) + + +def get_python_binary(python: str, host: str) -> str: + """Get the appropriate python binary given a python version and host.""" + if host == "win64": + is_32 = python.startswith("32-bit") + if is_32: + _, python = python.split() + base = "C:/python/32" + else: + base = "C:/python" + python = python.replace(".", "") + return f"{base}/Python{python}/python.exe" + + if host == "rhel8": + return f"/opt/python/{python}/bin/python3" + + if host == "macos": + return f"/Library/Frameworks/Python.Framework/Versions/{python}/bin/python3" + + raise ValueError(f"no match found for python {python} on {host}") + + +def get_display_name(base: str, host: str, version: str, python: str) -> str: + """Get the display name of a variant.""" + if version not in ["rapid", "latest"]: + version = f"v{version}" + if not python.startswith("pypy"): + python = f"py{python}" + return f"{base} {HOSTS[host].display_name} {version} {python}" + + +def zip_cycle(*iterables, empty_default=None): + """Get all combinations of the inputs, cycling over the shorter list(s).""" + cycles = [cycle(i) for i in iterables] + for _ in zip_longest(*iterables): + yield tuple(next(i, empty_default) for i in cycles) + + +############## +# Variants +############## + + +def create_ocsp_variants() -> list[BuildVariant]: + variants = [] + batchtime = BATCHTIME_WEEK * 2 + expansions = dict(AUTH="noauth", SSL="ssl", TOPOLOGY="server") + base_display = "OCSP test" + + # OCSP tests on rhel8 with all servers v4.4+ and all python versions. + versions = [v for v in ALL_VERSIONS if v != "4.0"] + for version, python in zip_cycle(versions, ALL_PYTHONS): + host = "rhel8" + variant = create_variant( + [".ocsp"], + get_display_name(base_display, host, version, python), + python=python, + version=version, + host=host, + expansions=expansions, + batchtime=batchtime, + ) + variants.append(variant) + + # OCSP tests on Windows and MacOS. + # MongoDB servers on these hosts do not staple OCSP responses and only support RSA. + for host, version in product(["win64", "macos"], ["4.4", "8.0"]): + python = CPYTHONS[0] if version == "4.4" else CPYTHONS[-1] + variant = create_variant( + [".ocsp-rsa !.ocsp-staple"], + get_display_name(base_display, host, version, python), + python=python, + version=version, + host=host, + expansions=expansions, + batchtime=batchtime, + ) + variants.append(variant) + + return variants + + +################## +# Generate Config +################## + +project = EvgProject(tasks=None, buildvariants=create_ocsp_variants()) +print(ShrubService.generate_yaml(project)) # noqa: T201 From 389127a66ab5eae854637f77e523d9be5af7a8bb Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 14 Oct 2024 16:51:28 -0400 Subject: [PATCH 12/19] PYTHON-4843 - Async test suite should use a single event loop per test --- test/__init__.py | 81 ++---------- test/asynchronous/__init__.py | 85 +++--------- test/asynchronous/test_bulk.py | 36 ++---- test/asynchronous/test_change_stream.py | 43 ++---- test/asynchronous/test_client.py | 9 +- test/asynchronous/test_collation.py | 30 ++--- test/asynchronous/test_collection.py | 33 ++--- ...nnections_survive_primary_stepdown_spec.py | 28 ++-- test/asynchronous/test_cursor.py | 4 - test/asynchronous/test_database.py | 3 +- test/asynchronous/test_encryption.py | 122 +++++++----------- test/asynchronous/test_grid_file.py | 1 + test/asynchronous/test_monitoring.py | 47 ++++--- test/asynchronous/test_retryable_writes.py | 65 ++++------ test/asynchronous/test_session.py | 32 ++--- test/asynchronous/test_transactions.py | 17 +-- test/asynchronous/utils_spec_runner.py | 26 ++-- test/test_bulk.py | 32 ++--- test/test_change_stream.py | 39 ++---- test/test_client.py | 9 +- test/test_collation.py | 28 ++-- test/test_collection.py | 33 ++--- ...nnections_survive_primary_stepdown_spec.py | 28 ++-- test/test_cursor.py | 4 - test/test_custom_types.py | 23 ++-- test/test_database.py | 1 + test/test_encryption.py | 120 +++++++---------- test/test_grid_file.py | 1 + test/test_monitoring.py | 45 +++---- test/test_retryable_writes.py | 65 ++++------ test/test_session.py | 32 ++--- test/test_transactions.py | 15 +-- test/utils_spec_runner.py | 26 ++-- 33 files changed, 389 insertions(+), 774 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index 6be3b49ce6..9858ab2d6e 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1114,26 +1114,10 @@ def enable_replication(self, client): class UnitTest(PyMongoTestCase): """Async base class for TestCases that don't require a connection to MongoDB.""" - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod - def _setup_class(cls): + def setUp(self) -> None: pass - @classmethod - def _tearDown_class(cls): + def tearDown(self) -> None: pass @@ -1144,37 +1128,20 @@ class IntegrationTest(PyMongoTestCase): db: Database credentials: Dict[str, str] - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod @client_context.require_connection - def _setup_class(cls): - if client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False): + def setUp(self) -> None: + if not _IS_SYNC: + reset_client_context() + if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") - if client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False): + if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False): raise SkipTest("this test does not support serverless") - cls.client = client_context.client - cls.db = cls.client.pymongo_test + self.client = client_context.client + self.db = self.client.pymongo_test if client_context.auth_enabled: - cls.credentials = {"username": db_user, "password": db_pwd} + self.credentials = {"username": db_user, "password": db_pwd} else: - cls.credentials = {} - - @classmethod - def _tearDown_class(cls): - pass + self.credentials = {} def cleanup_colls(self, *collections): """Cleanup collections faster than drop_collection.""" @@ -1200,37 +1167,15 @@ class MockClientTest(UnitTest): # MockClients tests that use replicaSet, directConnection=True, pass # multiple seed addresses, or wait for heartbeat events are incompatible # with loadBalanced=True. - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod @client_context.require_no_load_balancer - def _setup_class(cls): - pass - - @classmethod - def _tearDown_class(cls): - pass - - def setUp(self): + def setUp(self) -> None: super().setUp() self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001) self.client_knobs.enable() - def tearDown(self): + def tearDown(self) -> None: self.client_knobs.disable() super().tearDown() diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 1a386fe766..cc0fda3e6a 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -1132,26 +1132,10 @@ async def enable_replication(self, client): class AsyncUnitTest(AsyncPyMongoTestCase): """Async base class for TestCases that don't require a connection to MongoDB.""" - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod - async def _setup_class(cls): + async def asyncSetUp(self) -> None: pass - @classmethod - async def _tearDown_class(cls): + async def asyncTearDown(self) -> None: pass @@ -1162,37 +1146,20 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase): db: AsyncDatabase credentials: Dict[str, str] - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod @async_client_context.require_connection - async def _setup_class(cls): - if async_client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False): + async def asyncSetUp(self) -> None: + if not _IS_SYNC: + await reset_client_context() + if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") - if async_client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False): + if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False): raise SkipTest("this test does not support serverless") - cls.client = async_client_context.client - cls.db = cls.client.pymongo_test + self.client = async_client_context.client + self.db = self.client.pymongo_test if async_client_context.auth_enabled: - cls.credentials = {"username": db_user, "password": db_pwd} + self.credentials = {"username": db_user, "password": db_pwd} else: - cls.credentials = {} - - @classmethod - async def _tearDown_class(cls): - pass + self.credentials = {} async def cleanup_colls(self, *collections): """Cleanup collections faster than drop_collection.""" @@ -1218,39 +1185,17 @@ class AsyncMockClientTest(AsyncUnitTest): # MockClients tests that use replicaSet, directConnection=True, pass # multiple seed addresses, or wait for heartbeat events are incompatible # with loadBalanced=True. - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod @async_client_context.require_no_load_balancer - async def _setup_class(cls): - pass - - @classmethod - async def _tearDown_class(cls): - pass - - def setUp(self): - super().setUp() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001) self.client_knobs.enable() - def tearDown(self): + async def asyncTearDown(self) -> None: self.client_knobs.disable() - super().tearDown() + await super().asyncTearDown() async def async_setup(): diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index 42a3311072..e01dd53d7e 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -42,15 +42,11 @@ class AsyncBulkTestBase(AsyncIntegrationTest): coll: AsyncCollection coll_w0: AsyncCollection - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.coll = cls.db.test - cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0)) - async def asyncSetUp(self): - super().setUp() + await super().asyncSetUp() + self.coll = self.db.test await self.coll.drop() + self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0)) def assertEqualResponse(self, expected, actual): """Compare response from bulk.execute() to expected response.""" @@ -787,14 +783,10 @@ async def test_large_inserts_unordered(self): class AsyncBulkAuthorizationTestBase(AsyncBulkTestBase): - @classmethod @async_client_context.require_auth @async_client_context.require_no_api_version - async def _setup_class(cls): - await super()._setup_class() - async def asyncSetUp(self): - super().setUp() + await super().asyncSetUp() await async_client_context.create_user(self.db.name, "readonly", "pw", ["read"]) await self.db.command( "createRole", @@ -937,21 +929,19 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase): w: Optional[int] secondary: AsyncMongoClient - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.w = async_client_context.w - cls.secondary = None - if cls.w is not None and cls.w > 1: + async def asyncSetUp(self): + await super().asyncSetUp() + self.w = async_client_context.w + self.secondary = None + if self.w is not None and self.w > 1: for member in (await async_client_context.hello)["hosts"]: if member != (await async_client_context.hello)["primary"]: - cls.secondary = await cls.unmanaged_async_single_client(*partition_node(member)) + self.secondary = await self.async_single_client(*partition_node(member)) break - @classmethod - async def async_tearDownClass(cls): - if cls.secondary: - await cls.secondary.close() + async def asyncTearDown(self): + if self.secondary: + await self.secondary.close() async def cause_wtimeout(self, requests, ordered): if not async_client_context.test_commands_enabled: diff --git a/test/asynchronous/test_change_stream.py b/test/asynchronous/test_change_stream.py index 883ed72c4c..db8a74f55a 100644 --- a/test/asynchronous/test_change_stream.py +++ b/test/asynchronous/test_change_stream.py @@ -835,18 +835,16 @@ async def test_split_large_change(self): class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin): dbs: list - @classmethod @async_client_context.require_version_min(4, 0, 0, -1) @async_client_context.require_change_streams - async def _setup_class(cls): - await super()._setup_class() - cls.dbs = [cls.db, cls.client.pymongo_test_2] + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.dbs = [self.db, self.client.pymongo_test_2] - @classmethod - async def _tearDown_class(cls): - for db in cls.dbs: - await cls.client.drop_database(db) - await super()._tearDown_class() + async def asyncTearDown(self): + for db in self.dbs: + await self.client.drop_database(db) + await super().asyncTearDown() async def change_stream_with_client(self, client, *args, **kwargs): return await client.watch(*args, **kwargs) @@ -897,11 +895,10 @@ async def test_full_pipeline(self): class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin): - @classmethod @async_client_context.require_version_min(4, 0, 0, -1) @async_client_context.require_change_streams - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() async def change_stream_with_client(self, client, *args, **kwargs): return await client[self.db.name].watch(*args, **kwargs) @@ -987,12 +984,9 @@ async def test_isolation(self): class TestAsyncCollectionAsyncChangeStream( TestAsyncChangeStreamBase, APITestsMixin, ProseSpecTestsMixin ): - @classmethod @async_client_context.require_change_streams - async def _setup_class(cls): - await super()._setup_class() - async def asyncSetUp(self): + await super().asyncSetUp() # Use a new collection for each test. await self.watched_collection().drop() await self.watched_collection().insert_one({}) @@ -1132,20 +1126,11 @@ class TestAllLegacyScenarios(AsyncIntegrationTest): RUN_ON_LOAD_BALANCER = True listener: AllowListEventListener - @classmethod @async_client_context.require_connection - async def _setup_class(cls): - await super()._setup_class() - cls.listener = AllowListEventListener("aggregate", "getMore") - cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) - - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - await super()._tearDown_class() - - def asyncSetUp(self): - super().asyncSetUp() + async def asyncSetUp(self): + await super().asyncSetUp() + self.listener = AllowListEventListener("aggregate", "getMore") + self.client = await self.async_rs_or_single_client(event_listeners=[self.listener]) self.listener.reset() async def asyncSetUpCluster(self, scenario_dict): diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index c4d71cdbe6..78bba68fe8 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -130,16 +130,11 @@ class AsyncClientUnitTest(AsyncUnitTest): client: AsyncMongoClient - @classmethod - async def _setup_class(cls): - cls.client = await cls.unmanaged_async_rs_or_single_client( + async def asyncSetUp(self) -> None: + self.client = await self.async_rs_or_single_client( connect=False, serverSelectionTimeoutMS=100 ) - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): self._caplog = caplog diff --git a/test/asynchronous/test_collation.py b/test/asynchronous/test_collation.py index be3ea22e42..abbca1aff9 100644 --- a/test/asynchronous/test_collation.py +++ b/test/asynchronous/test_collation.py @@ -97,28 +97,22 @@ class TestCollation(AsyncIntegrationTest): warn_context: Any collation: Collation - @classmethod @async_client_context.require_connection - async def _setup_class(cls): - await super()._setup_class() - cls.listener = EventListener() - cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) - cls.db = cls.client.pymongo_test - cls.collation = Collation("en_US") - cls.warn_context = warnings.catch_warnings() - cls.warn_context.__enter__() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.listener = EventListener() + self.client = await self.async_rs_or_single_client(event_listeners=[self.listener]) + self.db = self.client.pymongo_test + self.collation = Collation("en_US") + self.warn_context = warnings.catch_warnings() + self.warn_context.__enter__() warnings.simplefilter("ignore", DeprecationWarning) - @classmethod - async def _tearDown_class(cls): - cls.warn_context.__exit__() - cls.warn_context = None - await cls.client.close() - await super()._tearDown_class() - - def tearDown(self): + async def asyncTearDown(self) -> None: + self.warn_context.__exit__() + self.warn_context = None self.listener.reset() - super().tearDown() + await super().asyncTearDown() def last_command_started(self): return self.listener.started_events[-1].command diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index 470425f4ce..a2ed4de388 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -86,14 +86,10 @@ class TestCollectionNoConnect(AsyncUnitTest): db: AsyncDatabase client: AsyncMongoClient - @classmethod - async def _setup_class(cls): - cls.client = AsyncMongoClient(connect=False) - cls.db = cls.client.pymongo_test - - @classmethod - async def _tearDown_class(cls): - await cls.client.close() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.client = self.simple_client(connect=False) + self.db = self.client.pymongo_test def test_collection(self): self.assertRaises(TypeError, AsyncCollection, self.db, 5) @@ -163,27 +159,14 @@ def test_iteration(self): class AsyncTestCollection(AsyncIntegrationTest): w: int - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.w = async_client_context.w # type: ignore - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine] - else: - asyncio.run(cls.async_tearDownClass()) - - @classmethod - async def async_tearDownClass(cls): - await cls.db.drop_collection("test_large_limit") - async def asyncSetUp(self): - await self.db.test.drop() + await super().asyncSetUp() + self.w = async_client_context.w # type: ignore async def asyncTearDown(self): await self.db.test.drop() + await self.db.drop_collection("test_large_limit") + await super().asyncTearDown() @contextlib.contextmanager def write_concern_collection(self): diff --git a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py index dc04cb28a7..ffff428379 100644 --- a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py +++ b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py @@ -44,30 +44,22 @@ class TestAsyncConnectionsSurvivePrimaryStepDown(AsyncIntegrationTest): listener: CMAPListener coll: AsyncCollection - @classmethod + async def asyncTearDown(self): + await reset_client_context() + @async_client_context.require_replica_set - async def _setup_class(cls): - await super()._setup_class() - cls.listener = CMAPListener() - cls.client = await cls.unmanaged_async_rs_or_single_client( - event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500 + async def asyncSetUp(self): + self.listener = CMAPListener() + self.client = await self.async_rs_or_single_client( + event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500 ) # Ensure connections to all servers in replica set. This is to test # that the is_writable flag is properly updated for connections that # survive a replica set election. - await async_ensure_all_connected(cls.client) - cls.listener.reset() - - cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority")) - cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority")) - - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - await reset_client_context() - - async def asyncSetUp(self): + await async_ensure_all_connected(self.client) + self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority")) + self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority")) # Note that all ops use same write-concern as self.db (majority). await self.db.drop_collection("step-down") await self.db.create_collection("step-down") diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index f7b795cdae..8ae364d31c 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -1649,10 +1649,6 @@ async def test_monitoring(self): class TestRawBatchCommandCursor(AsyncIntegrationTest): - @classmethod - async def _setup_class(cls): - await super()._setup_class() - async def test_aggregate_raw(self): c = self.db.test await c.drop() diff --git a/test/asynchronous/test_database.py b/test/asynchronous/test_database.py index 61369c8542..b5a5960420 100644 --- a/test/asynchronous/test_database.py +++ b/test/asynchronous/test_database.py @@ -717,7 +717,8 @@ def test_with_options(self): class TestDatabaseAggregation(AsyncIntegrationTest): - def setUp(self): + async def asyncSetUp(self): + await super().asyncSetUp() self.pipeline: List[Mapping[str, Any]] = [ {"$listLocalSessions": {}}, {"$limit": 1}, diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index c3f6223384..2db4845c72 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -29,7 +29,6 @@ import uuid import warnings from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, async_client_context -from test.asynchronous.test_bulk import AsyncBulkTestBase from threading import Thread from typing import Any, Dict, Mapping, Optional @@ -211,11 +210,10 @@ async def test_kwargs(self): class AsyncEncryptionIntegrationTest(AsyncIntegrationTest): """Base class for encryption integration tests.""" - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @async_client_context.require_version_min(4, 2, -1) - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() def assertEncrypted(self, val): self.assertIsInstance(val, Binary) @@ -430,10 +428,9 @@ async def test_upsert_uuid_standard_encrypt(self): class TestClientMaxWireVersion(AsyncIntegrationTest): - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self): + await super().asyncSetUp() @async_client_context.require_version_max(4, 0, 99) async def test_raise_max_wire_version_error(self): @@ -818,17 +815,16 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest): "local": None, } - @classmethod @unittest.skipUnless( any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), "No environment credentials are set", ) - async def _setup_class(cls): - await super()._setup_class() - cls.listener = OvertCommandListener() - cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) - await cls.client.db.coll.drop() - cls.vault = await create_key_vault(cls.client.keyvault.datakeys) + async def asyncSetUp(self): + await super().asyncSetUp() + self.listener = OvertCommandListener() + self.client = await self.async_rs_or_single_client(event_listeners=[self.listener]) + await self.client.db.coll.drop() + self.vault = await create_key_vault(self.client.keyvault.datakeys) # Configure the encrypted field via the local schema_map option. schemas = { @@ -846,25 +842,22 @@ async def _setup_class(cls): } } opts = AutoEncryptionOpts( - cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS + self.KMS_PROVIDERS, + "keyvault.datakeys", + schema_map=schemas, + kms_tls_options=KMS_TLS_OPTS, ) - cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client( + self.client_encrypted = await self.async_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - cls.client_encryption = cls.unmanaged_create_client_encryption( - cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS + self.client_encryption = self.create_client_encryption( + self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS ) - - @classmethod - async def _tearDown_class(cls): - await cls.vault.drop() - await cls.client.close() - await cls.client_encrypted.close() - await cls.client_encryption.close() - - def setUp(self): self.listener.reset() + async def asyncTearDown(self) -> None: + await self.vault.drop() + async def run_test(self, provider_name): # Create data key. master_key: Any = self.MASTER_KEYS[provider_name] @@ -1011,10 +1004,9 @@ async def test_views_are_prohibited(self): class TestCorpus(AsyncEncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self): + await super().asyncSetUp() @staticmethod def kms_providers(): @@ -1188,12 +1180,11 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest): client_encrypted: AsyncMongoClient listener: OvertCommandListener - @classmethod - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self): + await super().asyncSetUp() db = async_client_context.client.db - cls.coll = db.coll - await cls.coll.drop() + self.coll = db.coll + await self.coll.drop() # Configure the encrypted 'db.coll' collection via jsonSchema. json_schema = json_data("limits", "limits-schema.json") await db.create_collection( @@ -1211,17 +1202,14 @@ async def _setup_class(cls): await coll.insert_one(json_data("limits", "limits-key.json")) opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") - cls.listener = OvertCommandListener() - cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client( - auto_encryption_opts=opts, event_listeners=[cls.listener] + self.listener = OvertCommandListener() + self.client_encrypted = await self.async_rs_or_single_client( + auto_encryption_opts=opts, event_listeners=[self.listener] ) - cls.coll_encrypted = cls.client_encrypted.db.coll + self.coll_encrypted = self.client_encrypted.db.coll - @classmethod - async def _tearDown_class(cls): - await cls.coll_encrypted.drop() - await cls.client_encrypted.close() - await super()._tearDown_class() + async def asyncTearDown(self) -> None: + await self.coll_encrypted.drop() async def test_01_insert_succeeds_under_2MiB(self): doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB} @@ -1285,15 +1273,12 @@ async def test_06_insert_fails_over_16MiB(self): class TestCustomEndpoint(AsyncEncryptionIntegrationTest): """Prose tests for creating data keys with a custom endpoint.""" - @classmethod @unittest.skipUnless( any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), "No environment credentials are set", ) - async def _setup_class(cls): - await super()._setup_class() - - def setUp(self): + async def asyncSetUp(self): + await super().asyncSetUp() kms_providers = { "aws": AWS_CREDS, "azure": AZURE_CREDS, @@ -1322,10 +1307,6 @@ def setUp(self): self._kmip_host_error = None self._invalid_host_error = None - async def asyncTearDown(self): - await self.client_encryption.close() - await self.client_encryption_invalid.close() - async def run_test_expected_success(self, provider_name, master_key): data_key_id = await self.client_encryption.create_data_key( provider_name, master_key=master_key @@ -1501,6 +1482,7 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest): client: AsyncMongoClient async def asyncSetUp(self): + self.client = self.simple_client() keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL) await create_key_vault(keyvault, self.DEK) @@ -1559,13 +1541,12 @@ async def _test_automatic(self, expectation_extjson, payload): class TestAzureEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set") - async def _setup_class(cls): - cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} - cls.DEK = json_data(BASE, "custom", "azure-dek.json") - cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") - await super()._setup_class() + async def asyncSetUp(self): + self.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} + self.DEK = json_data(BASE, "custom", "azure-dek.json") + self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") + await super().asyncSetUp() async def test_explicit(self): return await self._test_explicit( @@ -1585,13 +1566,12 @@ async def test_automatic(self): class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set") - async def _setup_class(cls): - cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} - cls.DEK = json_data(BASE, "custom", "gcp-dek.json") - cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") - await super()._setup_class() + async def asyncSetUp(self): + self.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} + self.DEK = json_data(BASE, "custom", "gcp-dek.json") + self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") + await super().asyncSetUp() async def test_explicit(self): return await self._test_explicit( @@ -3089,17 +3069,11 @@ class TestNoSessionsSupport(AsyncEncryptionIntegrationTest): mongocryptd_client: AsyncMongoClient MONGOCRYPTD_PORT = 27020 - @classmethod @unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed") - async def _setup_class(cls): - await super()._setup_class() - start_mongocryptd(cls.MONGOCRYPTD_PORT) - - @classmethod - async def _tearDown_class(cls): - await super()._tearDown_class() - async def asyncSetUp(self) -> None: + await super().asyncSetUp() + start_mongocryptd(self.MONGOCRYPTD_PORT) + self.listener = OvertCommandListener() self.mongocryptd_client = self.simple_client( f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener] diff --git a/test/asynchronous/test_grid_file.py b/test/asynchronous/test_grid_file.py index 9c57c15c5a..14446106e0 100644 --- a/test/asynchronous/test_grid_file.py +++ b/test/asynchronous/test_grid_file.py @@ -97,6 +97,7 @@ def test_grid_in_custom_opts(self): class AsyncTestGridFile(AsyncIntegrationTest): async def asyncSetUp(self): + await super().asyncSetUp() await self.cleanup_colls(self.db.fs.files, self.db.fs.chunks) async def test_basic(self): diff --git a/test/asynchronous/test_monitoring.py b/test/asynchronous/test_monitoring.py index b5d8708dc3..a5f991b2f0 100644 --- a/test/asynchronous/test_monitoring.py +++ b/test/asynchronous/test_monitoring.py @@ -51,22 +51,16 @@ class AsyncTestCommandMonitoring(AsyncIntegrationTest): listener: EventListener @classmethod - @async_client_context.require_connection - async def _setup_class(cls): - await super()._setup_class() + def setUpClass(cls) -> None: cls.listener = EventListener() - cls.client = await cls.unmanaged_async_rs_or_single_client( - event_listeners=[cls.listener], retryWrites=False - ) - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - await super()._tearDown_class() - - async def asyncTearDown(self): + @async_client_context.require_connection + async def asyncSetUp(self) -> None: + await super().asyncSetUp() self.listener.reset() - await super().asyncTearDown() + self.client = await self.async_rs_or_single_client( + event_listeners=[self.listener], retryWrites=False + ) async def test_started_simple(self): await self.client.pymongo_test.command("ping") @@ -1137,27 +1131,30 @@ class AsyncTestGlobalListener(AsyncIntegrationTest): saved_listeners: Any @classmethod - @async_client_context.require_connection - async def _setup_class(cls): - await super()._setup_class() + def setUpClass(cls) -> None: cls.listener = EventListener() # We plan to call register(), which internally modifies _LISTENERS. cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS) monitoring.register(cls.listener) - cls.client = await cls.unmanaged_async_single_client() - # Get one (authenticated) socket in the pool. - await cls.client.pymongo_test.command("ping") - - @classmethod - async def _tearDown_class(cls): - monitoring._LISTENERS = cls.saved_listeners - await cls.client.close() - await super()._tearDown_class() + @async_client_context.require_connection async def asyncSetUp(self): await super().asyncSetUp() + self.listener = EventListener() + # We plan to call register(), which internally modifies _LISTENERS. + self.saved_listeners = copy.deepcopy(monitoring._LISTENERS) + monitoring.register(self.listener) + self.client = await self.async_single_client() + # Get one (authenticated) socket in the pool. + await self.client.pymongo_test.command("ping") + + async def asyncTearDown(self) -> None: self.listener.reset() + @classmethod + def tearDownClass(cls): + monitoring._LISTENERS = cls.saved_listeners + async def test_simple(self): await self.client.pymongo_test.command("ping") started = self.listener.started_events[0] diff --git a/test/asynchronous/test_retryable_writes.py b/test/asynchronous/test_retryable_writes.py index accbbd003f..746f23ea48 100644 --- a/test/asynchronous/test_retryable_writes.py +++ b/test/asynchronous/test_retryable_writes.py @@ -133,34 +133,27 @@ class IgnoreDeprecationsTest(AsyncIntegrationTest): RUN_ON_SERVERLESS = True deprecation_filter: DeprecationFilter - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.deprecation_filter = DeprecationFilter() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.deprecation_filter = DeprecationFilter() - @classmethod - async def _tearDown_class(cls): - cls.deprecation_filter.stop() - await super()._tearDown_class() + async def asyncTearDown(self) -> None: + self.deprecation_filter.stop() class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest): knobs: client_knobs - @classmethod - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - cls.client = await cls.unmanaged_async_rs_or_single_client(retryWrites=True) - cls.db = cls.client.pymongo_test + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() + self.client = await self.async_rs_or_single_client(retryWrites=True) + self.db = self.client.pymongo_test - @classmethod - async def _tearDown_class(cls): - cls.knobs.disable() - await cls.client.close() - await super()._tearDown_class() + async def asyncTearDown(self) -> None: + self.knobs.disable() @async_client_context.require_no_standalone async def test_actionable_error_message(self): @@ -181,26 +174,18 @@ class TestRetryableWrites(IgnoreDeprecationsTest): listener: OvertCommandListener knobs: client_knobs - @classmethod @async_client_context.require_no_mmap - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - cls.listener = OvertCommandListener() - cls.client = await cls.unmanaged_async_rs_or_single_client( - retryWrites=True, event_listeners=[cls.listener] + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() + self.listener = OvertCommandListener() + self.client = await self.async_rs_or_single_client( + retryWrites=True, event_listeners=[self.listener] ) - cls.db = cls.client.pymongo_test + self.db = self.client.pymongo_test - @classmethod - async def _tearDown_class(cls): - cls.knobs.disable() - await cls.client.close() - await super()._tearDown_class() - - async def asyncSetUp(self): if async_client_context.is_rs and async_client_context.test_commands_enabled: await self.client.admin.command( SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")]) @@ -211,6 +196,7 @@ async def asyncTearDown(self): await self.client.admin.command( SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")]) ) + self.knobs.disable() async def test_supported_single_statement_no_retry(self): listener = OvertCommandListener() @@ -480,13 +466,12 @@ class TestWriteConcernError(AsyncIntegrationTest): RUN_ON_SERVERLESS = True fail_insert: dict - @classmethod @async_client_context.require_replica_set @async_client_context.require_no_mmap @async_client_context.require_failCommand_fail_point - async def _setup_class(cls): - await super()._setup_class() - cls.fail_insert = { + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.fail_insert = { "configureFailPoint": "failCommand", "mode": {"times": 2}, "data": { diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index c1dac6f56d..e424796ce0 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -81,36 +81,27 @@ class TestSession(AsyncIntegrationTest): client2: AsyncMongoClient sensitive_commands: Set[str] - @classmethod @async_client_context.require_sessions - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self): + await super().asyncSetUp() # Create a second client so we can make sure clients cannot share # sessions. - cls.client2 = await cls.unmanaged_async_rs_or_single_client() + self.client2 = await self.async_rs_or_single_client() # Redact no commands, so we can test user-admin commands have "lsid". - cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() + self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() monitoring._SENSITIVE_COMMANDS.clear() - @classmethod - async def _tearDown_class(cls): - monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands) - await cls.client2.close() - await super()._tearDown_class() - - async def asyncSetUp(self): self.listener = SessionTestListener() self.session_checker_listener = SessionTestListener() self.client = await self.async_rs_or_single_client( event_listeners=[self.listener, self.session_checker_listener] ) - self.addAsyncCleanup(self.client.close) self.db = self.client.pymongo_test self.initial_lsids = {s["id"] for s in session_ids(self.client)} async def asyncTearDown(self): - """All sessions used in the test must be returned to the pool.""" + monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands) await self.client.drop_database("pymongo_test") used_lsids = self.initial_lsids.copy() for event in self.session_checker_listener.started_events: @@ -120,6 +111,8 @@ async def asyncTearDown(self): current_lsids = {s["id"] for s in session_ids(self.client)} self.assertLessEqual(used_lsids, current_lsids) + await super().asyncTearDown() + async def _test_ops(self, client, *ops): listener = client.options.event_listeners[0] @@ -831,18 +824,11 @@ class TestCausalConsistency(AsyncUnitTest): listener: SessionTestListener client: AsyncMongoClient - @classmethod - async def _setup_class(cls): - cls.listener = SessionTestListener() - cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) - - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - @async_client_context.require_sessions async def asyncSetUp(self): await super().asyncSetUp() + self.listener = SessionTestListener() + self.client = await self.async_rs_or_single_client(event_listeners=[self.listener]) @async_client_context.require_no_standalone async def test_core(self): diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index 229046e79b..d11d0a9776 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -403,21 +403,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): class TestTransactionsConvenientAPI(AsyncTransactionsBase): - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.mongos_clients = [] + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.mongos_clients = [] if async_client_context.supports_transactions(): for address in async_client_context.mongoses: - cls.mongos_clients.append( - await cls.unmanaged_async_single_client("{}:{}".format(*address)) - ) - - @classmethod - async def _tearDown_class(cls): - for client in cls.mongos_clients: - await client.close() - await super()._tearDown_class() + self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address))) async def _set_fail_point(self, client, command_args): cmd = {"configureFailPoint": "failCommand"} diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 12cb13c2cd..144647f16f 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -88,30 +88,24 @@ class AsyncSpecRunner(AsyncIntegrationTest): knobs: client_knobs listener: EventListener - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.mongos_clients = [] + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.mongos_clients = [] # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - - @classmethod - async def _tearDown_class(cls): - cls.knobs.disable() - for client in cls.mongos_clients: - await client.close() - await super()._tearDown_class() - - def setUp(self): - super().setUp() + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() self.targets = {} self.listener = None # type: ignore self.pool_listener = None self.server_listener = None self.maxDiff = None + async def asyncTearDown(self) -> None: + self.knobs.disable() + for client in self.mongos_clients: + await client.close() + async def _set_fail_point(self, client, command_args): cmd = SON([("configureFailPoint", "failCommand")]) cmd.update(command_args) diff --git a/test/test_bulk.py b/test/test_bulk.py index 64fd48e8cd..ad22c1ce9a 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -42,15 +42,11 @@ class BulkTestBase(IntegrationTest): coll: Collection coll_w0: Collection - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.coll = cls.db.test - cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0)) - def setUp(self): super().setUp() + self.coll = self.db.test self.coll.drop() + self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0)) def assertEqualResponse(self, expected, actual): """Compare response from bulk.execute() to expected response.""" @@ -785,12 +781,8 @@ def test_large_inserts_unordered(self): class BulkAuthorizationTestBase(BulkTestBase): - @classmethod @client_context.require_auth @client_context.require_no_api_version - def _setup_class(cls): - super()._setup_class() - def setUp(self): super().setUp() client_context.create_user(self.db.name, "readonly", "pw", ["read"]) @@ -935,21 +927,19 @@ class TestBulkWriteConcern(BulkTestBase): w: Optional[int] secondary: MongoClient - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.w = client_context.w - cls.secondary = None - if cls.w is not None and cls.w > 1: + def setUp(self): + super().setUp() + self.w = client_context.w + self.secondary = None + if self.w is not None and self.w > 1: for member in (client_context.hello)["hosts"]: if member != (client_context.hello)["primary"]: - cls.secondary = cls.unmanaged_single_client(*partition_node(member)) + self.secondary = self.single_client(*partition_node(member)) break - @classmethod - def async_tearDownClass(cls): - if cls.secondary: - cls.secondary.close() + def tearDown(self): + if self.secondary: + self.secondary.close() def cause_wtimeout(self, requests, ordered): if not client_context.test_commands_enabled: diff --git a/test/test_change_stream.py b/test/test_change_stream.py index dae224c5e0..0742384184 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -819,18 +819,16 @@ def test_split_large_change(self): class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin): dbs: list - @classmethod @client_context.require_version_min(4, 0, 0, -1) @client_context.require_change_streams - def _setup_class(cls): - super()._setup_class() - cls.dbs = [cls.db, cls.client.pymongo_test_2] + def setUp(self) -> None: + super().setUp() + self.dbs = [self.db, self.client.pymongo_test_2] - @classmethod - def _tearDown_class(cls): - for db in cls.dbs: - cls.client.drop_database(db) - super()._tearDown_class() + def tearDown(self): + for db in self.dbs: + self.client.drop_database(db) + super().tearDown() def change_stream_with_client(self, client, *args, **kwargs): return client.watch(*args, **kwargs) @@ -881,11 +879,10 @@ def test_full_pipeline(self): class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin): - @classmethod @client_context.require_version_min(4, 0, 0, -1) @client_context.require_change_streams - def _setup_class(cls): - super()._setup_class() + def setUp(self) -> None: + super().setUp() def change_stream_with_client(self, client, *args, **kwargs): return client[self.db.name].watch(*args, **kwargs) @@ -967,12 +964,9 @@ def test_isolation(self): class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, ProseSpecTestsMixin): - @classmethod @client_context.require_change_streams - def _setup_class(cls): - super()._setup_class() - def setUp(self): + super().setUp() # Use a new collection for each test. self.watched_collection().drop() self.watched_collection().insert_one({}) @@ -1110,20 +1104,11 @@ class TestAllLegacyScenarios(IntegrationTest): RUN_ON_LOAD_BALANCER = True listener: AllowListEventListener - @classmethod @client_context.require_connection - def _setup_class(cls): - super()._setup_class() - cls.listener = AllowListEventListener("aggregate", "getMore") - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - - @classmethod - def _tearDown_class(cls): - cls.client.close() - super()._tearDown_class() - def setUp(self): super().setUp() + self.listener = AllowListEventListener("aggregate", "getMore") + self.client = self.rs_or_single_client(event_listeners=[self.listener]) self.listener.reset() def setUpCluster(self, scenario_dict): diff --git a/test/test_client.py b/test/test_client.py index a4c521157b..21885b5c51 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -129,13 +129,8 @@ class ClientUnitTest(UnitTest): client: MongoClient - @classmethod - def _setup_class(cls): - cls.client = cls.unmanaged_rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) - - @classmethod - def _tearDown_class(cls): - cls.client.close() + def setUp(self) -> None: + self.client = self.rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): diff --git a/test/test_collation.py b/test/test_collation.py index e5c1c7eb11..6d4e958a1f 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -97,26 +97,20 @@ class TestCollation(IntegrationTest): warn_context: Any collation: Collation - @classmethod @client_context.require_connection - def _setup_class(cls): - super()._setup_class() - cls.listener = EventListener() - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - cls.db = cls.client.pymongo_test - cls.collation = Collation("en_US") - cls.warn_context = warnings.catch_warnings() - cls.warn_context.__enter__() + def setUp(self) -> None: + super().setUp() + self.listener = EventListener() + self.client = self.rs_or_single_client(event_listeners=[self.listener]) + self.db = self.client.pymongo_test + self.collation = Collation("en_US") + self.warn_context = warnings.catch_warnings() + self.warn_context.__enter__() warnings.simplefilter("ignore", DeprecationWarning) - @classmethod - def _tearDown_class(cls): - cls.warn_context.__exit__() - cls.warn_context = None - cls.client.close() - super()._tearDown_class() - - def tearDown(self): + def tearDown(self) -> None: + self.warn_context.__exit__() + self.warn_context = None self.listener.reset() super().tearDown() diff --git a/test/test_collection.py b/test/test_collection.py index f2f01ac686..9364d34e34 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -86,14 +86,10 @@ class TestCollectionNoConnect(UnitTest): db: Database client: MongoClient - @classmethod - def _setup_class(cls): - cls.client = MongoClient(connect=False) - cls.db = cls.client.pymongo_test - - @classmethod - def _tearDown_class(cls): - cls.client.close() + def setUp(self) -> None: + super().setUp() + self.client = self.simple_client(connect=False) + self.db = self.client.pymongo_test def test_collection(self): self.assertRaises(TypeError, Collection, self.db, 5) @@ -163,27 +159,14 @@ def test_iteration(self): class TestCollection(IntegrationTest): w: int - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.w = client_context.w # type: ignore - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine] - else: - asyncio.run(cls.async_tearDownClass()) - - @classmethod - def async_tearDownClass(cls): - cls.db.drop_collection("test_large_limit") - def setUp(self): - self.db.test.drop() + super().setUp() + self.w = client_context.w # type: ignore def tearDown(self): self.db.test.drop() + self.db.drop_collection("test_large_limit") + super().tearDown() @contextlib.contextmanager def write_concern_collection(self): diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index 984d700fb3..4387850a00 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -44,30 +44,22 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): listener: CMAPListener coll: Collection - @classmethod + def tearDown(self): + reset_client_context() + @client_context.require_replica_set - def _setup_class(cls): - super()._setup_class() - cls.listener = CMAPListener() - cls.client = cls.unmanaged_rs_or_single_client( - event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500 + def setUp(self): + self.listener = CMAPListener() + self.client = self.rs_or_single_client( + event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500 ) # Ensure connections to all servers in replica set. This is to test # that the is_writable flag is properly updated for connections that # survive a replica set election. - ensure_all_connected(cls.client) - cls.listener.reset() - - cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority")) - cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority")) - - @classmethod - def _tearDown_class(cls): - cls.client.close() - reset_client_context() - - def setUp(self): + ensure_all_connected(self.client) + self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority")) + self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority")) # Note that all ops use same write-concern as self.db (majority). self.db.drop_collection("step-down") self.db.create_collection("step-down") diff --git a/test/test_cursor.py b/test/test_cursor.py index 7c073bf351..f279dc89a3 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1638,10 +1638,6 @@ def test_monitoring(self): class TestRawBatchCommandCursor(IntegrationTest): - @classmethod - def _setup_class(cls): - super()._setup_class() - def test_aggregate_raw(self): c = self.db.test c.drop() diff --git a/test/test_custom_types.py b/test/test_custom_types.py index abaa820cb7..6771ea25f9 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -633,6 +633,7 @@ class MyType(pytype): # type: ignore class TestCollectionWCustomType(IntegrationTest): def setUp(self): + super().setUp() self.db.test.drop() def tearDown(self): @@ -754,6 +755,7 @@ def test_find_one_and__w_custom_type_decoder(self): class TestGridFileCustomType(IntegrationTest): def setUp(self): + super().setUp() self.db.drop_collection("fs.files") self.db.drop_collection("fs.chunks") @@ -917,11 +919,10 @@ def run_test(doc_cls): class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): - @classmethod @client_context.require_change_streams - def setUpClass(cls): - super().setUpClass() - cls.db.test.delete_many({}) + def setUp(self): + super().setUp() + self.db.test.delete_many({}) def tearDown(self): self.input_target.drop() @@ -935,12 +936,11 @@ def create_targets(self, *args, **kwargs): class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): - @classmethod @client_context.require_version_min(4, 0, 0) @client_context.require_change_streams - def setUpClass(cls): - super().setUpClass() - cls.db.test.delete_many({}) + def setUp(self): + super().setUp() + self.db.test.delete_many({}) def tearDown(self): self.input_target.drop() @@ -954,12 +954,11 @@ def create_targets(self, *args, **kwargs): class TestClusterChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): - @classmethod @client_context.require_version_min(4, 0, 0) @client_context.require_change_streams - def setUpClass(cls): - super().setUpClass() - cls.db.test.delete_many({}) + def setUp(self): + super().setUp() + self.db.test.delete_many({}) def tearDown(self): self.input_target.drop() diff --git a/test/test_database.py b/test/test_database.py index 4973ed0134..5e854c941d 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -709,6 +709,7 @@ def test_with_options(self): class TestDatabaseAggregation(IntegrationTest): def setUp(self): + super().setUp() self.pipeline: List[Mapping[str, Any]] = [ {"$listLocalSessions": {}}, {"$limit": 1}, diff --git a/test/test_encryption.py b/test/test_encryption.py index 43c85e2c5b..428b787f0a 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -29,7 +29,6 @@ import uuid import warnings from test import IntegrationTest, PyMongoTestCase, client_context -from test.test_bulk import BulkTestBase from threading import Thread from typing import Any, Dict, Mapping, Optional @@ -211,11 +210,10 @@ def test_kwargs(self): class EncryptionIntegrationTest(IntegrationTest): """Base class for encryption integration tests.""" - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @client_context.require_version_min(4, 2, -1) - def _setup_class(cls): - super()._setup_class() + def setUp(self) -> None: + super().setUp() def assertEncrypted(self, val): self.assertIsInstance(val, Binary) @@ -430,10 +428,9 @@ def test_upsert_uuid_standard_encrypt(self): class TestClientMaxWireVersion(IntegrationTest): - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") - def _setup_class(cls): - super()._setup_class() + def setUp(self): + super().setUp() @client_context.require_version_max(4, 0, 99) def test_raise_max_wire_version_error(self): @@ -814,17 +811,16 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest): "local": None, } - @classmethod @unittest.skipUnless( any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), "No environment credentials are set", ) - def _setup_class(cls): - super()._setup_class() - cls.listener = OvertCommandListener() - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - cls.client.db.coll.drop() - cls.vault = create_key_vault(cls.client.keyvault.datakeys) + def setUp(self): + super().setUp() + self.listener = OvertCommandListener() + self.client = self.rs_or_single_client(event_listeners=[self.listener]) + self.client.db.coll.drop() + self.vault = create_key_vault(self.client.keyvault.datakeys) # Configure the encrypted field via the local schema_map option. schemas = { @@ -842,25 +838,22 @@ def _setup_class(cls): } } opts = AutoEncryptionOpts( - cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS + self.KMS_PROVIDERS, + "keyvault.datakeys", + schema_map=schemas, + kms_tls_options=KMS_TLS_OPTS, ) - cls.client_encrypted = cls.unmanaged_rs_or_single_client( + self.client_encrypted = self.rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - cls.client_encryption = cls.unmanaged_create_client_encryption( - cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS + self.client_encryption = self.create_client_encryption( + self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS ) - - @classmethod - def _tearDown_class(cls): - cls.vault.drop() - cls.client.close() - cls.client_encrypted.close() - cls.client_encryption.close() - - def setUp(self): self.listener.reset() + def tearDown(self) -> None: + self.vault.drop() + def run_test(self, provider_name): # Create data key. master_key: Any = self.MASTER_KEYS[provider_name] @@ -1005,10 +998,9 @@ def test_views_are_prohibited(self): class TestCorpus(EncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") - def _setup_class(cls): - super()._setup_class() + def setUp(self): + super().setUp() @staticmethod def kms_providers(): @@ -1182,12 +1174,11 @@ class TestBsonSizeBatches(EncryptionIntegrationTest): client_encrypted: MongoClient listener: OvertCommandListener - @classmethod - def _setup_class(cls): - super()._setup_class() + def setUp(self): + super().setUp() db = client_context.client.db - cls.coll = db.coll - cls.coll.drop() + self.coll = db.coll + self.coll.drop() # Configure the encrypted 'db.coll' collection via jsonSchema. json_schema = json_data("limits", "limits-schema.json") db.create_collection( @@ -1205,17 +1196,14 @@ def _setup_class(cls): coll.insert_one(json_data("limits", "limits-key.json")) opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") - cls.listener = OvertCommandListener() - cls.client_encrypted = cls.unmanaged_rs_or_single_client( - auto_encryption_opts=opts, event_listeners=[cls.listener] + self.listener = OvertCommandListener() + self.client_encrypted = self.rs_or_single_client( + auto_encryption_opts=opts, event_listeners=[self.listener] ) - cls.coll_encrypted = cls.client_encrypted.db.coll + self.coll_encrypted = self.client_encrypted.db.coll - @classmethod - def _tearDown_class(cls): - cls.coll_encrypted.drop() - cls.client_encrypted.close() - super()._tearDown_class() + def tearDown(self) -> None: + self.coll_encrypted.drop() def test_01_insert_succeeds_under_2MiB(self): doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB} @@ -1279,15 +1267,12 @@ def test_06_insert_fails_over_16MiB(self): class TestCustomEndpoint(EncryptionIntegrationTest): """Prose tests for creating data keys with a custom endpoint.""" - @classmethod @unittest.skipUnless( any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), "No environment credentials are set", ) - def _setup_class(cls): - super()._setup_class() - def setUp(self): + super().setUp() kms_providers = { "aws": AWS_CREDS, "azure": AZURE_CREDS, @@ -1316,10 +1301,6 @@ def setUp(self): self._kmip_host_error = None self._invalid_host_error = None - def tearDown(self): - self.client_encryption.close() - self.client_encryption_invalid.close() - def run_test_expected_success(self, provider_name, master_key): data_key_id = self.client_encryption.create_data_key(provider_name, master_key=master_key) encrypted = self.client_encryption.encrypt( @@ -1493,6 +1474,7 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest): client: MongoClient def setUp(self): + self.client = self.simple_client() keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL) create_key_vault(keyvault, self.DEK) @@ -1551,13 +1533,12 @@ def _test_automatic(self, expectation_extjson, payload): class TestAzureEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set") - def _setup_class(cls): - cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} - cls.DEK = json_data(BASE, "custom", "azure-dek.json") - cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") - super()._setup_class() + def setUp(self): + self.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} + self.DEK = json_data(BASE, "custom", "azure-dek.json") + self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") + super().setUp() def test_explicit(self): return self._test_explicit( @@ -1577,13 +1558,12 @@ def test_automatic(self): class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set") - def _setup_class(cls): - cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} - cls.DEK = json_data(BASE, "custom", "gcp-dek.json") - cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") - super()._setup_class() + def setUp(self): + self.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} + self.DEK = json_data(BASE, "custom", "gcp-dek.json") + self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") + super().setUp() def test_explicit(self): return self._test_explicit( @@ -3069,17 +3049,11 @@ class TestNoSessionsSupport(EncryptionIntegrationTest): mongocryptd_client: MongoClient MONGOCRYPTD_PORT = 27020 - @classmethod @unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed") - def _setup_class(cls): - super()._setup_class() - start_mongocryptd(cls.MONGOCRYPTD_PORT) - - @classmethod - def _tearDown_class(cls): - super()._tearDown_class() - def setUp(self) -> None: + super().setUp() + start_mongocryptd(self.MONGOCRYPTD_PORT) + self.listener = OvertCommandListener() self.mongocryptd_client = self.simple_client( f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener] diff --git a/test/test_grid_file.py b/test/test_grid_file.py index fe88aec5ff..0a5b1ad40a 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -97,6 +97,7 @@ def test_grid_in_custom_opts(self): class TestGridFile(IntegrationTest): def setUp(self): + super().setUp() self.cleanup_colls(self.db.fs.files, self.db.fs.chunks) def test_basic(self): diff --git a/test/test_monitoring.py b/test/test_monitoring.py index a0c520ed27..31f546fe54 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -51,22 +51,14 @@ class TestCommandMonitoring(IntegrationTest): listener: EventListener @classmethod - @client_context.require_connection - def _setup_class(cls): - super()._setup_class() + def setUpClass(cls) -> None: cls.listener = EventListener() - cls.client = cls.unmanaged_rs_or_single_client( - event_listeners=[cls.listener], retryWrites=False - ) - @classmethod - def _tearDown_class(cls): - cls.client.close() - super()._tearDown_class() - - def tearDown(self): + @client_context.require_connection + def setUp(self) -> None: + super().setUp() self.listener.reset() - super().tearDown() + self.client = self.rs_or_single_client(event_listeners=[self.listener], retryWrites=False) def test_started_simple(self): self.client.pymongo_test.command("ping") @@ -1137,27 +1129,30 @@ class TestGlobalListener(IntegrationTest): saved_listeners: Any @classmethod - @client_context.require_connection - def _setup_class(cls): - super()._setup_class() + def setUpClass(cls) -> None: cls.listener = EventListener() # We plan to call register(), which internally modifies _LISTENERS. cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS) monitoring.register(cls.listener) - cls.client = cls.unmanaged_single_client() - # Get one (authenticated) socket in the pool. - cls.client.pymongo_test.command("ping") - - @classmethod - def _tearDown_class(cls): - monitoring._LISTENERS = cls.saved_listeners - cls.client.close() - super()._tearDown_class() + @client_context.require_connection def setUp(self): super().setUp() + self.listener = EventListener() + # We plan to call register(), which internally modifies _LISTENERS. + self.saved_listeners = copy.deepcopy(monitoring._LISTENERS) + monitoring.register(self.listener) + self.client = self.single_client() + # Get one (authenticated) socket in the pool. + self.client.pymongo_test.command("ping") + + def tearDown(self) -> None: self.listener.reset() + @classmethod + def tearDownClass(cls): + monitoring._LISTENERS = cls.saved_listeners + def test_simple(self): self.client.pymongo_test.command("ping") started = self.listener.started_events[0] diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 5df6c41f7a..eb814c4ef9 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -133,34 +133,27 @@ class IgnoreDeprecationsTest(IntegrationTest): RUN_ON_SERVERLESS = True deprecation_filter: DeprecationFilter - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.deprecation_filter = DeprecationFilter() + def setUp(self) -> None: + super().setUp() + self.deprecation_filter = DeprecationFilter() - @classmethod - def _tearDown_class(cls): - cls.deprecation_filter.stop() - super()._tearDown_class() + def tearDown(self) -> None: + self.deprecation_filter.stop() class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest): knobs: client_knobs - @classmethod - def _setup_class(cls): - super()._setup_class() + def setUp(self) -> None: + super().setUp() # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - cls.client = cls.unmanaged_rs_or_single_client(retryWrites=True) - cls.db = cls.client.pymongo_test + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() + self.client = self.rs_or_single_client(retryWrites=True) + self.db = self.client.pymongo_test - @classmethod - def _tearDown_class(cls): - cls.knobs.disable() - cls.client.close() - super()._tearDown_class() + def tearDown(self) -> None: + self.knobs.disable() @client_context.require_no_standalone def test_actionable_error_message(self): @@ -181,26 +174,16 @@ class TestRetryableWrites(IgnoreDeprecationsTest): listener: OvertCommandListener knobs: client_knobs - @classmethod @client_context.require_no_mmap - def _setup_class(cls): - super()._setup_class() + def setUp(self) -> None: + super().setUp() # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - cls.listener = OvertCommandListener() - cls.client = cls.unmanaged_rs_or_single_client( - retryWrites=True, event_listeners=[cls.listener] - ) - cls.db = cls.client.pymongo_test - - @classmethod - def _tearDown_class(cls): - cls.knobs.disable() - cls.client.close() - super()._tearDown_class() + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() + self.listener = OvertCommandListener() + self.client = self.rs_or_single_client(retryWrites=True, event_listeners=[self.listener]) + self.db = self.client.pymongo_test - def setUp(self): if client_context.is_rs and client_context.test_commands_enabled: self.client.admin.command( SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")]) @@ -211,6 +194,7 @@ def tearDown(self): self.client.admin.command( SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")]) ) + self.knobs.disable() def test_supported_single_statement_no_retry(self): listener = OvertCommandListener() @@ -480,13 +464,12 @@ class TestWriteConcernError(IntegrationTest): RUN_ON_SERVERLESS = True fail_insert: dict - @classmethod @client_context.require_replica_set @client_context.require_no_mmap @client_context.require_failCommand_fail_point - def _setup_class(cls): - super()._setup_class() - cls.fail_insert = { + def setUp(self) -> None: + super().setUp() + self.fail_insert = { "configureFailPoint": "failCommand", "mode": {"times": 2}, "data": { diff --git a/test/test_session.py b/test/test_session.py index 9f94ded927..980d9df688 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -81,36 +81,27 @@ class TestSession(IntegrationTest): client2: MongoClient sensitive_commands: Set[str] - @classmethod @client_context.require_sessions - def _setup_class(cls): - super()._setup_class() + def setUp(self): + super().setUp() # Create a second client so we can make sure clients cannot share # sessions. - cls.client2 = cls.unmanaged_rs_or_single_client() + self.client2 = self.rs_or_single_client() # Redact no commands, so we can test user-admin commands have "lsid". - cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() + self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() monitoring._SENSITIVE_COMMANDS.clear() - @classmethod - def _tearDown_class(cls): - monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands) - cls.client2.close() - super()._tearDown_class() - - def setUp(self): self.listener = SessionTestListener() self.session_checker_listener = SessionTestListener() self.client = self.rs_or_single_client( event_listeners=[self.listener, self.session_checker_listener] ) - self.addCleanup(self.client.close) self.db = self.client.pymongo_test self.initial_lsids = {s["id"] for s in session_ids(self.client)} def tearDown(self): - """All sessions used in the test must be returned to the pool.""" + monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands) self.client.drop_database("pymongo_test") used_lsids = self.initial_lsids.copy() for event in self.session_checker_listener.started_events: @@ -120,6 +111,8 @@ def tearDown(self): current_lsids = {s["id"] for s in session_ids(self.client)} self.assertLessEqual(used_lsids, current_lsids) + super().tearDown() + def _test_ops(self, client, *ops): listener = client.options.event_listeners[0] @@ -831,18 +824,11 @@ class TestCausalConsistency(UnitTest): listener: SessionTestListener client: MongoClient - @classmethod - def _setup_class(cls): - cls.listener = SessionTestListener() - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - - @classmethod - def _tearDown_class(cls): - cls.client.close() - @client_context.require_sessions def setUp(self): super().setUp() + self.listener = SessionTestListener() + self.client = self.rs_or_single_client(event_listeners=[self.listener]) @client_context.require_no_standalone def test_core(self): diff --git a/test/test_transactions.py b/test/test_transactions.py index 3cecbe9d38..949b88e60b 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -395,19 +395,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): class TestTransactionsConvenientAPI(TransactionsBase): - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.mongos_clients = [] + def setUp(self) -> None: + super().setUp() + self.mongos_clients = [] if client_context.supports_transactions(): for address in client_context.mongoses: - cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address))) - - @classmethod - def _tearDown_class(cls): - for client in cls.mongos_clients: - client.close() - super()._tearDown_class() + self.mongos_clients.append(self.single_client("{}:{}".format(*address))) def _set_fail_point(self, client, command_args): cmd = {"configureFailPoint": "failCommand"} diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 06a40351cd..7cc30ba017 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -88,30 +88,24 @@ class SpecRunner(IntegrationTest): knobs: client_knobs listener: EventListener - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.mongos_clients = [] + def setUp(self) -> None: + super().setUp() + self.mongos_clients = [] # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - - @classmethod - def _tearDown_class(cls): - cls.knobs.disable() - for client in cls.mongos_clients: - client.close() - super()._tearDown_class() - - def setUp(self): - super().setUp() + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() self.targets = {} self.listener = None # type: ignore self.pool_listener = None self.server_listener = None self.maxDiff = None + def tearDown(self) -> None: + self.knobs.disable() + for client in self.mongos_clients: + client.close() + def _set_fail_point(self, client, command_args): cmd = SON([("configureFailPoint", "failCommand")]) cmd.update(command_args) From a911245bde1377c485f06dfd5373d159b7e8aff7 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 14 Oct 2024 15:06:42 -0700 Subject: [PATCH 13/19] PYTHON-4866 Fix test_command_cursor_to_list_csot_applied (#1926) --- test/asynchronous/test_cursor.py | 14 ++++++-------- test/test_cursor.py | 14 ++++++-------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index e79ad00641..ee0a757ed3 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -1412,12 +1412,11 @@ async def test_to_list_length(self): self.assertEqual(len(docs), 2) async def test_to_list_csot_applied(self): - client = await self.async_single_client(timeoutMS=500) + client = await self.async_single_client(timeoutMS=500, w=1) + coll = client.pymongo.test # Initialize the client with a larger timeout to help make test less flakey with pymongo.timeout(10): - await client.admin.command("ping") - coll = client.pymongo.test - await coll.insert_many([{} for _ in range(5)]) + await coll.insert_many([{} for _ in range(5)]) cursor = coll.find({"$where": delay(1)}) with self.assertRaises(PyMongoError) as ctx: await cursor.to_list() @@ -1454,12 +1453,11 @@ async def test_command_cursor_to_list_length(self): @async_client_context.require_failCommand_blockConnection async def test_command_cursor_to_list_csot_applied(self): - client = await self.async_single_client(timeoutMS=500) + client = await self.async_single_client(timeoutMS=500, w=1) + coll = client.pymongo.test # Initialize the client with a larger timeout to help make test less flakey with pymongo.timeout(10): - await client.admin.command("ping") - coll = client.pymongo.test - await coll.insert_many([{} for _ in range(5)]) + await coll.insert_many([{} for _ in range(5)]) fail_command = { "configureFailPoint": "failCommand", "mode": {"times": 5}, diff --git a/test/test_cursor.py b/test/test_cursor.py index 7c073bf351..7a6dfc9429 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1403,12 +1403,11 @@ def test_to_list_length(self): self.assertEqual(len(docs), 2) def test_to_list_csot_applied(self): - client = self.single_client(timeoutMS=500) + client = self.single_client(timeoutMS=500, w=1) + coll = client.pymongo.test # Initialize the client with a larger timeout to help make test less flakey with pymongo.timeout(10): - client.admin.command("ping") - coll = client.pymongo.test - coll.insert_many([{} for _ in range(5)]) + coll.insert_many([{} for _ in range(5)]) cursor = coll.find({"$where": delay(1)}) with self.assertRaises(PyMongoError) as ctx: cursor.to_list() @@ -1445,12 +1444,11 @@ def test_command_cursor_to_list_length(self): @client_context.require_failCommand_blockConnection def test_command_cursor_to_list_csot_applied(self): - client = self.single_client(timeoutMS=500) + client = self.single_client(timeoutMS=500, w=1) + coll = client.pymongo.test # Initialize the client with a larger timeout to help make test less flakey with pymongo.timeout(10): - client.admin.command("ping") - coll = client.pymongo.test - coll.insert_many([{} for _ in range(5)]) + coll.insert_many([{} for _ in range(5)]) fail_command = { "configureFailPoint": "failCommand", "mode": {"times": 5}, From 9e38c54fa03d0f719a43ff023894c2a1ad9b5480 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Mon, 14 Oct 2024 15:25:21 -0700 Subject: [PATCH 14/19] PYTHON-4861 Fix HATCH_CONFIG on cygwin (#1927) --- .evergreen/hatch.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.evergreen/hatch.sh b/.evergreen/hatch.sh index 8f862c39d2..6f3d36b389 100644 --- a/.evergreen/hatch.sh +++ b/.evergreen/hatch.sh @@ -29,6 +29,9 @@ else # Set up virtualenv before installing hatch # Ensure hatch does not write to user or global locations. touch hatch_config.toml HATCH_CONFIG=$(pwd)/hatch_config.toml + if [ "Windows_NT" = "$OS" ]; then # Magic variable in cygwin + HATCH_CONFIG=$(cygpath -m "$HATCH_CONFIG") + fi export HATCH_CONFIG hatch config restore hatch config set dirs.data ".hatch/data" From 872fda179e247fb8e1bcc3cf2af3d892788a2e2f Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 15 Oct 2024 08:54:42 -0400 Subject: [PATCH 15/19] PYTHON-4574 - FaaS detection logic mistakenly identifies EKS as AWS Lambda (#1908) --- test/asynchronous/test_client.py | 16 ++++++++++++++++ test/test_client.py | 16 ++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index faa23348c9..c6b6416c16 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -2019,6 +2019,22 @@ async def test_handshake_08_invalid_aws_ec2(self): None, ) + async def test_handshake_09_container_with_provider(self): + await self._test_handshake( + { + ENV_VAR_K8S: "1", + "AWS_LAMBDA_RUNTIME_API": "1", + "AWS_REGION": "us-east-1", + "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "256", + }, + { + "container": {"orchestrator": "kubernetes"}, + "name": "aws.lambda", + "region": "us-east-1", + "memory_mb": 256, + }, + ) + def test_dict_hints(self): self.db.t.find(hint={"x": 1}) diff --git a/test/test_client.py b/test/test_client.py index be1994dd93..8e3d9c8b8b 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1977,6 +1977,22 @@ def test_handshake_08_invalid_aws_ec2(self): None, ) + def test_handshake_09_container_with_provider(self): + self._test_handshake( + { + ENV_VAR_K8S: "1", + "AWS_LAMBDA_RUNTIME_API": "1", + "AWS_REGION": "us-east-1", + "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "256", + }, + { + "container": {"orchestrator": "kubernetes"}, + "name": "aws.lambda", + "region": "us-east-1", + "memory_mb": 256, + }, + ) + def test_dict_hints(self): self.db.t.find(hint={"x": 1}) From 8cfd49ae48298a3247e6c37b8d4ec906d8a52106 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 15 Oct 2024 11:01:12 -0400 Subject: [PATCH 16/19] All tests passing --- test/test_examples.py | 13 ++++--------- test/test_gridfs.py | 20 ++++++++------------ test/test_gridfs_bucket.py | 14 +++++--------- test/test_read_concern.py | 20 ++++++-------------- test/test_sdam_monitoring_spec.py | 2 +- test/test_threads.py | 1 + test/test_typing.py | 7 +++---- 7 files changed, 28 insertions(+), 49 deletions(-) diff --git a/test/test_examples.py b/test/test_examples.py index ebf1d784a3..7f98226e7a 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -33,19 +33,14 @@ class TestSampleShellCommands(IntegrationTest): - @classmethod - def setUpClass(cls): - super().setUpClass() - # Run once before any tests run. - cls.db.inventory.drop() - - @classmethod - def tearDownClass(cls): - cls.client.drop_database("pymongo_test") + def setUp(self): + super().setUp() + self.db.inventory.drop() def tearDown(self): # Run after every test. self.db.inventory.drop() + self.client.drop_database("pymongo_test") def test_first_three_examples(self): db = self.db diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 549dc0b204..a36109f399 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -75,9 +75,9 @@ def run(self): class TestGridfsNoConnect(unittest.TestCase): db: Database - @classmethod - def setUpClass(cls): - cls.db = MongoClient(connect=False).pymongo_test + def setUp(self): + super().setUp() + self.db = MongoClient(connect=False).pymongo_test def test_gridfs(self): self.assertRaises(TypeError, gridfs.GridFS, "foo") @@ -88,13 +88,10 @@ class TestGridfs(IntegrationTest): fs: gridfs.GridFS alt: gridfs.GridFS - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.fs = gridfs.GridFS(cls.db) - cls.alt = gridfs.GridFS(cls.db, "alt") - def setUp(self): + super().setUp() + self.fs = gridfs.GridFS(self.db) + self.alt = gridfs.GridFS(self.db, "alt") self.cleanup_colls( self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks ) @@ -509,10 +506,9 @@ def test_md5(self): class TestGridfsReplicaSet(IntegrationTest): - @classmethod @client_context.require_secondaries_count(1) - def setUpClass(cls): - super().setUpClass() + def setUp(self): + super().setUp() @classmethod def tearDownClass(cls): diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index 28adb7051a..04c7427350 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -79,13 +79,10 @@ class TestGridfs(IntegrationTest): fs: gridfs.GridFSBucket alt: gridfs.GridFSBucket - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.fs = gridfs.GridFSBucket(cls.db) - cls.alt = gridfs.GridFSBucket(cls.db, bucket_name="alt") - def setUp(self): + super().setUp() + self.fs = gridfs.GridFSBucket(self.db) + self.alt = gridfs.GridFSBucket(self.db, bucket_name="alt") self.cleanup_colls( self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks ) @@ -479,10 +476,9 @@ def test_md5(self): class TestGridfsBucketReplicaSet(IntegrationTest): - @classmethod @client_context.require_secondaries_count(1) - def setUpClass(cls): - super().setUpClass() + def setUp(self): + super().setUp() @classmethod def tearDownClass(cls): diff --git a/test/test_read_concern.py b/test/test_read_concern.py index ea9ce49a30..f7c0901422 100644 --- a/test/test_read_concern.py +++ b/test/test_read_concern.py @@ -31,24 +31,16 @@ class TestReadConcern(IntegrationTest): listener: OvertCommandListener - @classmethod @client_context.require_connection - def setUpClass(cls): - super().setUpClass() - cls.listener = OvertCommandListener() - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - cls.db = cls.client.pymongo_test + def setUp(self): + super().setUp() + self.listener = OvertCommandListener() + self.client = self.rs_or_single_client(event_listeners=[self.listener]) + self.db = self.client.pymongo_test client_context.client.pymongo_test.create_collection("coll") - @classmethod - def tearDownClass(cls): - cls.client.close() - client_context.client.pymongo_test.drop_collection("coll") - super().tearDownClass() - def tearDown(self): - self.listener.reset() - super().tearDown() + client_context.client.pymongo_test.drop_collection("coll") def test_read_concern(self): rc = ReadConcern() diff --git a/test/test_sdam_monitoring_spec.py b/test/test_sdam_monitoring_spec.py index 81b208d511..6b808b159d 100644 --- a/test/test_sdam_monitoring_spec.py +++ b/test/test_sdam_monitoring_spec.py @@ -270,7 +270,7 @@ class TestSdamMonitoring(IntegrationTest): @classmethod @client_context.require_failCommand_fail_point def setUpClass(cls): - super().setUpClass() + super().setUp(cls) # Speed up the tests by decreasing the event publish frequency. cls.knobs = client_knobs( events_queue_frequency=0.1, heartbeat_frequency=0.1, min_heartbeat_interval=0.1 diff --git a/test/test_threads.py b/test/test_threads.py index b3dadbb1a3..3e469e28fe 100644 --- a/test/test_threads.py +++ b/test/test_threads.py @@ -105,6 +105,7 @@ def run(self): class TestThreads(IntegrationTest): def setUp(self): + super().setUp() self.db = self.client.pymongo_test def test_threading(self): diff --git a/test/test_typing.py b/test/test_typing.py index 441707616e..bfe4d032c1 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -114,10 +114,9 @@ def test_mypy_failures(self) -> None: class TestPymongo(IntegrationTest): coll: Collection - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.coll = cls.client.test.test + def setUp(self): + super().setUp() + self.coll = self.client.test.test def test_insert_find(self) -> None: doc = {"my": "doc"} From 710bc40c730d2fd982e1cb7a41fd91ac7b5d4498 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 15 Oct 2024 12:12:18 -0400 Subject: [PATCH 17/19] =?UTF-8?q?PYTHON-4870=20-=20MongoClient.address=20s?= =?UTF-8?q?hould=20block=20until=20a=20connection=20suc=E2=80=A6=20(#1929)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pymongo/asynchronous/mongo_client.py | 7 ------- pymongo/synchronous/mongo_client.py | 7 ------- test/asynchronous/test_client.py | 2 -- test/test_client.py | 2 -- test/test_replica_set_reconfig.py | 3 ++- 5 files changed, 2 insertions(+), 19 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index bfae302dac..4e09efe401 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1453,13 +1453,6 @@ async def address(self) -> Optional[tuple[str, int]]: 'Cannot use "address" property when load balancing among' ' mongoses, use "nodes" instead.' ) - if topology_type not in ( - TOPOLOGY_TYPE.ReplicaSetWithPrimary, - TOPOLOGY_TYPE.Single, - TOPOLOGY_TYPE.LoadBalanced, - TOPOLOGY_TYPE.Sharded, - ): - return None return await self._server_property("address") @property diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 1351cb200f..815446bb2c 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1447,13 +1447,6 @@ def address(self) -> Optional[tuple[str, int]]: 'Cannot use "address" property when load balancing among' ' mongoses, use "nodes" instead.' ) - if topology_type not in ( - TOPOLOGY_TYPE.ReplicaSetWithPrimary, - TOPOLOGY_TYPE.Single, - TOPOLOGY_TYPE.LoadBalanced, - TOPOLOGY_TYPE.Sharded, - ): - return None return self._server_property("address") @property diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index c6b6416c16..590154b857 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -838,8 +838,6 @@ async def test_init_disconnected(self): c = await self.async_rs_or_single_client(connect=False) self.assertIsInstance(c.topology_description, TopologyDescription) self.assertEqual(c.topology_description, c._topology._description) - self.assertIsNone(await c.address) # PYTHON-2981 - await c.admin.command("ping") # connect if async_client_context.is_rs: # The primary's host and port are from the replica set config. self.assertIsNotNone(await c.address) diff --git a/test/test_client.py b/test/test_client.py index 8e3d9c8b8b..5bbb5bd751 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -812,8 +812,6 @@ def test_init_disconnected(self): c = self.rs_or_single_client(connect=False) self.assertIsInstance(c.topology_description, TopologyDescription) self.assertEqual(c.topology_description, c._topology._description) - self.assertIsNone(c.address) # PYTHON-2981 - c.admin.command("ping") # connect if client_context.is_rs: # The primary's host and port are from the replica set config. self.assertIsNotNone(c.address) diff --git a/test/test_replica_set_reconfig.py b/test/test_replica_set_reconfig.py index 1dae0aea86..4c23d71b69 100644 --- a/test/test_replica_set_reconfig.py +++ b/test/test_replica_set_reconfig.py @@ -59,7 +59,8 @@ def test_client(self): with self.assertRaises(ServerSelectionTimeoutError): c.db.command("ping") - self.assertEqual(c.address, None) + with self.assertRaises(ServerSelectionTimeoutError): + _ = c.address # Client can still discover the primary node c.revive_host("a:1") From 7d5688e36080e6cec2ad70c5d4ee046b35b3e507 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 15 Oct 2024 17:05:24 -0400 Subject: [PATCH 18/19] Unified test runner setup fixes --- test/__init__.py | 1 - test/asynchronous/__init__.py | 1 - test/asynchronous/unified_format.py | 41 ++++++++++++----------------- test/unified_format.py | 39 ++++++++++++--------------- 4 files changed, 34 insertions(+), 48 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index e8dc4bcd8c..940518c2c5 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1173,7 +1173,6 @@ def setUp(self) -> None: super().setUp() self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001) - self.client_knobs.enable() def tearDown(self) -> None: diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index ce786966f9..8d1e3e1911 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -1191,7 +1191,6 @@ async def asyncSetUp(self) -> None: await super().asyncSetUp() self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001) - self.client_knobs.enable() async def asyncTearDown(self) -> None: diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 2ff38f06e9..9b9282a902 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -478,55 +478,42 @@ async def insert_initial_data(self, initial_data): # Ensure collection exists await db.create_collection(coll_name, write_concern=wc, **opts) - @classmethod - async def _setup_class(cls): + async def asyncSetUp(self): # super call creates internal client cls.client - await super()._setup_class() + await super().asyncSetUp() # process file-level runOnRequirements - run_on_spec = cls.TEST_SPEC.get("runOnRequirements", []) - if not await cls.should_run_on(run_on_spec): - raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied") + run_on_spec = self.TEST_SPEC.get("runOnRequirements", []) + if not await self.should_run_on(run_on_spec): + raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied") # add any special-casing for skipping tests here if async_client_context.storage_engine == "mmapv1": - if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str( - cls.TEST_PATH + if "retryable-writes" in self.TEST_SPEC["description"] or "retryable_writes" in str( + self.TEST_PATH ): raise unittest.SkipTest("MMAPv1 does not support retryWrites=True") # Handle mongos_clients for transactions tests. - cls.mongos_clients = [] + self.mongos_clients = [] if ( async_client_context.supports_transactions() and not async_client_context.load_balancer and not async_client_context.serverless ): for address in async_client_context.mongoses: - cls.mongos_clients.append( - await cls.unmanaged_async_single_client("{}:{}".format(*address)) - ) + self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address))) # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs( + self.knobs = client_knobs( heartbeat_frequency=0.1, min_heartbeat_interval=0.1, kill_cursor_frequency=0.1, events_queue_frequency=0.1, ) - cls.knobs.enable() - - @classmethod - async def _tearDown_class(cls): - cls.knobs.disable() - for client in cls.mongos_clients: - await client.close() - await super()._tearDown_class() + self.knobs.enable() - async def asyncSetUp(self): - await super().asyncSetUp() # process schemaVersion # note: we check major schema version during class generation - # note: we do this here because we cannot run assertions in setUpClass version = Version.from_string(self.TEST_SPEC["schemaVersion"]) self.assertLessEqual( version, @@ -537,6 +524,12 @@ async def asyncSetUp(self): # initialize internals self.match_evaluator = MatchEvaluatorUtil(self) + async def asyncTearDown(self): + self.knobs.disable() + for client in self.mongos_clients: + await client.close() + await super().asyncTearDown() + def maybe_skip_test(self, spec): # add any special-casing for skipping tests here if async_client_context.storage_engine == "mmapv1": diff --git a/test/unified_format.py b/test/unified_format.py index 13ab0af69b..51630ab16a 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -477,53 +477,42 @@ def insert_initial_data(self, initial_data): # Ensure collection exists db.create_collection(coll_name, write_concern=wc, **opts) - @classmethod - def _setup_class(cls): + def setUp(self): # super call creates internal client cls.client - super()._setup_class() + super().setUp() # process file-level runOnRequirements - run_on_spec = cls.TEST_SPEC.get("runOnRequirements", []) - if not cls.should_run_on(run_on_spec): - raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied") + run_on_spec = self.TEST_SPEC.get("runOnRequirements", []) + if not self.should_run_on(run_on_spec): + raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied") # add any special-casing for skipping tests here if client_context.storage_engine == "mmapv1": - if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str( - cls.TEST_PATH + if "retryable-writes" in self.TEST_SPEC["description"] or "retryable_writes" in str( + self.TEST_PATH ): raise unittest.SkipTest("MMAPv1 does not support retryWrites=True") # Handle mongos_clients for transactions tests. - cls.mongos_clients = [] + self.mongos_clients = [] if ( client_context.supports_transactions() and not client_context.load_balancer and not client_context.serverless ): for address in client_context.mongoses: - cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address))) + self.mongos_clients.append(self.single_client("{}:{}".format(*address))) # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs( + self.knobs = client_knobs( heartbeat_frequency=0.1, min_heartbeat_interval=0.1, kill_cursor_frequency=0.1, events_queue_frequency=0.1, ) - cls.knobs.enable() + self.knobs.enable() - @classmethod - def _tearDown_class(cls): - cls.knobs.disable() - for client in cls.mongos_clients: - client.close() - super()._tearDown_class() - - def setUp(self): - super().setUp() # process schemaVersion # note: we check major schema version during class generation - # note: we do this here because we cannot run assertions in setUpClass version = Version.from_string(self.TEST_SPEC["schemaVersion"]) self.assertLessEqual( version, @@ -534,6 +523,12 @@ def setUp(self): # initialize internals self.match_evaluator = MatchEvaluatorUtil(self) + def tearDown(self): + self.knobs.disable() + for client in self.mongos_clients: + client.close() + super().tearDown() + def maybe_skip_test(self, spec): # add any special-casing for skipping tests here if client_context.storage_engine == "mmapv1": From 85df5b675f7fb09110b2857b8ddf0bbcb8c574f5 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 16 Oct 2024 10:19:08 -0400 Subject: [PATCH 19/19] Fix unified test client_knobs --- test/asynchronous/unified_format.py | 25 +++++++++++++++---------- test/unified_format.py | 25 +++++++++++++++---------- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 9b9282a902..f25e96e04d 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -478,6 +478,21 @@ async def insert_initial_data(self, initial_data): # Ensure collection exists await db.create_collection(coll_name, write_concern=wc, **opts) + @classmethod + def setUpClass(cls) -> None: + # Speed up the tests by decreasing the heartbeat frequency. + cls.knobs = client_knobs( + heartbeat_frequency=0.1, + min_heartbeat_interval=0.1, + kill_cursor_frequency=0.1, + events_queue_frequency=0.1, + ) + cls.knobs.enable() + + @classmethod + def tearDownClass(cls) -> None: + cls.knobs.disable() + async def asyncSetUp(self): # super call creates internal client cls.client await super().asyncSetUp() @@ -503,15 +518,6 @@ async def asyncSetUp(self): for address in async_client_context.mongoses: self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address))) - # Speed up the tests by decreasing the heartbeat frequency. - self.knobs = client_knobs( - heartbeat_frequency=0.1, - min_heartbeat_interval=0.1, - kill_cursor_frequency=0.1, - events_queue_frequency=0.1, - ) - self.knobs.enable() - # process schemaVersion # note: we check major schema version during class generation version = Version.from_string(self.TEST_SPEC["schemaVersion"]) @@ -525,7 +531,6 @@ async def asyncSetUp(self): self.match_evaluator = MatchEvaluatorUtil(self) async def asyncTearDown(self): - self.knobs.disable() for client in self.mongos_clients: await client.close() await super().asyncTearDown() diff --git a/test/unified_format.py b/test/unified_format.py index 51630ab16a..7d5c4e4e03 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -477,6 +477,21 @@ def insert_initial_data(self, initial_data): # Ensure collection exists db.create_collection(coll_name, write_concern=wc, **opts) + @classmethod + def setUpClass(cls) -> None: + # Speed up the tests by decreasing the heartbeat frequency. + cls.knobs = client_knobs( + heartbeat_frequency=0.1, + min_heartbeat_interval=0.1, + kill_cursor_frequency=0.1, + events_queue_frequency=0.1, + ) + cls.knobs.enable() + + @classmethod + def tearDownClass(cls) -> None: + cls.knobs.disable() + def setUp(self): # super call creates internal client cls.client super().setUp() @@ -502,15 +517,6 @@ def setUp(self): for address in client_context.mongoses: self.mongos_clients.append(self.single_client("{}:{}".format(*address))) - # Speed up the tests by decreasing the heartbeat frequency. - self.knobs = client_knobs( - heartbeat_frequency=0.1, - min_heartbeat_interval=0.1, - kill_cursor_frequency=0.1, - events_queue_frequency=0.1, - ) - self.knobs.enable() - # process schemaVersion # note: we check major schema version during class generation version = Version.from_string(self.TEST_SPEC["schemaVersion"]) @@ -524,7 +530,6 @@ def setUp(self): self.match_evaluator = MatchEvaluatorUtil(self) def tearDown(self): - self.knobs.disable() for client in self.mongos_clients: client.close() super().tearDown()