diff --git a/test/asynchronous/test_auth_spec.py b/test/asynchronous/test_auth_spec.py index a6ab1cb331..e9e43d5759 100644 --- a/test/asynchronous/test_auth_spec.py +++ b/test/asynchronous/test_auth_spec.py @@ -25,7 +25,7 @@ sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes from pymongo import AsyncMongoClient from pymongo.asynchronous.auth_oidc import OIDCCallback diff --git a/test/asynchronous/test_change_stream.py b/test/asynchronous/test_change_stream.py index 873631bbe5..08da00cc1e 100644 --- a/test/asynchronous/test_change_stream.py +++ b/test/asynchronous/test_change_stream.py @@ -35,7 +35,7 @@ async_client_context, unittest, ) -from test.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes from test.utils import ( AllowListEventListener, EventListener, diff --git a/test/asynchronous/test_connection_logging.py b/test/asynchronous/test_connection_logging.py index 6bc9835b70..945c6c59b5 100644 --- a/test/asynchronous/test_connection_logging.py +++ b/test/asynchronous/test_connection_logging.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes _IS_SYNC = False diff --git a/test/asynchronous/test_create_entities.py b/test/asynchronous/test_create_entities.py index cb2ec63f4c..1f68cf6ddc 100644 --- a/test/asynchronous/test_create_entities.py +++ b/test/asynchronous/test_create_entities.py @@ -56,6 +56,9 @@ async def test_store_events_as_entities(self): self.assertGreater(len(final_entity_map["events1"]), 0) for event in final_entity_map["events1"]: self.assertIn("PoolCreatedEvent", event["name"]) + if self.scenario_runner.mongos_clients: + for client in self.scenario_runner.mongos_clients: + await client.close() async def test_store_all_others_as_entities(self): self.scenario_runner = UnifiedSpecTestMixinV1() @@ -122,6 +125,9 @@ async def test_store_all_others_as_entities(self): self.assertEqual(entity_map["failures"], []) self.assertEqual(entity_map["successes"], 2) self.assertEqual(entity_map["iterations"], 5) + if self.scenario_runner.mongos_clients: + for client in self.scenario_runner.mongos_clients: + await client.close() if __name__ == "__main__": diff --git a/test/asynchronous/test_crud_unified.py b/test/asynchronous/test_crud_unified.py index 3d8deb36e9..e6f42d5bdf 100644 --- a/test/asynchronous/test_crud_unified.py +++ b/test/asynchronous/test_crud_unified.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes _IS_SYNC = False diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index d75bad6862..ba68960c5e 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -46,6 +46,7 @@ unittest, ) from test.asynchronous.test_bulk import AsyncBulkTestBase +from test.asynchronous.unified_format import generate_test_classes from test.asynchronous.utils_spec_runner import AsyncSpecRunner from test.helpers import ( AWS_CREDS, @@ -56,7 +57,6 @@ KMIP_CREDS, LOCAL_MASTER_KEY, ) -from test.unified_format import generate_test_classes from test.utils import ( AllowListEventListener, OvertCommandListener, diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index e8d1e4380f..ec70c1dc13 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -304,7 +304,6 @@ async def _create_entity(self, entity_spec, uri=None): kwargs["h"] = uri client = await self.test.async_rs_or_single_client(**kwargs) self[spec["id"]] = client - self.test.addAsyncCleanup(client.close) return elif entity_type == "database": client = self[spec["client"]] @@ -1037,7 +1036,6 @@ async def _testOperation_targetedFailPoint(self, spec): ) client = await self.async_single_client("{}:{}".format(*session._pinned_address)) - self.addAsyncCleanup(client.close) await self.__set_fail_point(client=client, command_args=spec["failPoint"]) async def _testOperation_createEntities(self, spec): diff --git a/test/test_create_entities.py b/test/test_create_entities.py index ad75fe5702..9d77a08eee 100644 --- a/test/test_create_entities.py +++ b/test/test_create_entities.py @@ -56,6 +56,9 @@ def test_store_events_as_entities(self): self.assertGreater(len(final_entity_map["events1"]), 0) for event in final_entity_map["events1"]: self.assertIn("PoolCreatedEvent", event["name"]) + if self.scenario_runner.mongos_clients: + for client in self.scenario_runner.mongos_clients: + client.close() def test_store_all_others_as_entities(self): self.scenario_runner = UnifiedSpecTestMixinV1() @@ -122,6 +125,9 @@ def test_store_all_others_as_entities(self): self.assertEqual(entity_map["failures"], []) self.assertEqual(entity_map["successes"], 2) self.assertEqual(entity_map["iterations"], 5) + if self.scenario_runner.mongos_clients: + for client in self.scenario_runner.mongos_clients: + client.close() if __name__ == "__main__": diff --git a/test/unified_format.py b/test/unified_format.py index 435078989b..6fea541e8a 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -303,7 +303,6 @@ def _create_entity(self, entity_spec, uri=None): kwargs["h"] = uri client = self.test.rs_or_single_client(**kwargs) self[spec["id"]] = client - self.test.addCleanup(client.close) return elif entity_type == "database": client = self[spec["client"]] @@ -1028,7 +1027,6 @@ def _testOperation_targetedFailPoint(self, spec): ) client = self.single_client("{}:{}".format(*session._pinned_address)) - self.addCleanup(client.close) self.__set_fail_point(client=client, command_args=spec["failPoint"]) def _testOperation_createEntities(self, spec):