Skip to content

Commit 389127a

Browse files
committed
PYTHON-4843 - Async test suite should use a single event loop per test
1 parent ee18313 commit 389127a

33 files changed

+389
-774
lines changed

test/__init__.py

Lines changed: 13 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,26 +1114,10 @@ def enable_replication(self, client):
11141114
class UnitTest(PyMongoTestCase):
11151115
"""Async base class for TestCases that don't require a connection to MongoDB."""
11161116

1117-
@classmethod
1118-
def setUpClass(cls):
1119-
if _IS_SYNC:
1120-
cls._setup_class()
1121-
else:
1122-
asyncio.run(cls._setup_class())
1123-
1124-
@classmethod
1125-
def tearDownClass(cls):
1126-
if _IS_SYNC:
1127-
cls._tearDown_class()
1128-
else:
1129-
asyncio.run(cls._tearDown_class())
1130-
1131-
@classmethod
1132-
def _setup_class(cls):
1117+
def setUp(self) -> None:
11331118
pass
11341119

1135-
@classmethod
1136-
def _tearDown_class(cls):
1120+
def tearDown(self) -> None:
11371121
pass
11381122

11391123

@@ -1144,37 +1128,20 @@ class IntegrationTest(PyMongoTestCase):
11441128
db: Database
11451129
credentials: Dict[str, str]
11461130

1147-
@classmethod
1148-
def setUpClass(cls):
1149-
if _IS_SYNC:
1150-
cls._setup_class()
1151-
else:
1152-
asyncio.run(cls._setup_class())
1153-
1154-
@classmethod
1155-
def tearDownClass(cls):
1156-
if _IS_SYNC:
1157-
cls._tearDown_class()
1158-
else:
1159-
asyncio.run(cls._tearDown_class())
1160-
1161-
@classmethod
11621131
@client_context.require_connection
1163-
def _setup_class(cls):
1164-
if client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False):
1132+
def setUp(self) -> None:
1133+
if not _IS_SYNC:
1134+
reset_client_context()
1135+
if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
11651136
raise SkipTest("this test does not support load balancers")
1166-
if client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False):
1137+
if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
11671138
raise SkipTest("this test does not support serverless")
1168-
cls.client = client_context.client
1169-
cls.db = cls.client.pymongo_test
1139+
self.client = client_context.client
1140+
self.db = self.client.pymongo_test
11701141
if client_context.auth_enabled:
1171-
cls.credentials = {"username": db_user, "password": db_pwd}
1142+
self.credentials = {"username": db_user, "password": db_pwd}
11721143
else:
1173-
cls.credentials = {}
1174-
1175-
@classmethod
1176-
def _tearDown_class(cls):
1177-
pass
1144+
self.credentials = {}
11781145

11791146
def cleanup_colls(self, *collections):
11801147
"""Cleanup collections faster than drop_collection."""
@@ -1200,37 +1167,15 @@ class MockClientTest(UnitTest):
12001167
# MockClients tests that use replicaSet, directConnection=True, pass
12011168
# multiple seed addresses, or wait for heartbeat events are incompatible
12021169
# with loadBalanced=True.
1203-
@classmethod
1204-
def setUpClass(cls):
1205-
if _IS_SYNC:
1206-
cls._setup_class()
1207-
else:
1208-
asyncio.run(cls._setup_class())
1209-
1210-
@classmethod
1211-
def tearDownClass(cls):
1212-
if _IS_SYNC:
1213-
cls._tearDown_class()
1214-
else:
1215-
asyncio.run(cls._tearDown_class())
1216-
1217-
@classmethod
12181170
@client_context.require_no_load_balancer
1219-
def _setup_class(cls):
1220-
pass
1221-
1222-
@classmethod
1223-
def _tearDown_class(cls):
1224-
pass
1225-
1226-
def setUp(self):
1171+
def setUp(self) -> None:
12271172
super().setUp()
12281173

12291174
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
12301175

12311176
self.client_knobs.enable()
12321177

1233-
def tearDown(self):
1178+
def tearDown(self) -> None:
12341179
self.client_knobs.disable()
12351180
super().tearDown()
12361181

test/asynchronous/__init__.py

Lines changed: 15 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,26 +1132,10 @@ async def enable_replication(self, client):
11321132
class AsyncUnitTest(AsyncPyMongoTestCase):
11331133
"""Async base class for TestCases that don't require a connection to MongoDB."""
11341134

1135-
@classmethod
1136-
def setUpClass(cls):
1137-
if _IS_SYNC:
1138-
cls._setup_class()
1139-
else:
1140-
asyncio.run(cls._setup_class())
1141-
1142-
@classmethod
1143-
def tearDownClass(cls):
1144-
if _IS_SYNC:
1145-
cls._tearDown_class()
1146-
else:
1147-
asyncio.run(cls._tearDown_class())
1148-
1149-
@classmethod
1150-
async def _setup_class(cls):
1135+
async def asyncSetUp(self) -> None:
11511136
pass
11521137

1153-
@classmethod
1154-
async def _tearDown_class(cls):
1138+
async def asyncTearDown(self) -> None:
11551139
pass
11561140

11571141

@@ -1162,37 +1146,20 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
11621146
db: AsyncDatabase
11631147
credentials: Dict[str, str]
11641148

1165-
@classmethod
1166-
def setUpClass(cls):
1167-
if _IS_SYNC:
1168-
cls._setup_class()
1169-
else:
1170-
asyncio.run(cls._setup_class())
1171-
1172-
@classmethod
1173-
def tearDownClass(cls):
1174-
if _IS_SYNC:
1175-
cls._tearDown_class()
1176-
else:
1177-
asyncio.run(cls._tearDown_class())
1178-
1179-
@classmethod
11801149
@async_client_context.require_connection
1181-
async def _setup_class(cls):
1182-
if async_client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False):
1150+
async def asyncSetUp(self) -> None:
1151+
if not _IS_SYNC:
1152+
await reset_client_context()
1153+
if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
11831154
raise SkipTest("this test does not support load balancers")
1184-
if async_client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False):
1155+
if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
11851156
raise SkipTest("this test does not support serverless")
1186-
cls.client = async_client_context.client
1187-
cls.db = cls.client.pymongo_test
1157+
self.client = async_client_context.client
1158+
self.db = self.client.pymongo_test
11881159
if async_client_context.auth_enabled:
1189-
cls.credentials = {"username": db_user, "password": db_pwd}
1160+
self.credentials = {"username": db_user, "password": db_pwd}
11901161
else:
1191-
cls.credentials = {}
1192-
1193-
@classmethod
1194-
async def _tearDown_class(cls):
1195-
pass
1162+
self.credentials = {}
11961163

11971164
async def cleanup_colls(self, *collections):
11981165
"""Cleanup collections faster than drop_collection."""
@@ -1218,39 +1185,17 @@ class AsyncMockClientTest(AsyncUnitTest):
12181185
# MockClients tests that use replicaSet, directConnection=True, pass
12191186
# multiple seed addresses, or wait for heartbeat events are incompatible
12201187
# with loadBalanced=True.
1221-
@classmethod
1222-
def setUpClass(cls):
1223-
if _IS_SYNC:
1224-
cls._setup_class()
1225-
else:
1226-
asyncio.run(cls._setup_class())
1227-
1228-
@classmethod
1229-
def tearDownClass(cls):
1230-
if _IS_SYNC:
1231-
cls._tearDown_class()
1232-
else:
1233-
asyncio.run(cls._tearDown_class())
1234-
1235-
@classmethod
12361188
@async_client_context.require_no_load_balancer
1237-
async def _setup_class(cls):
1238-
pass
1239-
1240-
@classmethod
1241-
async def _tearDown_class(cls):
1242-
pass
1243-
1244-
def setUp(self):
1245-
super().setUp()
1189+
async def asyncSetUp(self) -> None:
1190+
await super().asyncSetUp()
12461191

12471192
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
12481193

12491194
self.client_knobs.enable()
12501195

1251-
def tearDown(self):
1196+
async def asyncTearDown(self) -> None:
12521197
self.client_knobs.disable()
1253-
super().tearDown()
1198+
await super().asyncTearDown()
12541199

12551200

12561201
async def async_setup():

test/asynchronous/test_bulk.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,11 @@ class AsyncBulkTestBase(AsyncIntegrationTest):
4242
coll: AsyncCollection
4343
coll_w0: AsyncCollection
4444

45-
@classmethod
46-
async def _setup_class(cls):
47-
await super()._setup_class()
48-
cls.coll = cls.db.test
49-
cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0))
50-
5145
async def asyncSetUp(self):
52-
super().setUp()
46+
await super().asyncSetUp()
47+
self.coll = self.db.test
5348
await self.coll.drop()
49+
self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0))
5450

5551
def assertEqualResponse(self, expected, actual):
5652
"""Compare response from bulk.execute() to expected response."""
@@ -787,14 +783,10 @@ async def test_large_inserts_unordered(self):
787783

788784

789785
class AsyncBulkAuthorizationTestBase(AsyncBulkTestBase):
790-
@classmethod
791786
@async_client_context.require_auth
792787
@async_client_context.require_no_api_version
793-
async def _setup_class(cls):
794-
await super()._setup_class()
795-
796788
async def asyncSetUp(self):
797-
super().setUp()
789+
await super().asyncSetUp()
798790
await async_client_context.create_user(self.db.name, "readonly", "pw", ["read"])
799791
await self.db.command(
800792
"createRole",
@@ -937,21 +929,19 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase):
937929
w: Optional[int]
938930
secondary: AsyncMongoClient
939931

940-
@classmethod
941-
async def _setup_class(cls):
942-
await super()._setup_class()
943-
cls.w = async_client_context.w
944-
cls.secondary = None
945-
if cls.w is not None and cls.w > 1:
932+
async def asyncSetUp(self):
933+
await super().asyncSetUp()
934+
self.w = async_client_context.w
935+
self.secondary = None
936+
if self.w is not None and self.w > 1:
946937
for member in (await async_client_context.hello)["hosts"]:
947938
if member != (await async_client_context.hello)["primary"]:
948-
cls.secondary = await cls.unmanaged_async_single_client(*partition_node(member))
939+
self.secondary = await self.async_single_client(*partition_node(member))
949940
break
950941

951-
@classmethod
952-
async def async_tearDownClass(cls):
953-
if cls.secondary:
954-
await cls.secondary.close()
942+
async def asyncTearDown(self):
943+
if self.secondary:
944+
await self.secondary.close()
955945

956946
async def cause_wtimeout(self, requests, ordered):
957947
if not async_client_context.test_commands_enabled:

test/asynchronous/test_change_stream.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -835,18 +835,16 @@ async def test_split_large_change(self):
835835
class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
836836
dbs: list
837837

838-
@classmethod
839838
@async_client_context.require_version_min(4, 0, 0, -1)
840839
@async_client_context.require_change_streams
841-
async def _setup_class(cls):
842-
await super()._setup_class()
843-
cls.dbs = [cls.db, cls.client.pymongo_test_2]
840+
async def asyncSetUp(self) -> None:
841+
await super().asyncSetUp()
842+
self.dbs = [self.db, self.client.pymongo_test_2]
844843

845-
@classmethod
846-
async def _tearDown_class(cls):
847-
for db in cls.dbs:
848-
await cls.client.drop_database(db)
849-
await super()._tearDown_class()
844+
async def asyncTearDown(self):
845+
for db in self.dbs:
846+
await self.client.drop_database(db)
847+
await super().asyncTearDown()
850848

851849
async def change_stream_with_client(self, client, *args, **kwargs):
852850
return await client.watch(*args, **kwargs)
@@ -897,11 +895,10 @@ async def test_full_pipeline(self):
897895

898896

899897
class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
900-
@classmethod
901898
@async_client_context.require_version_min(4, 0, 0, -1)
902899
@async_client_context.require_change_streams
903-
async def _setup_class(cls):
904-
await super()._setup_class()
900+
async def asyncSetUp(self) -> None:
901+
await super().asyncSetUp()
905902

906903
async def change_stream_with_client(self, client, *args, **kwargs):
907904
return await client[self.db.name].watch(*args, **kwargs)
@@ -987,12 +984,9 @@ async def test_isolation(self):
987984
class TestAsyncCollectionAsyncChangeStream(
988985
TestAsyncChangeStreamBase, APITestsMixin, ProseSpecTestsMixin
989986
):
990-
@classmethod
991987
@async_client_context.require_change_streams
992-
async def _setup_class(cls):
993-
await super()._setup_class()
994-
995988
async def asyncSetUp(self):
989+
await super().asyncSetUp()
996990
# Use a new collection for each test.
997991
await self.watched_collection().drop()
998992
await self.watched_collection().insert_one({})
@@ -1132,20 +1126,11 @@ class TestAllLegacyScenarios(AsyncIntegrationTest):
11321126
RUN_ON_LOAD_BALANCER = True
11331127
listener: AllowListEventListener
11341128

1135-
@classmethod
11361129
@async_client_context.require_connection
1137-
async def _setup_class(cls):
1138-
await super()._setup_class()
1139-
cls.listener = AllowListEventListener("aggregate", "getMore")
1140-
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
1141-
1142-
@classmethod
1143-
async def _tearDown_class(cls):
1144-
await cls.client.close()
1145-
await super()._tearDown_class()
1146-
1147-
def asyncSetUp(self):
1148-
super().asyncSetUp()
1130+
async def asyncSetUp(self):
1131+
await super().asyncSetUp()
1132+
self.listener = AllowListEventListener("aggregate", "getMore")
1133+
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
11491134
self.listener.reset()
11501135

11511136
async def asyncSetUpCluster(self, scenario_dict):

test/asynchronous/test_client.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,11 @@ class AsyncClientUnitTest(AsyncUnitTest):
130130

131131
client: AsyncMongoClient
132132

133-
@classmethod
134-
async def _setup_class(cls):
135-
cls.client = await cls.unmanaged_async_rs_or_single_client(
133+
async def asyncSetUp(self) -> None:
134+
self.client = await self.async_rs_or_single_client(
136135
connect=False, serverSelectionTimeoutMS=100
137136
)
138137

139-
@classmethod
140-
async def _tearDown_class(cls):
141-
await cls.client.close()
142-
143138
@pytest.fixture(autouse=True)
144139
def inject_fixtures(self, caplog):
145140
self._caplog = caplog

0 commit comments

Comments
 (0)