diff --git a/test/__init__.py b/test/__init__.py index c55eb74c9d..940518c2c5 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1115,26 +1115,10 @@ def enable_replication(self, client): class UnitTest(PyMongoTestCase): """Async base class for TestCases that don't require a connection to MongoDB.""" - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod - def _setup_class(cls): + def setUp(self) -> None: pass - @classmethod - def _tearDown_class(cls): + def tearDown(self) -> None: pass @@ -1145,37 +1129,20 @@ class IntegrationTest(PyMongoTestCase): db: Database credentials: Dict[str, str] - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod @client_context.require_connection - def _setup_class(cls): - if client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False): + def setUp(self) -> None: + if not _IS_SYNC: + reset_client_context() + if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") - if client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False): + if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False): raise SkipTest("this test does not support serverless") - cls.client = client_context.client - cls.db = cls.client.pymongo_test + self.client = client_context.client + self.db = self.client.pymongo_test if client_context.auth_enabled: - cls.credentials = {"username": db_user, "password": db_pwd} + self.credentials = {"username": db_user, "password": db_pwd} else: - cls.credentials = {} - - @classmethod - def _tearDown_class(cls): - pass + self.credentials = {} def cleanup_colls(self, *collections): """Cleanup collections faster than drop_collection.""" @@ -1201,37 +1168,14 @@ class MockClientTest(UnitTest): # MockClients tests that use replicaSet, directConnection=True, pass # multiple seed addresses, or wait for heartbeat events are incompatible # with loadBalanced=True. - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod @client_context.require_no_load_balancer - def _setup_class(cls): - pass - - @classmethod - def _tearDown_class(cls): - pass - - def setUp(self): + def setUp(self) -> None: super().setUp() self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001) - self.client_knobs.enable() - def tearDown(self): + def tearDown(self) -> None: self.client_knobs.disable() super().tearDown() diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 58e69c7c58..8d1e3e1911 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -1133,26 +1133,10 @@ async def enable_replication(self, client): class AsyncUnitTest(AsyncPyMongoTestCase): """Async base class for TestCases that don't require a connection to MongoDB.""" - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod - async def _setup_class(cls): + async def asyncSetUp(self) -> None: pass - @classmethod - async def _tearDown_class(cls): + async def asyncTearDown(self) -> None: pass @@ -1163,37 +1147,20 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase): db: AsyncDatabase credentials: Dict[str, str] - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod @async_client_context.require_connection - async def _setup_class(cls): - if async_client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False): + async def asyncSetUp(self) -> None: + if not _IS_SYNC: + await reset_client_context() + if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") - if async_client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False): + if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False): raise SkipTest("this test does not support serverless") - cls.client = async_client_context.client - cls.db = cls.client.pymongo_test + self.client = async_client_context.client + self.db = self.client.pymongo_test if async_client_context.auth_enabled: - cls.credentials = {"username": db_user, "password": db_pwd} + self.credentials = {"username": db_user, "password": db_pwd} else: - cls.credentials = {} - - @classmethod - async def _tearDown_class(cls): - pass + self.credentials = {} async def cleanup_colls(self, *collections): """Cleanup collections faster than drop_collection.""" @@ -1219,39 +1186,16 @@ class AsyncMockClientTest(AsyncUnitTest): # MockClients tests that use replicaSet, directConnection=True, pass # multiple seed addresses, or wait for heartbeat events are incompatible # with loadBalanced=True. - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod @async_client_context.require_no_load_balancer - async def _setup_class(cls): - pass - - @classmethod - async def _tearDown_class(cls): - pass - - def setUp(self): - super().setUp() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001) - self.client_knobs.enable() - def tearDown(self): + async def asyncTearDown(self) -> None: self.client_knobs.disable() - super().tearDown() + await super().asyncTearDown() async def async_setup(): diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index 42a3311072..e01dd53d7e 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -42,15 +42,11 @@ class AsyncBulkTestBase(AsyncIntegrationTest): coll: AsyncCollection coll_w0: AsyncCollection - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.coll = cls.db.test - cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0)) - async def asyncSetUp(self): - super().setUp() + await super().asyncSetUp() + self.coll = self.db.test await self.coll.drop() + self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0)) def assertEqualResponse(self, expected, actual): """Compare response from bulk.execute() to expected response.""" @@ -787,14 +783,10 @@ async def test_large_inserts_unordered(self): class AsyncBulkAuthorizationTestBase(AsyncBulkTestBase): - @classmethod @async_client_context.require_auth @async_client_context.require_no_api_version - async def _setup_class(cls): - await super()._setup_class() - async def asyncSetUp(self): - super().setUp() + await super().asyncSetUp() await async_client_context.create_user(self.db.name, "readonly", "pw", ["read"]) await self.db.command( "createRole", @@ -937,21 +929,19 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase): w: Optional[int] secondary: AsyncMongoClient - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.w = async_client_context.w - cls.secondary = None - if cls.w is not None and cls.w > 1: + async def asyncSetUp(self): + await super().asyncSetUp() + self.w = async_client_context.w + self.secondary = None + if self.w is not None and self.w > 1: for member in (await async_client_context.hello)["hosts"]: if member != (await async_client_context.hello)["primary"]: - cls.secondary = await cls.unmanaged_async_single_client(*partition_node(member)) + self.secondary = await self.async_single_client(*partition_node(member)) break - @classmethod - async def async_tearDownClass(cls): - if cls.secondary: - await cls.secondary.close() + async def asyncTearDown(self): + if self.secondary: + await self.secondary.close() async def cause_wtimeout(self, requests, ordered): if not async_client_context.test_commands_enabled: diff --git a/test/asynchronous/test_change_stream.py b/test/asynchronous/test_change_stream.py index 883ed72c4c..db8a74f55a 100644 --- a/test/asynchronous/test_change_stream.py +++ b/test/asynchronous/test_change_stream.py @@ -835,18 +835,16 @@ async def test_split_large_change(self): class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin): dbs: list - @classmethod @async_client_context.require_version_min(4, 0, 0, -1) @async_client_context.require_change_streams - async def _setup_class(cls): - await super()._setup_class() - cls.dbs = [cls.db, cls.client.pymongo_test_2] + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.dbs = [self.db, self.client.pymongo_test_2] - @classmethod - async def _tearDown_class(cls): - for db in cls.dbs: - await cls.client.drop_database(db) - await super()._tearDown_class() + async def asyncTearDown(self): + for db in self.dbs: + await self.client.drop_database(db) + await super().asyncTearDown() async def change_stream_with_client(self, client, *args, **kwargs): return await client.watch(*args, **kwargs) @@ -897,11 +895,10 @@ async def test_full_pipeline(self): class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin): - @classmethod @async_client_context.require_version_min(4, 0, 0, -1) @async_client_context.require_change_streams - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() async def change_stream_with_client(self, client, *args, **kwargs): return await client[self.db.name].watch(*args, **kwargs) @@ -987,12 +984,9 @@ async def test_isolation(self): class TestAsyncCollectionAsyncChangeStream( TestAsyncChangeStreamBase, APITestsMixin, ProseSpecTestsMixin ): - @classmethod @async_client_context.require_change_streams - async def _setup_class(cls): - await super()._setup_class() - async def asyncSetUp(self): + await super().asyncSetUp() # Use a new collection for each test. await self.watched_collection().drop() await self.watched_collection().insert_one({}) @@ -1132,20 +1126,11 @@ class TestAllLegacyScenarios(AsyncIntegrationTest): RUN_ON_LOAD_BALANCER = True listener: AllowListEventListener - @classmethod @async_client_context.require_connection - async def _setup_class(cls): - await super()._setup_class() - cls.listener = AllowListEventListener("aggregate", "getMore") - cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) - - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - await super()._tearDown_class() - - def asyncSetUp(self): - super().asyncSetUp() + async def asyncSetUp(self): + await super().asyncSetUp() + self.listener = AllowListEventListener("aggregate", "getMore") + self.client = await self.async_rs_or_single_client(event_listeners=[self.listener]) self.listener.reset() async def asyncSetUpCluster(self, scenario_dict): diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index ce396997e3..47cbff6d5b 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -130,16 +130,11 @@ class AsyncClientUnitTest(AsyncUnitTest): client: AsyncMongoClient - @classmethod - async def _setup_class(cls): - cls.client = await cls.unmanaged_async_rs_or_single_client( + async def asyncSetUp(self) -> None: + self.client = await self.async_rs_or_single_client( connect=False, serverSelectionTimeoutMS=100 ) - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): self._caplog = caplog diff --git a/test/asynchronous/test_collation.py b/test/asynchronous/test_collation.py index be3ea22e42..abbca1aff9 100644 --- a/test/asynchronous/test_collation.py +++ b/test/asynchronous/test_collation.py @@ -97,28 +97,22 @@ class TestCollation(AsyncIntegrationTest): warn_context: Any collation: Collation - @classmethod @async_client_context.require_connection - async def _setup_class(cls): - await super()._setup_class() - cls.listener = EventListener() - cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) - cls.db = cls.client.pymongo_test - cls.collation = Collation("en_US") - cls.warn_context = warnings.catch_warnings() - cls.warn_context.__enter__() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.listener = EventListener() + self.client = await self.async_rs_or_single_client(event_listeners=[self.listener]) + self.db = self.client.pymongo_test + self.collation = Collation("en_US") + self.warn_context = warnings.catch_warnings() + self.warn_context.__enter__() warnings.simplefilter("ignore", DeprecationWarning) - @classmethod - async def _tearDown_class(cls): - cls.warn_context.__exit__() - cls.warn_context = None - await cls.client.close() - await super()._tearDown_class() - - def tearDown(self): + async def asyncTearDown(self) -> None: + self.warn_context.__exit__() + self.warn_context = None self.listener.reset() - super().tearDown() + await super().asyncTearDown() def last_command_started(self): return self.listener.started_events[-1].command diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index 470425f4ce..a2ed4de388 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -86,14 +86,10 @@ class TestCollectionNoConnect(AsyncUnitTest): db: AsyncDatabase client: AsyncMongoClient - @classmethod - async def _setup_class(cls): - cls.client = AsyncMongoClient(connect=False) - cls.db = cls.client.pymongo_test - - @classmethod - async def _tearDown_class(cls): - await cls.client.close() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.client = self.simple_client(connect=False) + self.db = self.client.pymongo_test def test_collection(self): self.assertRaises(TypeError, AsyncCollection, self.db, 5) @@ -163,27 +159,14 @@ def test_iteration(self): class AsyncTestCollection(AsyncIntegrationTest): w: int - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.w = async_client_context.w # type: ignore - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine] - else: - asyncio.run(cls.async_tearDownClass()) - - @classmethod - async def async_tearDownClass(cls): - await cls.db.drop_collection("test_large_limit") - async def asyncSetUp(self): - await self.db.test.drop() + await super().asyncSetUp() + self.w = async_client_context.w # type: ignore async def asyncTearDown(self): await self.db.test.drop() + await self.db.drop_collection("test_large_limit") + await super().asyncTearDown() @contextlib.contextmanager def write_concern_collection(self): diff --git a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py index dc04cb28a7..ffff428379 100644 --- a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py +++ b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py @@ -44,30 +44,22 @@ class TestAsyncConnectionsSurvivePrimaryStepDown(AsyncIntegrationTest): listener: CMAPListener coll: AsyncCollection - @classmethod + async def asyncTearDown(self): + await reset_client_context() + @async_client_context.require_replica_set - async def _setup_class(cls): - await super()._setup_class() - cls.listener = CMAPListener() - cls.client = await cls.unmanaged_async_rs_or_single_client( - event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500 + async def asyncSetUp(self): + self.listener = CMAPListener() + self.client = await self.async_rs_or_single_client( + event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500 ) # Ensure connections to all servers in replica set. This is to test # that the is_writable flag is properly updated for connections that # survive a replica set election. - await async_ensure_all_connected(cls.client) - cls.listener.reset() - - cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority")) - cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority")) - - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - await reset_client_context() - - async def asyncSetUp(self): + await async_ensure_all_connected(self.client) + self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority")) + self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority")) # Note that all ops use same write-concern as self.db (majority). await self.db.drop_collection("step-down") await self.db.create_collection("step-down") diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index b1ca8855de..09955ca66f 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -1647,10 +1647,6 @@ async def test_monitoring(self): class TestRawBatchCommandCursor(AsyncIntegrationTest): - @classmethod - async def _setup_class(cls): - await super()._setup_class() - async def test_aggregate_raw(self): c = self.db.test await c.drop() diff --git a/test/asynchronous/test_database.py b/test/asynchronous/test_database.py index 61369c8542..b5a5960420 100644 --- a/test/asynchronous/test_database.py +++ b/test/asynchronous/test_database.py @@ -717,7 +717,8 @@ def test_with_options(self): class TestDatabaseAggregation(AsyncIntegrationTest): - def setUp(self): + async def asyncSetUp(self): + await super().asyncSetUp() self.pipeline: List[Mapping[str, Any]] = [ {"$listLocalSessions": {}}, {"$limit": 1}, diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 88b005c4b3..d75bad6862 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -211,11 +211,10 @@ async def test_kwargs(self): class AsyncEncryptionIntegrationTest(AsyncIntegrationTest): """Base class for encryption integration tests.""" - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @async_client_context.require_version_min(4, 2, -1) - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() def assertEncrypted(self, val): self.assertIsInstance(val, Binary) @@ -430,10 +429,9 @@ async def test_upsert_uuid_standard_encrypt(self): class TestClientMaxWireVersion(AsyncIntegrationTest): - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self): + await super().asyncSetUp() @async_client_context.require_version_max(4, 0, 99) async def test_raise_max_wire_version_error(self): @@ -818,17 +816,16 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest): "local": None, } - @classmethod @unittest.skipUnless( any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), "No environment credentials are set", ) - async def _setup_class(cls): - await super()._setup_class() - cls.listener = OvertCommandListener() - cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) - await cls.client.db.coll.drop() - cls.vault = await create_key_vault(cls.client.keyvault.datakeys) + async def asyncSetUp(self): + await super().asyncSetUp() + self.listener = OvertCommandListener() + self.client = await self.async_rs_or_single_client(event_listeners=[self.listener]) + await self.client.db.coll.drop() + self.vault = await create_key_vault(self.client.keyvault.datakeys) # Configure the encrypted field via the local schema_map option. schemas = { @@ -846,25 +843,22 @@ async def _setup_class(cls): } } opts = AutoEncryptionOpts( - cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS + self.KMS_PROVIDERS, + "keyvault.datakeys", + schema_map=schemas, + kms_tls_options=KMS_TLS_OPTS, ) - cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client( + self.client_encrypted = await self.async_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - cls.client_encryption = cls.unmanaged_create_client_encryption( - cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS + self.client_encryption = self.create_client_encryption( + self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS ) - - @classmethod - async def _tearDown_class(cls): - await cls.vault.drop() - await cls.client.close() - await cls.client_encrypted.close() - await cls.client_encryption.close() - - def setUp(self): self.listener.reset() + async def asyncTearDown(self) -> None: + await self.vault.drop() + async def run_test(self, provider_name): # Create data key. master_key: Any = self.MASTER_KEYS[provider_name] @@ -1011,10 +1005,9 @@ async def test_views_are_prohibited(self): class TestCorpus(AsyncEncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self): + await super().asyncSetUp() @staticmethod def kms_providers(): @@ -1188,12 +1181,11 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest): client_encrypted: AsyncMongoClient listener: OvertCommandListener - @classmethod - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self): + await super().asyncSetUp() db = async_client_context.client.db - cls.coll = db.coll - await cls.coll.drop() + self.coll = db.coll + await self.coll.drop() # Configure the encrypted 'db.coll' collection via jsonSchema. json_schema = json_data("limits", "limits-schema.json") await db.create_collection( @@ -1211,17 +1203,14 @@ async def _setup_class(cls): await coll.insert_one(json_data("limits", "limits-key.json")) opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") - cls.listener = OvertCommandListener() - cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client( - auto_encryption_opts=opts, event_listeners=[cls.listener] + self.listener = OvertCommandListener() + self.client_encrypted = await self.async_rs_or_single_client( + auto_encryption_opts=opts, event_listeners=[self.listener] ) - cls.coll_encrypted = cls.client_encrypted.db.coll + self.coll_encrypted = self.client_encrypted.db.coll - @classmethod - async def _tearDown_class(cls): - await cls.coll_encrypted.drop() - await cls.client_encrypted.close() - await super()._tearDown_class() + async def asyncTearDown(self) -> None: + await self.coll_encrypted.drop() async def test_01_insert_succeeds_under_2MiB(self): doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB} @@ -1285,15 +1274,12 @@ async def test_06_insert_fails_over_16MiB(self): class TestCustomEndpoint(AsyncEncryptionIntegrationTest): """Prose tests for creating data keys with a custom endpoint.""" - @classmethod @unittest.skipUnless( any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), "No environment credentials are set", ) - async def _setup_class(cls): - await super()._setup_class() - - def setUp(self): + async def asyncSetUp(self): + await super().asyncSetUp() kms_providers = { "aws": AWS_CREDS, "azure": AZURE_CREDS, @@ -1322,10 +1308,6 @@ def setUp(self): self._kmip_host_error = None self._invalid_host_error = None - async def asyncTearDown(self): - await self.client_encryption.close() - await self.client_encryption_invalid.close() - async def run_test_expected_success(self, provider_name, master_key): data_key_id = await self.client_encryption.create_data_key( provider_name, master_key=master_key @@ -1501,6 +1483,7 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest): client: AsyncMongoClient async def asyncSetUp(self): + self.client = self.simple_client() keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL) await create_key_vault(keyvault, self.DEK) @@ -1559,13 +1542,12 @@ async def _test_automatic(self, expectation_extjson, payload): class TestAzureEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set") - async def _setup_class(cls): - cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} - cls.DEK = json_data(BASE, "custom", "azure-dek.json") - cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") - await super()._setup_class() + async def asyncSetUp(self): + self.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} + self.DEK = json_data(BASE, "custom", "azure-dek.json") + self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") + await super().asyncSetUp() async def test_explicit(self): return await self._test_explicit( @@ -1585,13 +1567,12 @@ async def test_automatic(self): class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set") - async def _setup_class(cls): - cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} - cls.DEK = json_data(BASE, "custom", "gcp-dek.json") - cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") - await super()._setup_class() + async def asyncSetUp(self): + self.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} + self.DEK = json_data(BASE, "custom", "gcp-dek.json") + self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") + await super().asyncSetUp() async def test_explicit(self): return await self._test_explicit( @@ -3089,17 +3070,11 @@ class TestNoSessionsSupport(AsyncEncryptionIntegrationTest): mongocryptd_client: AsyncMongoClient MONGOCRYPTD_PORT = 27020 - @classmethod @unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed") - async def _setup_class(cls): - await super()._setup_class() - start_mongocryptd(cls.MONGOCRYPTD_PORT) - - @classmethod - async def _tearDown_class(cls): - await super()._tearDown_class() - async def asyncSetUp(self) -> None: + await super().asyncSetUp() + start_mongocryptd(self.MONGOCRYPTD_PORT) + self.listener = OvertCommandListener() self.mongocryptd_client = self.simple_client( f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener] diff --git a/test/asynchronous/test_grid_file.py b/test/asynchronous/test_grid_file.py index 9c57c15c5a..14446106e0 100644 --- a/test/asynchronous/test_grid_file.py +++ b/test/asynchronous/test_grid_file.py @@ -97,6 +97,7 @@ def test_grid_in_custom_opts(self): class AsyncTestGridFile(AsyncIntegrationTest): async def asyncSetUp(self): + await super().asyncSetUp() await self.cleanup_colls(self.db.fs.files, self.db.fs.chunks) async def test_basic(self): diff --git a/test/asynchronous/test_monitoring.py b/test/asynchronous/test_monitoring.py index b5d8708dc3..a5f991b2f0 100644 --- a/test/asynchronous/test_monitoring.py +++ b/test/asynchronous/test_monitoring.py @@ -51,22 +51,16 @@ class AsyncTestCommandMonitoring(AsyncIntegrationTest): listener: EventListener @classmethod - @async_client_context.require_connection - async def _setup_class(cls): - await super()._setup_class() + def setUpClass(cls) -> None: cls.listener = EventListener() - cls.client = await cls.unmanaged_async_rs_or_single_client( - event_listeners=[cls.listener], retryWrites=False - ) - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - await super()._tearDown_class() - - async def asyncTearDown(self): + @async_client_context.require_connection + async def asyncSetUp(self) -> None: + await super().asyncSetUp() self.listener.reset() - await super().asyncTearDown() + self.client = await self.async_rs_or_single_client( + event_listeners=[self.listener], retryWrites=False + ) async def test_started_simple(self): await self.client.pymongo_test.command("ping") @@ -1137,27 +1131,30 @@ class AsyncTestGlobalListener(AsyncIntegrationTest): saved_listeners: Any @classmethod - @async_client_context.require_connection - async def _setup_class(cls): - await super()._setup_class() + def setUpClass(cls) -> None: cls.listener = EventListener() # We plan to call register(), which internally modifies _LISTENERS. cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS) monitoring.register(cls.listener) - cls.client = await cls.unmanaged_async_single_client() - # Get one (authenticated) socket in the pool. - await cls.client.pymongo_test.command("ping") - - @classmethod - async def _tearDown_class(cls): - monitoring._LISTENERS = cls.saved_listeners - await cls.client.close() - await super()._tearDown_class() + @async_client_context.require_connection async def asyncSetUp(self): await super().asyncSetUp() + self.listener = EventListener() + # We plan to call register(), which internally modifies _LISTENERS. + self.saved_listeners = copy.deepcopy(monitoring._LISTENERS) + monitoring.register(self.listener) + self.client = await self.async_single_client() + # Get one (authenticated) socket in the pool. + await self.client.pymongo_test.command("ping") + + async def asyncTearDown(self) -> None: self.listener.reset() + @classmethod + def tearDownClass(cls): + monitoring._LISTENERS = cls.saved_listeners + async def test_simple(self): await self.client.pymongo_test.command("ping") started = self.listener.started_events[0] diff --git a/test/asynchronous/test_retryable_writes.py b/test/asynchronous/test_retryable_writes.py index accbbd003f..746f23ea48 100644 --- a/test/asynchronous/test_retryable_writes.py +++ b/test/asynchronous/test_retryable_writes.py @@ -133,34 +133,27 @@ class IgnoreDeprecationsTest(AsyncIntegrationTest): RUN_ON_SERVERLESS = True deprecation_filter: DeprecationFilter - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.deprecation_filter = DeprecationFilter() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.deprecation_filter = DeprecationFilter() - @classmethod - async def _tearDown_class(cls): - cls.deprecation_filter.stop() - await super()._tearDown_class() + async def asyncTearDown(self) -> None: + self.deprecation_filter.stop() class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest): knobs: client_knobs - @classmethod - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - cls.client = await cls.unmanaged_async_rs_or_single_client(retryWrites=True) - cls.db = cls.client.pymongo_test + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() + self.client = await self.async_rs_or_single_client(retryWrites=True) + self.db = self.client.pymongo_test - @classmethod - async def _tearDown_class(cls): - cls.knobs.disable() - await cls.client.close() - await super()._tearDown_class() + async def asyncTearDown(self) -> None: + self.knobs.disable() @async_client_context.require_no_standalone async def test_actionable_error_message(self): @@ -181,26 +174,18 @@ class TestRetryableWrites(IgnoreDeprecationsTest): listener: OvertCommandListener knobs: client_knobs - @classmethod @async_client_context.require_no_mmap - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - cls.listener = OvertCommandListener() - cls.client = await cls.unmanaged_async_rs_or_single_client( - retryWrites=True, event_listeners=[cls.listener] + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() + self.listener = OvertCommandListener() + self.client = await self.async_rs_or_single_client( + retryWrites=True, event_listeners=[self.listener] ) - cls.db = cls.client.pymongo_test + self.db = self.client.pymongo_test - @classmethod - async def _tearDown_class(cls): - cls.knobs.disable() - await cls.client.close() - await super()._tearDown_class() - - async def asyncSetUp(self): if async_client_context.is_rs and async_client_context.test_commands_enabled: await self.client.admin.command( SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")]) @@ -211,6 +196,7 @@ async def asyncTearDown(self): await self.client.admin.command( SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")]) ) + self.knobs.disable() async def test_supported_single_statement_no_retry(self): listener = OvertCommandListener() @@ -480,13 +466,12 @@ class TestWriteConcernError(AsyncIntegrationTest): RUN_ON_SERVERLESS = True fail_insert: dict - @classmethod @async_client_context.require_replica_set @async_client_context.require_no_mmap @async_client_context.require_failCommand_fail_point - async def _setup_class(cls): - await super()._setup_class() - cls.fail_insert = { + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.fail_insert = { "configureFailPoint": "failCommand", "mode": {"times": 2}, "data": { diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index c1dac6f56d..e424796ce0 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -81,36 +81,27 @@ class TestSession(AsyncIntegrationTest): client2: AsyncMongoClient sensitive_commands: Set[str] - @classmethod @async_client_context.require_sessions - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self): + await super().asyncSetUp() # Create a second client so we can make sure clients cannot share # sessions. - cls.client2 = await cls.unmanaged_async_rs_or_single_client() + self.client2 = await self.async_rs_or_single_client() # Redact no commands, so we can test user-admin commands have "lsid". - cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() + self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() monitoring._SENSITIVE_COMMANDS.clear() - @classmethod - async def _tearDown_class(cls): - monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands) - await cls.client2.close() - await super()._tearDown_class() - - async def asyncSetUp(self): self.listener = SessionTestListener() self.session_checker_listener = SessionTestListener() self.client = await self.async_rs_or_single_client( event_listeners=[self.listener, self.session_checker_listener] ) - self.addAsyncCleanup(self.client.close) self.db = self.client.pymongo_test self.initial_lsids = {s["id"] for s in session_ids(self.client)} async def asyncTearDown(self): - """All sessions used in the test must be returned to the pool.""" + monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands) await self.client.drop_database("pymongo_test") used_lsids = self.initial_lsids.copy() for event in self.session_checker_listener.started_events: @@ -120,6 +111,8 @@ async def asyncTearDown(self): current_lsids = {s["id"] for s in session_ids(self.client)} self.assertLessEqual(used_lsids, current_lsids) + await super().asyncTearDown() + async def _test_ops(self, client, *ops): listener = client.options.event_listeners[0] @@ -831,18 +824,11 @@ class TestCausalConsistency(AsyncUnitTest): listener: SessionTestListener client: AsyncMongoClient - @classmethod - async def _setup_class(cls): - cls.listener = SessionTestListener() - cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) - - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - @async_client_context.require_sessions async def asyncSetUp(self): await super().asyncSetUp() + self.listener = SessionTestListener() + self.client = await self.async_rs_or_single_client(event_listeners=[self.listener]) @async_client_context.require_no_standalone async def test_core(self): diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index 229046e79b..d11d0a9776 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -403,21 +403,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): class TestTransactionsConvenientAPI(AsyncTransactionsBase): - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.mongos_clients = [] + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.mongos_clients = [] if async_client_context.supports_transactions(): for address in async_client_context.mongoses: - cls.mongos_clients.append( - await cls.unmanaged_async_single_client("{}:{}".format(*address)) - ) - - @classmethod - async def _tearDown_class(cls): - for client in cls.mongos_clients: - await client.close() - await super()._tearDown_class() + self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address))) async def _set_fail_point(self, client, command_args): cmd = {"configureFailPoint": "failCommand"} diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 2ff38f06e9..f25e96e04d 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -479,54 +479,47 @@ async def insert_initial_data(self, initial_data): await db.create_collection(coll_name, write_concern=wc, **opts) @classmethod - async def _setup_class(cls): + def setUpClass(cls) -> None: + # Speed up the tests by decreasing the heartbeat frequency. + cls.knobs = client_knobs( + heartbeat_frequency=0.1, + min_heartbeat_interval=0.1, + kill_cursor_frequency=0.1, + events_queue_frequency=0.1, + ) + cls.knobs.enable() + + @classmethod + def tearDownClass(cls) -> None: + cls.knobs.disable() + + async def asyncSetUp(self): # super call creates internal client cls.client - await super()._setup_class() + await super().asyncSetUp() # process file-level runOnRequirements - run_on_spec = cls.TEST_SPEC.get("runOnRequirements", []) - if not await cls.should_run_on(run_on_spec): - raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied") + run_on_spec = self.TEST_SPEC.get("runOnRequirements", []) + if not await self.should_run_on(run_on_spec): + raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied") # add any special-casing for skipping tests here if async_client_context.storage_engine == "mmapv1": - if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str( - cls.TEST_PATH + if "retryable-writes" in self.TEST_SPEC["description"] or "retryable_writes" in str( + self.TEST_PATH ): raise unittest.SkipTest("MMAPv1 does not support retryWrites=True") # Handle mongos_clients for transactions tests. - cls.mongos_clients = [] + self.mongos_clients = [] if ( async_client_context.supports_transactions() and not async_client_context.load_balancer and not async_client_context.serverless ): for address in async_client_context.mongoses: - cls.mongos_clients.append( - await cls.unmanaged_async_single_client("{}:{}".format(*address)) - ) + self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address))) - # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs( - heartbeat_frequency=0.1, - min_heartbeat_interval=0.1, - kill_cursor_frequency=0.1, - events_queue_frequency=0.1, - ) - cls.knobs.enable() - - @classmethod - async def _tearDown_class(cls): - cls.knobs.disable() - for client in cls.mongos_clients: - await client.close() - await super()._tearDown_class() - - async def asyncSetUp(self): - await super().asyncSetUp() # process schemaVersion # note: we check major schema version during class generation - # note: we do this here because we cannot run assertions in setUpClass version = Version.from_string(self.TEST_SPEC["schemaVersion"]) self.assertLessEqual( version, @@ -537,6 +530,11 @@ async def asyncSetUp(self): # initialize internals self.match_evaluator = MatchEvaluatorUtil(self) + async def asyncTearDown(self): + for client in self.mongos_clients: + await client.close() + await super().asyncTearDown() + def maybe_skip_test(self, spec): # add any special-casing for skipping tests here if async_client_context.storage_engine == "mmapv1": diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 4d9c4c8f20..f0463244d7 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -249,30 +249,24 @@ class AsyncSpecRunner(AsyncIntegrationTest): knobs: client_knobs listener: EventListener - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.mongos_clients = [] + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.mongos_clients = [] # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - - @classmethod - async def _tearDown_class(cls): - cls.knobs.disable() - for client in cls.mongos_clients: - await client.close() - await super()._tearDown_class() - - def setUp(self): - super().setUp() + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() self.targets = {} self.listener = None # type: ignore self.pool_listener = None self.server_listener = None self.maxDiff = None + async def asyncTearDown(self) -> None: + self.knobs.disable() + for client in self.mongos_clients: + await client.close() + async def _set_fail_point(self, client, command_args): cmd = SON([("configureFailPoint", "failCommand")]) cmd.update(command_args) diff --git a/test/test_bulk.py b/test/test_bulk.py index 64fd48e8cd..ad22c1ce9a 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -42,15 +42,11 @@ class BulkTestBase(IntegrationTest): coll: Collection coll_w0: Collection - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.coll = cls.db.test - cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0)) - def setUp(self): super().setUp() + self.coll = self.db.test self.coll.drop() + self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0)) def assertEqualResponse(self, expected, actual): """Compare response from bulk.execute() to expected response.""" @@ -785,12 +781,8 @@ def test_large_inserts_unordered(self): class BulkAuthorizationTestBase(BulkTestBase): - @classmethod @client_context.require_auth @client_context.require_no_api_version - def _setup_class(cls): - super()._setup_class() - def setUp(self): super().setUp() client_context.create_user(self.db.name, "readonly", "pw", ["read"]) @@ -935,21 +927,19 @@ class TestBulkWriteConcern(BulkTestBase): w: Optional[int] secondary: MongoClient - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.w = client_context.w - cls.secondary = None - if cls.w is not None and cls.w > 1: + def setUp(self): + super().setUp() + self.w = client_context.w + self.secondary = None + if self.w is not None and self.w > 1: for member in (client_context.hello)["hosts"]: if member != (client_context.hello)["primary"]: - cls.secondary = cls.unmanaged_single_client(*partition_node(member)) + self.secondary = self.single_client(*partition_node(member)) break - @classmethod - def async_tearDownClass(cls): - if cls.secondary: - cls.secondary.close() + def tearDown(self): + if self.secondary: + self.secondary.close() def cause_wtimeout(self, requests, ordered): if not client_context.test_commands_enabled: diff --git a/test/test_change_stream.py b/test/test_change_stream.py index dae224c5e0..0742384184 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -819,18 +819,16 @@ def test_split_large_change(self): class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin): dbs: list - @classmethod @client_context.require_version_min(4, 0, 0, -1) @client_context.require_change_streams - def _setup_class(cls): - super()._setup_class() - cls.dbs = [cls.db, cls.client.pymongo_test_2] + def setUp(self) -> None: + super().setUp() + self.dbs = [self.db, self.client.pymongo_test_2] - @classmethod - def _tearDown_class(cls): - for db in cls.dbs: - cls.client.drop_database(db) - super()._tearDown_class() + def tearDown(self): + for db in self.dbs: + self.client.drop_database(db) + super().tearDown() def change_stream_with_client(self, client, *args, **kwargs): return client.watch(*args, **kwargs) @@ -881,11 +879,10 @@ def test_full_pipeline(self): class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin): - @classmethod @client_context.require_version_min(4, 0, 0, -1) @client_context.require_change_streams - def _setup_class(cls): - super()._setup_class() + def setUp(self) -> None: + super().setUp() def change_stream_with_client(self, client, *args, **kwargs): return client[self.db.name].watch(*args, **kwargs) @@ -967,12 +964,9 @@ def test_isolation(self): class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, ProseSpecTestsMixin): - @classmethod @client_context.require_change_streams - def _setup_class(cls): - super()._setup_class() - def setUp(self): + super().setUp() # Use a new collection for each test. self.watched_collection().drop() self.watched_collection().insert_one({}) @@ -1110,20 +1104,11 @@ class TestAllLegacyScenarios(IntegrationTest): RUN_ON_LOAD_BALANCER = True listener: AllowListEventListener - @classmethod @client_context.require_connection - def _setup_class(cls): - super()._setup_class() - cls.listener = AllowListEventListener("aggregate", "getMore") - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - - @classmethod - def _tearDown_class(cls): - cls.client.close() - super()._tearDown_class() - def setUp(self): super().setUp() + self.listener = AllowListEventListener("aggregate", "getMore") + self.client = self.rs_or_single_client(event_listeners=[self.listener]) self.listener.reset() def setUpCluster(self, scenario_dict): diff --git a/test/test_client.py b/test/test_client.py index 07f3e560fe..d41b0bbfda 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -129,13 +129,8 @@ class ClientUnitTest(UnitTest): client: MongoClient - @classmethod - def _setup_class(cls): - cls.client = cls.unmanaged_rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) - - @classmethod - def _tearDown_class(cls): - cls.client.close() + def setUp(self) -> None: + self.client = self.rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): diff --git a/test/test_collation.py b/test/test_collation.py index e5c1c7eb11..6d4e958a1f 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -97,26 +97,20 @@ class TestCollation(IntegrationTest): warn_context: Any collation: Collation - @classmethod @client_context.require_connection - def _setup_class(cls): - super()._setup_class() - cls.listener = EventListener() - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - cls.db = cls.client.pymongo_test - cls.collation = Collation("en_US") - cls.warn_context = warnings.catch_warnings() - cls.warn_context.__enter__() + def setUp(self) -> None: + super().setUp() + self.listener = EventListener() + self.client = self.rs_or_single_client(event_listeners=[self.listener]) + self.db = self.client.pymongo_test + self.collation = Collation("en_US") + self.warn_context = warnings.catch_warnings() + self.warn_context.__enter__() warnings.simplefilter("ignore", DeprecationWarning) - @classmethod - def _tearDown_class(cls): - cls.warn_context.__exit__() - cls.warn_context = None - cls.client.close() - super()._tearDown_class() - - def tearDown(self): + def tearDown(self) -> None: + self.warn_context.__exit__() + self.warn_context = None self.listener.reset() super().tearDown() diff --git a/test/test_collection.py b/test/test_collection.py index f2f01ac686..9364d34e34 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -86,14 +86,10 @@ class TestCollectionNoConnect(UnitTest): db: Database client: MongoClient - @classmethod - def _setup_class(cls): - cls.client = MongoClient(connect=False) - cls.db = cls.client.pymongo_test - - @classmethod - def _tearDown_class(cls): - cls.client.close() + def setUp(self) -> None: + super().setUp() + self.client = self.simple_client(connect=False) + self.db = self.client.pymongo_test def test_collection(self): self.assertRaises(TypeError, Collection, self.db, 5) @@ -163,27 +159,14 @@ def test_iteration(self): class TestCollection(IntegrationTest): w: int - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.w = client_context.w # type: ignore - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine] - else: - asyncio.run(cls.async_tearDownClass()) - - @classmethod - def async_tearDownClass(cls): - cls.db.drop_collection("test_large_limit") - def setUp(self): - self.db.test.drop() + super().setUp() + self.w = client_context.w # type: ignore def tearDown(self): self.db.test.drop() + self.db.drop_collection("test_large_limit") + super().tearDown() @contextlib.contextmanager def write_concern_collection(self): diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index 984d700fb3..4387850a00 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -44,30 +44,22 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): listener: CMAPListener coll: Collection - @classmethod + def tearDown(self): + reset_client_context() + @client_context.require_replica_set - def _setup_class(cls): - super()._setup_class() - cls.listener = CMAPListener() - cls.client = cls.unmanaged_rs_or_single_client( - event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500 + def setUp(self): + self.listener = CMAPListener() + self.client = self.rs_or_single_client( + event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500 ) # Ensure connections to all servers in replica set. This is to test # that the is_writable flag is properly updated for connections that # survive a replica set election. - ensure_all_connected(cls.client) - cls.listener.reset() - - cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority")) - cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority")) - - @classmethod - def _tearDown_class(cls): - cls.client.close() - reset_client_context() - - def setUp(self): + ensure_all_connected(self.client) + self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority")) + self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority")) # Note that all ops use same write-concern as self.db (majority). self.db.drop_collection("step-down") self.db.create_collection("step-down") diff --git a/test/test_cursor.py b/test/test_cursor.py index 7a6dfc9429..e687abcfbf 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1636,10 +1636,6 @@ def test_monitoring(self): class TestRawBatchCommandCursor(IntegrationTest): - @classmethod - def _setup_class(cls): - super()._setup_class() - def test_aggregate_raw(self): c = self.db.test c.drop() diff --git a/test/test_custom_types.py b/test/test_custom_types.py index abaa820cb7..6771ea25f9 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -633,6 +633,7 @@ class MyType(pytype): # type: ignore class TestCollectionWCustomType(IntegrationTest): def setUp(self): + super().setUp() self.db.test.drop() def tearDown(self): @@ -754,6 +755,7 @@ def test_find_one_and__w_custom_type_decoder(self): class TestGridFileCustomType(IntegrationTest): def setUp(self): + super().setUp() self.db.drop_collection("fs.files") self.db.drop_collection("fs.chunks") @@ -917,11 +919,10 @@ def run_test(doc_cls): class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): - @classmethod @client_context.require_change_streams - def setUpClass(cls): - super().setUpClass() - cls.db.test.delete_many({}) + def setUp(self): + super().setUp() + self.db.test.delete_many({}) def tearDown(self): self.input_target.drop() @@ -935,12 +936,11 @@ def create_targets(self, *args, **kwargs): class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): - @classmethod @client_context.require_version_min(4, 0, 0) @client_context.require_change_streams - def setUpClass(cls): - super().setUpClass() - cls.db.test.delete_many({}) + def setUp(self): + super().setUp() + self.db.test.delete_many({}) def tearDown(self): self.input_target.drop() @@ -954,12 +954,11 @@ def create_targets(self, *args, **kwargs): class TestClusterChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): - @classmethod @client_context.require_version_min(4, 0, 0) @client_context.require_change_streams - def setUpClass(cls): - super().setUpClass() - cls.db.test.delete_many({}) + def setUp(self): + super().setUp() + self.db.test.delete_many({}) def tearDown(self): self.input_target.drop() diff --git a/test/test_database.py b/test/test_database.py index 4973ed0134..5e854c941d 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -709,6 +709,7 @@ def test_with_options(self): class TestDatabaseAggregation(IntegrationTest): def setUp(self): + super().setUp() self.pipeline: List[Mapping[str, Any]] = [ {"$listLocalSessions": {}}, {"$limit": 1}, diff --git a/test/test_encryption.py b/test/test_encryption.py index 13a69ca9ad..3749354217 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -211,11 +211,10 @@ def test_kwargs(self): class EncryptionIntegrationTest(IntegrationTest): """Base class for encryption integration tests.""" - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @client_context.require_version_min(4, 2, -1) - def _setup_class(cls): - super()._setup_class() + def setUp(self) -> None: + super().setUp() def assertEncrypted(self, val): self.assertIsInstance(val, Binary) @@ -430,10 +429,9 @@ def test_upsert_uuid_standard_encrypt(self): class TestClientMaxWireVersion(IntegrationTest): - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") - def _setup_class(cls): - super()._setup_class() + def setUp(self): + super().setUp() @client_context.require_version_max(4, 0, 99) def test_raise_max_wire_version_error(self): @@ -816,17 +814,16 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest): "local": None, } - @classmethod @unittest.skipUnless( any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), "No environment credentials are set", ) - def _setup_class(cls): - super()._setup_class() - cls.listener = OvertCommandListener() - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - cls.client.db.coll.drop() - cls.vault = create_key_vault(cls.client.keyvault.datakeys) + def setUp(self): + super().setUp() + self.listener = OvertCommandListener() + self.client = self.rs_or_single_client(event_listeners=[self.listener]) + self.client.db.coll.drop() + self.vault = create_key_vault(self.client.keyvault.datakeys) # Configure the encrypted field via the local schema_map option. schemas = { @@ -844,25 +841,22 @@ def _setup_class(cls): } } opts = AutoEncryptionOpts( - cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS + self.KMS_PROVIDERS, + "keyvault.datakeys", + schema_map=schemas, + kms_tls_options=KMS_TLS_OPTS, ) - cls.client_encrypted = cls.unmanaged_rs_or_single_client( + self.client_encrypted = self.rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - cls.client_encryption = cls.unmanaged_create_client_encryption( - cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS + self.client_encryption = self.create_client_encryption( + self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS ) - - @classmethod - def _tearDown_class(cls): - cls.vault.drop() - cls.client.close() - cls.client_encrypted.close() - cls.client_encryption.close() - - def setUp(self): self.listener.reset() + def tearDown(self) -> None: + self.vault.drop() + def run_test(self, provider_name): # Create data key. master_key: Any = self.MASTER_KEYS[provider_name] @@ -1007,10 +1001,9 @@ def test_views_are_prohibited(self): class TestCorpus(EncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") - def _setup_class(cls): - super()._setup_class() + def setUp(self): + super().setUp() @staticmethod def kms_providers(): @@ -1184,12 +1177,11 @@ class TestBsonSizeBatches(EncryptionIntegrationTest): client_encrypted: MongoClient listener: OvertCommandListener - @classmethod - def _setup_class(cls): - super()._setup_class() + def setUp(self): + super().setUp() db = client_context.client.db - cls.coll = db.coll - cls.coll.drop() + self.coll = db.coll + self.coll.drop() # Configure the encrypted 'db.coll' collection via jsonSchema. json_schema = json_data("limits", "limits-schema.json") db.create_collection( @@ -1207,17 +1199,14 @@ def _setup_class(cls): coll.insert_one(json_data("limits", "limits-key.json")) opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") - cls.listener = OvertCommandListener() - cls.client_encrypted = cls.unmanaged_rs_or_single_client( - auto_encryption_opts=opts, event_listeners=[cls.listener] + self.listener = OvertCommandListener() + self.client_encrypted = self.rs_or_single_client( + auto_encryption_opts=opts, event_listeners=[self.listener] ) - cls.coll_encrypted = cls.client_encrypted.db.coll + self.coll_encrypted = self.client_encrypted.db.coll - @classmethod - def _tearDown_class(cls): - cls.coll_encrypted.drop() - cls.client_encrypted.close() - super()._tearDown_class() + def tearDown(self) -> None: + self.coll_encrypted.drop() def test_01_insert_succeeds_under_2MiB(self): doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB} @@ -1281,15 +1270,12 @@ def test_06_insert_fails_over_16MiB(self): class TestCustomEndpoint(EncryptionIntegrationTest): """Prose tests for creating data keys with a custom endpoint.""" - @classmethod @unittest.skipUnless( any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), "No environment credentials are set", ) - def _setup_class(cls): - super()._setup_class() - def setUp(self): + super().setUp() kms_providers = { "aws": AWS_CREDS, "azure": AZURE_CREDS, @@ -1318,10 +1304,6 @@ def setUp(self): self._kmip_host_error = None self._invalid_host_error = None - def tearDown(self): - self.client_encryption.close() - self.client_encryption_invalid.close() - def run_test_expected_success(self, provider_name, master_key): data_key_id = self.client_encryption.create_data_key(provider_name, master_key=master_key) encrypted = self.client_encryption.encrypt( @@ -1495,6 +1477,7 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest): client: MongoClient def setUp(self): + self.client = self.simple_client() keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL) create_key_vault(keyvault, self.DEK) @@ -1553,13 +1536,12 @@ def _test_automatic(self, expectation_extjson, payload): class TestAzureEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set") - def _setup_class(cls): - cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} - cls.DEK = json_data(BASE, "custom", "azure-dek.json") - cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") - super()._setup_class() + def setUp(self): + self.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} + self.DEK = json_data(BASE, "custom", "azure-dek.json") + self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") + super().setUp() def test_explicit(self): return self._test_explicit( @@ -1579,13 +1561,12 @@ def test_automatic(self): class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set") - def _setup_class(cls): - cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} - cls.DEK = json_data(BASE, "custom", "gcp-dek.json") - cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") - super()._setup_class() + def setUp(self): + self.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} + self.DEK = json_data(BASE, "custom", "gcp-dek.json") + self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") + super().setUp() def test_explicit(self): return self._test_explicit( @@ -3071,17 +3052,11 @@ class TestNoSessionsSupport(EncryptionIntegrationTest): mongocryptd_client: MongoClient MONGOCRYPTD_PORT = 27020 - @classmethod @unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed") - def _setup_class(cls): - super()._setup_class() - start_mongocryptd(cls.MONGOCRYPTD_PORT) - - @classmethod - def _tearDown_class(cls): - super()._tearDown_class() - def setUp(self) -> None: + super().setUp() + start_mongocryptd(self.MONGOCRYPTD_PORT) + self.listener = OvertCommandListener() self.mongocryptd_client = self.simple_client( f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener] diff --git a/test/test_examples.py b/test/test_examples.py index ebf1d784a3..7f98226e7a 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -33,19 +33,14 @@ class TestSampleShellCommands(IntegrationTest): - @classmethod - def setUpClass(cls): - super().setUpClass() - # Run once before any tests run. - cls.db.inventory.drop() - - @classmethod - def tearDownClass(cls): - cls.client.drop_database("pymongo_test") + def setUp(self): + super().setUp() + self.db.inventory.drop() def tearDown(self): # Run after every test. self.db.inventory.drop() + self.client.drop_database("pymongo_test") def test_first_three_examples(self): db = self.db diff --git a/test/test_grid_file.py b/test/test_grid_file.py index fe88aec5ff..0a5b1ad40a 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -97,6 +97,7 @@ def test_grid_in_custom_opts(self): class TestGridFile(IntegrationTest): def setUp(self): + super().setUp() self.cleanup_colls(self.db.fs.files, self.db.fs.chunks) def test_basic(self): diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 549dc0b204..a36109f399 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -75,9 +75,9 @@ def run(self): class TestGridfsNoConnect(unittest.TestCase): db: Database - @classmethod - def setUpClass(cls): - cls.db = MongoClient(connect=False).pymongo_test + def setUp(self): + super().setUp() + self.db = MongoClient(connect=False).pymongo_test def test_gridfs(self): self.assertRaises(TypeError, gridfs.GridFS, "foo") @@ -88,13 +88,10 @@ class TestGridfs(IntegrationTest): fs: gridfs.GridFS alt: gridfs.GridFS - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.fs = gridfs.GridFS(cls.db) - cls.alt = gridfs.GridFS(cls.db, "alt") - def setUp(self): + super().setUp() + self.fs = gridfs.GridFS(self.db) + self.alt = gridfs.GridFS(self.db, "alt") self.cleanup_colls( self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks ) @@ -509,10 +506,9 @@ def test_md5(self): class TestGridfsReplicaSet(IntegrationTest): - @classmethod @client_context.require_secondaries_count(1) - def setUpClass(cls): - super().setUpClass() + def setUp(self): + super().setUp() @classmethod def tearDownClass(cls): diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index 28adb7051a..04c7427350 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -79,13 +79,10 @@ class TestGridfs(IntegrationTest): fs: gridfs.GridFSBucket alt: gridfs.GridFSBucket - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.fs = gridfs.GridFSBucket(cls.db) - cls.alt = gridfs.GridFSBucket(cls.db, bucket_name="alt") - def setUp(self): + super().setUp() + self.fs = gridfs.GridFSBucket(self.db) + self.alt = gridfs.GridFSBucket(self.db, bucket_name="alt") self.cleanup_colls( self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks ) @@ -479,10 +476,9 @@ def test_md5(self): class TestGridfsBucketReplicaSet(IntegrationTest): - @classmethod @client_context.require_secondaries_count(1) - def setUpClass(cls): - super().setUpClass() + def setUp(self): + super().setUp() @classmethod def tearDownClass(cls): diff --git a/test/test_monitoring.py b/test/test_monitoring.py index a0c520ed27..31f546fe54 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -51,22 +51,14 @@ class TestCommandMonitoring(IntegrationTest): listener: EventListener @classmethod - @client_context.require_connection - def _setup_class(cls): - super()._setup_class() + def setUpClass(cls) -> None: cls.listener = EventListener() - cls.client = cls.unmanaged_rs_or_single_client( - event_listeners=[cls.listener], retryWrites=False - ) - @classmethod - def _tearDown_class(cls): - cls.client.close() - super()._tearDown_class() - - def tearDown(self): + @client_context.require_connection + def setUp(self) -> None: + super().setUp() self.listener.reset() - super().tearDown() + self.client = self.rs_or_single_client(event_listeners=[self.listener], retryWrites=False) def test_started_simple(self): self.client.pymongo_test.command("ping") @@ -1137,27 +1129,30 @@ class TestGlobalListener(IntegrationTest): saved_listeners: Any @classmethod - @client_context.require_connection - def _setup_class(cls): - super()._setup_class() + def setUpClass(cls) -> None: cls.listener = EventListener() # We plan to call register(), which internally modifies _LISTENERS. cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS) monitoring.register(cls.listener) - cls.client = cls.unmanaged_single_client() - # Get one (authenticated) socket in the pool. - cls.client.pymongo_test.command("ping") - - @classmethod - def _tearDown_class(cls): - monitoring._LISTENERS = cls.saved_listeners - cls.client.close() - super()._tearDown_class() + @client_context.require_connection def setUp(self): super().setUp() + self.listener = EventListener() + # We plan to call register(), which internally modifies _LISTENERS. + self.saved_listeners = copy.deepcopy(monitoring._LISTENERS) + monitoring.register(self.listener) + self.client = self.single_client() + # Get one (authenticated) socket in the pool. + self.client.pymongo_test.command("ping") + + def tearDown(self) -> None: self.listener.reset() + @classmethod + def tearDownClass(cls): + monitoring._LISTENERS = cls.saved_listeners + def test_simple(self): self.client.pymongo_test.command("ping") started = self.listener.started_events[0] diff --git a/test/test_read_concern.py b/test/test_read_concern.py index ea9ce49a30..f7c0901422 100644 --- a/test/test_read_concern.py +++ b/test/test_read_concern.py @@ -31,24 +31,16 @@ class TestReadConcern(IntegrationTest): listener: OvertCommandListener - @classmethod @client_context.require_connection - def setUpClass(cls): - super().setUpClass() - cls.listener = OvertCommandListener() - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - cls.db = cls.client.pymongo_test + def setUp(self): + super().setUp() + self.listener = OvertCommandListener() + self.client = self.rs_or_single_client(event_listeners=[self.listener]) + self.db = self.client.pymongo_test client_context.client.pymongo_test.create_collection("coll") - @classmethod - def tearDownClass(cls): - cls.client.close() - client_context.client.pymongo_test.drop_collection("coll") - super().tearDownClass() - def tearDown(self): - self.listener.reset() - super().tearDown() + client_context.client.pymongo_test.drop_collection("coll") def test_read_concern(self): rc = ReadConcern() diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 5df6c41f7a..eb814c4ef9 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -133,34 +133,27 @@ class IgnoreDeprecationsTest(IntegrationTest): RUN_ON_SERVERLESS = True deprecation_filter: DeprecationFilter - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.deprecation_filter = DeprecationFilter() + def setUp(self) -> None: + super().setUp() + self.deprecation_filter = DeprecationFilter() - @classmethod - def _tearDown_class(cls): - cls.deprecation_filter.stop() - super()._tearDown_class() + def tearDown(self) -> None: + self.deprecation_filter.stop() class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest): knobs: client_knobs - @classmethod - def _setup_class(cls): - super()._setup_class() + def setUp(self) -> None: + super().setUp() # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - cls.client = cls.unmanaged_rs_or_single_client(retryWrites=True) - cls.db = cls.client.pymongo_test + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() + self.client = self.rs_or_single_client(retryWrites=True) + self.db = self.client.pymongo_test - @classmethod - def _tearDown_class(cls): - cls.knobs.disable() - cls.client.close() - super()._tearDown_class() + def tearDown(self) -> None: + self.knobs.disable() @client_context.require_no_standalone def test_actionable_error_message(self): @@ -181,26 +174,16 @@ class TestRetryableWrites(IgnoreDeprecationsTest): listener: OvertCommandListener knobs: client_knobs - @classmethod @client_context.require_no_mmap - def _setup_class(cls): - super()._setup_class() + def setUp(self) -> None: + super().setUp() # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - cls.listener = OvertCommandListener() - cls.client = cls.unmanaged_rs_or_single_client( - retryWrites=True, event_listeners=[cls.listener] - ) - cls.db = cls.client.pymongo_test - - @classmethod - def _tearDown_class(cls): - cls.knobs.disable() - cls.client.close() - super()._tearDown_class() + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() + self.listener = OvertCommandListener() + self.client = self.rs_or_single_client(retryWrites=True, event_listeners=[self.listener]) + self.db = self.client.pymongo_test - def setUp(self): if client_context.is_rs and client_context.test_commands_enabled: self.client.admin.command( SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")]) @@ -211,6 +194,7 @@ def tearDown(self): self.client.admin.command( SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")]) ) + self.knobs.disable() def test_supported_single_statement_no_retry(self): listener = OvertCommandListener() @@ -480,13 +464,12 @@ class TestWriteConcernError(IntegrationTest): RUN_ON_SERVERLESS = True fail_insert: dict - @classmethod @client_context.require_replica_set @client_context.require_no_mmap @client_context.require_failCommand_fail_point - def _setup_class(cls): - super()._setup_class() - cls.fail_insert = { + def setUp(self) -> None: + super().setUp() + self.fail_insert = { "configureFailPoint": "failCommand", "mode": {"times": 2}, "data": { diff --git a/test/test_sdam_monitoring_spec.py b/test/test_sdam_monitoring_spec.py index 81b208d511..6b808b159d 100644 --- a/test/test_sdam_monitoring_spec.py +++ b/test/test_sdam_monitoring_spec.py @@ -270,7 +270,7 @@ class TestSdamMonitoring(IntegrationTest): @classmethod @client_context.require_failCommand_fail_point def setUpClass(cls): - super().setUpClass() + super().setUp(cls) # Speed up the tests by decreasing the event publish frequency. cls.knobs = client_knobs( events_queue_frequency=0.1, heartbeat_frequency=0.1, min_heartbeat_interval=0.1 diff --git a/test/test_session.py b/test/test_session.py index 9f94ded927..980d9df688 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -81,36 +81,27 @@ class TestSession(IntegrationTest): client2: MongoClient sensitive_commands: Set[str] - @classmethod @client_context.require_sessions - def _setup_class(cls): - super()._setup_class() + def setUp(self): + super().setUp() # Create a second client so we can make sure clients cannot share # sessions. - cls.client2 = cls.unmanaged_rs_or_single_client() + self.client2 = self.rs_or_single_client() # Redact no commands, so we can test user-admin commands have "lsid". - cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() + self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() monitoring._SENSITIVE_COMMANDS.clear() - @classmethod - def _tearDown_class(cls): - monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands) - cls.client2.close() - super()._tearDown_class() - - def setUp(self): self.listener = SessionTestListener() self.session_checker_listener = SessionTestListener() self.client = self.rs_or_single_client( event_listeners=[self.listener, self.session_checker_listener] ) - self.addCleanup(self.client.close) self.db = self.client.pymongo_test self.initial_lsids = {s["id"] for s in session_ids(self.client)} def tearDown(self): - """All sessions used in the test must be returned to the pool.""" + monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands) self.client.drop_database("pymongo_test") used_lsids = self.initial_lsids.copy() for event in self.session_checker_listener.started_events: @@ -120,6 +111,8 @@ def tearDown(self): current_lsids = {s["id"] for s in session_ids(self.client)} self.assertLessEqual(used_lsids, current_lsids) + super().tearDown() + def _test_ops(self, client, *ops): listener = client.options.event_listeners[0] @@ -831,18 +824,11 @@ class TestCausalConsistency(UnitTest): listener: SessionTestListener client: MongoClient - @classmethod - def _setup_class(cls): - cls.listener = SessionTestListener() - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - - @classmethod - def _tearDown_class(cls): - cls.client.close() - @client_context.require_sessions def setUp(self): super().setUp() + self.listener = SessionTestListener() + self.client = self.rs_or_single_client(event_listeners=[self.listener]) @client_context.require_no_standalone def test_core(self): diff --git a/test/test_threads.py b/test/test_threads.py index b3dadbb1a3..3e469e28fe 100644 --- a/test/test_threads.py +++ b/test/test_threads.py @@ -105,6 +105,7 @@ def run(self): class TestThreads(IntegrationTest): def setUp(self): + super().setUp() self.db = self.client.pymongo_test def test_threading(self): diff --git a/test/test_transactions.py b/test/test_transactions.py index 3cecbe9d38..949b88e60b 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -395,19 +395,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): class TestTransactionsConvenientAPI(TransactionsBase): - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.mongos_clients = [] + def setUp(self) -> None: + super().setUp() + self.mongos_clients = [] if client_context.supports_transactions(): for address in client_context.mongoses: - cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address))) - - @classmethod - def _tearDown_class(cls): - for client in cls.mongos_clients: - client.close() - super()._tearDown_class() + self.mongos_clients.append(self.single_client("{}:{}".format(*address))) def _set_fail_point(self, client, command_args): cmd = {"configureFailPoint": "failCommand"} diff --git a/test/test_typing.py b/test/test_typing.py index 441707616e..bfe4d032c1 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -114,10 +114,9 @@ def test_mypy_failures(self) -> None: class TestPymongo(IntegrationTest): coll: Collection - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.coll = cls.client.test.test + def setUp(self): + super().setUp() + self.coll = self.client.test.test def test_insert_find(self) -> None: doc = {"my": "doc"} diff --git a/test/unified_format.py b/test/unified_format.py index 13ab0af69b..7d5c4e4e03 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -478,52 +478,47 @@ def insert_initial_data(self, initial_data): db.create_collection(coll_name, write_concern=wc, **opts) @classmethod - def _setup_class(cls): + def setUpClass(cls) -> None: + # Speed up the tests by decreasing the heartbeat frequency. + cls.knobs = client_knobs( + heartbeat_frequency=0.1, + min_heartbeat_interval=0.1, + kill_cursor_frequency=0.1, + events_queue_frequency=0.1, + ) + cls.knobs.enable() + + @classmethod + def tearDownClass(cls) -> None: + cls.knobs.disable() + + def setUp(self): # super call creates internal client cls.client - super()._setup_class() + super().setUp() # process file-level runOnRequirements - run_on_spec = cls.TEST_SPEC.get("runOnRequirements", []) - if not cls.should_run_on(run_on_spec): - raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied") + run_on_spec = self.TEST_SPEC.get("runOnRequirements", []) + if not self.should_run_on(run_on_spec): + raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied") # add any special-casing for skipping tests here if client_context.storage_engine == "mmapv1": - if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str( - cls.TEST_PATH + if "retryable-writes" in self.TEST_SPEC["description"] or "retryable_writes" in str( + self.TEST_PATH ): raise unittest.SkipTest("MMAPv1 does not support retryWrites=True") # Handle mongos_clients for transactions tests. - cls.mongos_clients = [] + self.mongos_clients = [] if ( client_context.supports_transactions() and not client_context.load_balancer and not client_context.serverless ): for address in client_context.mongoses: - cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address))) - - # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs( - heartbeat_frequency=0.1, - min_heartbeat_interval=0.1, - kill_cursor_frequency=0.1, - events_queue_frequency=0.1, - ) - cls.knobs.enable() + self.mongos_clients.append(self.single_client("{}:{}".format(*address))) - @classmethod - def _tearDown_class(cls): - cls.knobs.disable() - for client in cls.mongos_clients: - client.close() - super()._tearDown_class() - - def setUp(self): - super().setUp() # process schemaVersion # note: we check major schema version during class generation - # note: we do this here because we cannot run assertions in setUpClass version = Version.from_string(self.TEST_SPEC["schemaVersion"]) self.assertLessEqual( version, @@ -534,6 +529,11 @@ def setUp(self): # initialize internals self.match_evaluator = MatchEvaluatorUtil(self) + def tearDown(self): + for client in self.mongos_clients: + client.close() + super().tearDown() + def maybe_skip_test(self, spec): # add any special-casing for skipping tests here if client_context.storage_engine == "mmapv1": diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 8a061de0b1..682cf0b0f8 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -249,30 +249,24 @@ class SpecRunner(IntegrationTest): knobs: client_knobs listener: EventListener - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.mongos_clients = [] + def setUp(self) -> None: + super().setUp() + self.mongos_clients = [] # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - - @classmethod - def _tearDown_class(cls): - cls.knobs.disable() - for client in cls.mongos_clients: - client.close() - super()._tearDown_class() - - def setUp(self): - super().setUp() + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() self.targets = {} self.listener = None # type: ignore self.pool_listener = None self.server_listener = None self.maxDiff = None + def tearDown(self) -> None: + self.knobs.disable() + for client in self.mongos_clients: + client.close() + def _set_fail_point(self, client, command_args): cmd = SON([("configureFailPoint", "failCommand")]) cmd.update(command_args)