Skip to content

Commit 1fec0a1

Browse files
authored
Merge branch 'async-improvements' into PYTHON-4860
2 parents 6e5b730 + e4ebfa4 commit 1fec0a1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+473
-881
lines changed

test/__init__.py

Lines changed: 13 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,26 +1115,10 @@ def enable_replication(self, client):
11151115
class UnitTest(PyMongoTestCase):
11161116
"""Async base class for TestCases that don't require a connection to MongoDB."""
11171117

1118-
@classmethod
1119-
def setUpClass(cls):
1120-
if _IS_SYNC:
1121-
cls._setup_class()
1122-
else:
1123-
asyncio.run(cls._setup_class())
1124-
1125-
@classmethod
1126-
def tearDownClass(cls):
1127-
if _IS_SYNC:
1128-
cls._tearDown_class()
1129-
else:
1130-
asyncio.run(cls._tearDown_class())
1131-
1132-
@classmethod
1133-
def _setup_class(cls):
1118+
def setUp(self) -> None:
11341119
pass
11351120

1136-
@classmethod
1137-
def _tearDown_class(cls):
1121+
def tearDown(self) -> None:
11381122
pass
11391123

11401124

@@ -1145,37 +1129,20 @@ class IntegrationTest(PyMongoTestCase):
11451129
db: Database
11461130
credentials: Dict[str, str]
11471131

1148-
@classmethod
1149-
def setUpClass(cls):
1150-
if _IS_SYNC:
1151-
cls._setup_class()
1152-
else:
1153-
asyncio.run(cls._setup_class())
1154-
1155-
@classmethod
1156-
def tearDownClass(cls):
1157-
if _IS_SYNC:
1158-
cls._tearDown_class()
1159-
else:
1160-
asyncio.run(cls._tearDown_class())
1161-
1162-
@classmethod
11631132
@client_context.require_connection
1164-
def _setup_class(cls):
1165-
if client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False):
1133+
def setUp(self) -> None:
1134+
if not _IS_SYNC:
1135+
reset_client_context()
1136+
if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
11661137
raise SkipTest("this test does not support load balancers")
1167-
if client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False):
1138+
if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
11681139
raise SkipTest("this test does not support serverless")
1169-
cls.client = client_context.client
1170-
cls.db = cls.client.pymongo_test
1140+
self.client = client_context.client
1141+
self.db = self.client.pymongo_test
11711142
if client_context.auth_enabled:
1172-
cls.credentials = {"username": db_user, "password": db_pwd}
1143+
self.credentials = {"username": db_user, "password": db_pwd}
11731144
else:
1174-
cls.credentials = {}
1175-
1176-
@classmethod
1177-
def _tearDown_class(cls):
1178-
pass
1145+
self.credentials = {}
11791146

11801147
def cleanup_colls(self, *collections):
11811148
"""Cleanup collections faster than drop_collection."""
@@ -1201,37 +1168,14 @@ class MockClientTest(UnitTest):
12011168
# MockClients tests that use replicaSet, directConnection=True, pass
12021169
# multiple seed addresses, or wait for heartbeat events are incompatible
12031170
# with loadBalanced=True.
1204-
@classmethod
1205-
def setUpClass(cls):
1206-
if _IS_SYNC:
1207-
cls._setup_class()
1208-
else:
1209-
asyncio.run(cls._setup_class())
1210-
1211-
@classmethod
1212-
def tearDownClass(cls):
1213-
if _IS_SYNC:
1214-
cls._tearDown_class()
1215-
else:
1216-
asyncio.run(cls._tearDown_class())
1217-
1218-
@classmethod
12191171
@client_context.require_no_load_balancer
1220-
def _setup_class(cls):
1221-
pass
1222-
1223-
@classmethod
1224-
def _tearDown_class(cls):
1225-
pass
1226-
1227-
def setUp(self):
1172+
def setUp(self) -> None:
12281173
super().setUp()
12291174

12301175
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
1231-
12321176
self.client_knobs.enable()
12331177

1234-
def tearDown(self):
1178+
def tearDown(self) -> None:
12351179
self.client_knobs.disable()
12361180
super().tearDown()
12371181

test/asynchronous/__init__.py

Lines changed: 15 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,26 +1133,10 @@ async def enable_replication(self, client):
11331133
class AsyncUnitTest(AsyncPyMongoTestCase):
11341134
"""Async base class for TestCases that don't require a connection to MongoDB."""
11351135

1136-
@classmethod
1137-
def setUpClass(cls):
1138-
if _IS_SYNC:
1139-
cls._setup_class()
1140-
else:
1141-
asyncio.run(cls._setup_class())
1142-
1143-
@classmethod
1144-
def tearDownClass(cls):
1145-
if _IS_SYNC:
1146-
cls._tearDown_class()
1147-
else:
1148-
asyncio.run(cls._tearDown_class())
1149-
1150-
@classmethod
1151-
async def _setup_class(cls):
1136+
async def asyncSetUp(self) -> None:
11521137
pass
11531138

1154-
@classmethod
1155-
async def _tearDown_class(cls):
1139+
async def asyncTearDown(self) -> None:
11561140
pass
11571141

11581142

@@ -1163,37 +1147,20 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
11631147
db: AsyncDatabase
11641148
credentials: Dict[str, str]
11651149

1166-
@classmethod
1167-
def setUpClass(cls):
1168-
if _IS_SYNC:
1169-
cls._setup_class()
1170-
else:
1171-
asyncio.run(cls._setup_class())
1172-
1173-
@classmethod
1174-
def tearDownClass(cls):
1175-
if _IS_SYNC:
1176-
cls._tearDown_class()
1177-
else:
1178-
asyncio.run(cls._tearDown_class())
1179-
1180-
@classmethod
11811150
@async_client_context.require_connection
1182-
async def _setup_class(cls):
1183-
if async_client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False):
1151+
async def asyncSetUp(self) -> None:
1152+
if not _IS_SYNC:
1153+
await reset_client_context()
1154+
if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
11841155
raise SkipTest("this test does not support load balancers")
1185-
if async_client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False):
1156+
if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
11861157
raise SkipTest("this test does not support serverless")
1187-
cls.client = async_client_context.client
1188-
cls.db = cls.client.pymongo_test
1158+
self.client = async_client_context.client
1159+
self.db = self.client.pymongo_test
11891160
if async_client_context.auth_enabled:
1190-
cls.credentials = {"username": db_user, "password": db_pwd}
1161+
self.credentials = {"username": db_user, "password": db_pwd}
11911162
else:
1192-
cls.credentials = {}
1193-
1194-
@classmethod
1195-
async def _tearDown_class(cls):
1196-
pass
1163+
self.credentials = {}
11971164

11981165
async def cleanup_colls(self, *collections):
11991166
"""Cleanup collections faster than drop_collection."""
@@ -1219,39 +1186,16 @@ class AsyncMockClientTest(AsyncUnitTest):
12191186
# MockClients tests that use replicaSet, directConnection=True, pass
12201187
# multiple seed addresses, or wait for heartbeat events are incompatible
12211188
# with loadBalanced=True.
1222-
@classmethod
1223-
def setUpClass(cls):
1224-
if _IS_SYNC:
1225-
cls._setup_class()
1226-
else:
1227-
asyncio.run(cls._setup_class())
1228-
1229-
@classmethod
1230-
def tearDownClass(cls):
1231-
if _IS_SYNC:
1232-
cls._tearDown_class()
1233-
else:
1234-
asyncio.run(cls._tearDown_class())
1235-
1236-
@classmethod
12371189
@async_client_context.require_no_load_balancer
1238-
async def _setup_class(cls):
1239-
pass
1240-
1241-
@classmethod
1242-
async def _tearDown_class(cls):
1243-
pass
1244-
1245-
def setUp(self):
1246-
super().setUp()
1190+
async def asyncSetUp(self) -> None:
1191+
await super().asyncSetUp()
12471192

12481193
self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001)
1249-
12501194
self.client_knobs.enable()
12511195

1252-
def tearDown(self):
1196+
async def asyncTearDown(self) -> None:
12531197
self.client_knobs.disable()
1254-
super().tearDown()
1198+
await super().asyncTearDown()
12551199

12561200

12571201
async def async_setup():

test/asynchronous/test_bulk.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,11 @@ class AsyncBulkTestBase(AsyncIntegrationTest):
4242
coll: AsyncCollection
4343
coll_w0: AsyncCollection
4444

45-
@classmethod
46-
async def _setup_class(cls):
47-
await super()._setup_class()
48-
cls.coll = cls.db.test
49-
cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0))
50-
5145
async def asyncSetUp(self):
52-
super().setUp()
46+
await super().asyncSetUp()
47+
self.coll = self.db.test
5348
await self.coll.drop()
49+
self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0))
5450

5551
def assertEqualResponse(self, expected, actual):
5652
"""Compare response from bulk.execute() to expected response."""
@@ -787,14 +783,10 @@ async def test_large_inserts_unordered(self):
787783

788784

789785
class AsyncBulkAuthorizationTestBase(AsyncBulkTestBase):
790-
@classmethod
791786
@async_client_context.require_auth
792787
@async_client_context.require_no_api_version
793-
async def _setup_class(cls):
794-
await super()._setup_class()
795-
796788
async def asyncSetUp(self):
797-
super().setUp()
789+
await super().asyncSetUp()
798790
await async_client_context.create_user(self.db.name, "readonly", "pw", ["read"])
799791
await self.db.command(
800792
"createRole",
@@ -937,21 +929,19 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase):
937929
w: Optional[int]
938930
secondary: AsyncMongoClient
939931

940-
@classmethod
941-
async def _setup_class(cls):
942-
await super()._setup_class()
943-
cls.w = async_client_context.w
944-
cls.secondary = None
945-
if cls.w is not None and cls.w > 1:
932+
async def asyncSetUp(self):
933+
await super().asyncSetUp()
934+
self.w = async_client_context.w
935+
self.secondary = None
936+
if self.w is not None and self.w > 1:
946937
for member in (await async_client_context.hello)["hosts"]:
947938
if member != (await async_client_context.hello)["primary"]:
948-
cls.secondary = await cls.unmanaged_async_single_client(*partition_node(member))
939+
self.secondary = await self.async_single_client(*partition_node(member))
949940
break
950941

951-
@classmethod
952-
async def async_tearDownClass(cls):
953-
if cls.secondary:
954-
await cls.secondary.close()
942+
async def asyncTearDown(self):
943+
if self.secondary:
944+
await self.secondary.close()
955945

956946
async def cause_wtimeout(self, requests, ordered):
957947
if not async_client_context.test_commands_enabled:

test/asynchronous/test_change_stream.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -835,18 +835,16 @@ async def test_split_large_change(self):
835835
class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
836836
dbs: list
837837

838-
@classmethod
839838
@async_client_context.require_version_min(4, 0, 0, -1)
840839
@async_client_context.require_change_streams
841-
async def _setup_class(cls):
842-
await super()._setup_class()
843-
cls.dbs = [cls.db, cls.client.pymongo_test_2]
840+
async def asyncSetUp(self) -> None:
841+
await super().asyncSetUp()
842+
self.dbs = [self.db, self.client.pymongo_test_2]
844843

845-
@classmethod
846-
async def _tearDown_class(cls):
847-
for db in cls.dbs:
848-
await cls.client.drop_database(db)
849-
await super()._tearDown_class()
844+
async def asyncTearDown(self):
845+
for db in self.dbs:
846+
await self.client.drop_database(db)
847+
await super().asyncTearDown()
850848

851849
async def change_stream_with_client(self, client, *args, **kwargs):
852850
return await client.watch(*args, **kwargs)
@@ -897,11 +895,10 @@ async def test_full_pipeline(self):
897895

898896

899897
class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin):
900-
@classmethod
901898
@async_client_context.require_version_min(4, 0, 0, -1)
902899
@async_client_context.require_change_streams
903-
async def _setup_class(cls):
904-
await super()._setup_class()
900+
async def asyncSetUp(self) -> None:
901+
await super().asyncSetUp()
905902

906903
async def change_stream_with_client(self, client, *args, **kwargs):
907904
return await client[self.db.name].watch(*args, **kwargs)
@@ -987,12 +984,9 @@ async def test_isolation(self):
987984
class TestAsyncCollectionAsyncChangeStream(
988985
TestAsyncChangeStreamBase, APITestsMixin, ProseSpecTestsMixin
989986
):
990-
@classmethod
991987
@async_client_context.require_change_streams
992-
async def _setup_class(cls):
993-
await super()._setup_class()
994-
995988
async def asyncSetUp(self):
989+
await super().asyncSetUp()
996990
# Use a new collection for each test.
997991
await self.watched_collection().drop()
998992
await self.watched_collection().insert_one({})
@@ -1132,20 +1126,11 @@ class TestAllLegacyScenarios(AsyncIntegrationTest):
11321126
RUN_ON_LOAD_BALANCER = True
11331127
listener: AllowListEventListener
11341128

1135-
@classmethod
11361129
@async_client_context.require_connection
1137-
async def _setup_class(cls):
1138-
await super()._setup_class()
1139-
cls.listener = AllowListEventListener("aggregate", "getMore")
1140-
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
1141-
1142-
@classmethod
1143-
async def _tearDown_class(cls):
1144-
await cls.client.close()
1145-
await super()._tearDown_class()
1146-
1147-
def asyncSetUp(self):
1148-
super().asyncSetUp()
1130+
async def asyncSetUp(self):
1131+
await super().asyncSetUp()
1132+
self.listener = AllowListEventListener("aggregate", "getMore")
1133+
self.client = await self.async_rs_or_single_client(event_listeners=[self.listener])
11491134
self.listener.reset()
11501135

11511136
async def asyncSetUpCluster(self, scenario_dict):

test/asynchronous/test_client.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,11 @@ class AsyncClientUnitTest(AsyncUnitTest):
130130

131131
client: AsyncMongoClient
132132

133-
@classmethod
134-
async def _setup_class(cls):
135-
cls.client = await cls.unmanaged_async_rs_or_single_client(
133+
async def asyncSetUp(self) -> None:
134+
self.client = await self.async_rs_or_single_client(
136135
connect=False, serverSelectionTimeoutMS=100
137136
)
138137

139-
@classmethod
140-
async def _tearDown_class(cls):
141-
await cls.client.close()
142-
143138
@pytest.fixture(autouse=True)
144139
def inject_fixtures(self, caplog):
145140
self._caplog = caplog

0 commit comments

Comments
 (0)