Skip to content

Commit d0193eb

Browse files
authored
PYTHON-4533 - Convert test/test_client.py to async (mongodb#1730)
1 parent 554ce7d commit d0193eb

21 files changed

+3216
-171
lines changed

pymongo/asynchronous/encryption.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ async def close(self) -> None:
299299
self.client_ref = None
300300
self.key_vault_coll = None
301301
if self.mongocryptd_client:
302-
await self.mongocryptd_client.close()
302+
await self.mongocryptd_client.aclose()
303303
self.mongocryptd_client = None
304304

305305

@@ -439,7 +439,7 @@ async def close(self) -> None:
439439
self._closed = True
440440
await self._auto_encrypter.close()
441441
if self._internal_client:
442-
await self._internal_client.close()
442+
await self._internal_client.aclose()
443443
self._internal_client = None
444444

445445

pymongo/asynchronous/mongo_client.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,10 @@ def __init__(
861861
# This will be used later if we fork.
862862
AsyncMongoClient._clients[self._topology._topology_id] = self
863863

864+
async def aconnect(self) -> None:
865+
"""Explicitly connect to MongoDB asynchronously instead of on the first operation."""
866+
await self._get_topology()
867+
864868
def _init_background(self, old_pid: Optional[int] = None) -> None:
865869
self._topology = Topology(self._topology_settings)
866870
# Seed the topology with the old one's pid so we can detect clients
@@ -1354,13 +1358,13 @@ async def __aenter__(self) -> AsyncMongoClient[_DocumentType]:
13541358
return self
13551359

13561360
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
1357-
await self.close()
1361+
await self.aclose()
13581362

13591363
# See PYTHON-3084.
13601364
__iter__ = None
13611365

13621366
def __next__(self) -> NoReturn:
1363-
raise TypeError("'MongoClient' object is not iterable")
1367+
raise TypeError("'AsyncMongoClient' object is not iterable")
13641368

13651369
next = __next__
13661370

@@ -1490,7 +1494,7 @@ async def _end_sessions(self, session_ids: list[_ServerSession]) -> None:
14901494
# command.
14911495
pass
14921496

1493-
async def close(self) -> None:
1497+
async def aclose(self) -> None:
14941498
"""Cleanup client resources and disconnect from MongoDB.
14951499
14961500
End all server sessions created by this client by sending one or more

pymongo/synchronous/mongo_client.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,10 @@ def __init__(
860860
# This will be used later if we fork.
861861
MongoClient._clients[self._topology._topology_id] = self
862862

863+
def _connect(self) -> None:
864+
"""Explicitly connect to MongoDB synchronously instead of on the first operation."""
865+
self._get_topology()
866+
863867
def _init_background(self, old_pid: Optional[int] = None) -> None:
864868
self._topology = Topology(self._topology_settings)
865869
# Seed the topology with the old one's pid so we can detect clients

test/__init__.py

Lines changed: 101 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import asyncio
1919
import base64
20+
import contextlib
2021
import gc
2122
import multiprocessing
2223
import os
@@ -39,8 +40,6 @@
3940
TEST_SERVERLESS,
4041
TLS_OPTIONS,
4142
SystemCertsPatcher,
42-
_all_users,
43-
_create_user,
4443
client_knobs,
4544
db_pwd,
4645
db_user,
@@ -62,9 +61,9 @@
6261
except ImportError:
6362
HAVE_IPADDRESS = False
6463
from contextlib import contextmanager
65-
from functools import wraps
64+
from functools import partial, wraps
6665
from test.version import Version
67-
from typing import Any, Callable, Dict, Generator
66+
from typing import Any, Callable, Dict, Generator, overload
6867
from unittest import SkipTest
6968
from urllib.parse import quote_plus
7069

@@ -812,6 +811,12 @@ def require_no_api_version(self, func):
812811
func=func,
813812
)
814813

814+
def require_sync(self, func):
815+
"""Run a test only if using the synchronous API."""
816+
return self._require(
817+
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
818+
)
819+
815820
def mongos_seeds(self):
816821
return ",".join("{}:{}".format(*address) for address in self.mongoses)
817822

@@ -919,6 +924,32 @@ def _target() -> None:
919924
self.assertEqual(proc.exitcode, 0)
920925

921926

927+
class UnitTest(PyMongoTestCase):
928+
"""Async base class for TestCases that don't require a connection to MongoDB."""
929+
930+
@classmethod
931+
def setUpClass(cls):
932+
if _IS_SYNC:
933+
cls._setup_class()
934+
else:
935+
asyncio.run(cls._setup_class())
936+
937+
@classmethod
938+
def tearDownClass(cls):
939+
if _IS_SYNC:
940+
cls._tearDown_class()
941+
else:
942+
asyncio.run(cls._tearDown_class())
943+
944+
@classmethod
945+
def _setup_class(cls):
946+
cls._setup_class()
947+
948+
@classmethod
949+
def _tearDown_class(cls):
950+
cls._tearDown_class()
951+
952+
922953
class IntegrationTest(PyMongoTestCase):
923954
"""Async base class for TestCases that need a connection to MongoDB to pass."""
924955

@@ -933,6 +964,13 @@ def setUpClass(cls):
933964
else:
934965
asyncio.run(cls._setup_class())
935966

967+
@classmethod
968+
def tearDownClass(cls):
969+
if _IS_SYNC:
970+
cls._tearDown_class()
971+
else:
972+
asyncio.run(cls._tearDown_class())
973+
936974
@classmethod
937975
@client_context.require_connection
938976
def _setup_class(cls):
@@ -947,6 +985,10 @@ def _setup_class(cls):
947985
else:
948986
cls.credentials = {}
949987

988+
@classmethod
989+
def _tearDown_class(cls):
990+
pass
991+
950992
def cleanup_colls(self, *collections):
951993
"""Cleanup collections faster than drop_collection."""
952994
for c in collections:
@@ -959,7 +1001,7 @@ def patch_system_certs(self, ca_certs):
9591001
self.addCleanup(patcher.disable)
9601002

9611003

962-
class MockClientTest(unittest.TestCase):
1004+
class MockClientTest(UnitTest):
9631005
"""Base class for TestCases that use MockClient.
9641006
9651007
This class is *not* an IntegrationTest: if properly written, MockClient
@@ -972,8 +1014,26 @@ class MockClientTest(unittest.TestCase):
9721014
# multiple seed addresses, or wait for heartbeat events are incompatible
9731015
# with loadBalanced=True.
9741016
@classmethod
975-
@client_context.require_no_load_balancer
9761017
def setUpClass(cls):
1018+
if _IS_SYNC:
1019+
cls._setup_class()
1020+
else:
1021+
asyncio.run(cls._setup_class())
1022+
1023+
@classmethod
1024+
def tearDownClass(cls):
1025+
if _IS_SYNC:
1026+
cls._tearDown_class()
1027+
else:
1028+
asyncio.run(cls._tearDown_class())
1029+
1030+
@classmethod
1031+
@client_context.require_no_load_balancer
1032+
def _setup_class(cls):
1033+
pass
1034+
1035+
@classmethod
1036+
def _tearDown_class(cls):
9771037
pass
9781038

9791039
def setUp(self):
@@ -1051,3 +1111,38 @@ def print_running_clients():
10511111
processed.add(obj._topology_id)
10521112
except ReferenceError:
10531113
pass
1114+
1115+
1116+
def _all_users(db):
1117+
return {u["user"] for u in (db.command("usersInfo")).get("users", [])}
1118+
1119+
1120+
def _create_user(authdb, user, pwd=None, roles=None, **kwargs):
1121+
cmd = SON([("createUser", user)])
1122+
# X509 doesn't use a password
1123+
if pwd:
1124+
cmd["pwd"] = pwd
1125+
cmd["roles"] = roles or ["root"]
1126+
cmd.update(**kwargs)
1127+
return authdb.command(cmd)
1128+
1129+
1130+
def connected(client):
1131+
"""Convenience to wait for a newly-constructed client to connect."""
1132+
with warnings.catch_warnings():
1133+
# Ignore warning that ping is always routed to primary even
1134+
# if client's read preference isn't PRIMARY.
1135+
warnings.simplefilter("ignore", UserWarning)
1136+
client.admin.command("ping") # Force connection.
1137+
1138+
return client
1139+
1140+
1141+
def drop_collections(db: Database):
1142+
# Drop all non-system collections in this database.
1143+
for coll in db.list_collection_names(filter={"name": {"$regex": r"^(?!system\.)"}}):
1144+
db.drop_collection(coll)
1145+
1146+
1147+
def remove_all_users(db: Database):
1148+
db.command("dropAllUsersFromDatabase", 1, writeConcern={"w": client_context.w})

0 commit comments

Comments
 (0)