17
17
18
18
import asyncio
19
19
import base64
20
+ import contextlib
20
21
import gc
21
22
import multiprocessing
22
23
import os
39
40
TEST_SERVERLESS ,
40
41
TLS_OPTIONS ,
41
42
SystemCertsPatcher ,
42
- _all_users ,
43
- _create_user ,
44
43
client_knobs ,
45
44
db_pwd ,
46
45
db_user ,
62
61
except ImportError :
63
62
HAVE_IPADDRESS = False
64
63
from contextlib import contextmanager
65
- from functools import wraps
64
+ from functools import partial , wraps
66
65
from test .version import Version
67
- from typing import Any , Callable , Dict , Generator
66
+ from typing import Any , Callable , Dict , Generator , overload
68
67
from unittest import SkipTest
69
68
from urllib .parse import quote_plus
70
69
@@ -812,6 +811,12 @@ def require_no_api_version(self, func):
812
811
func = func ,
813
812
)
814
813
814
+ def require_sync (self , func ):
815
+ """Run a test only if using the synchronous API."""
816
+ return self ._require (
817
+ lambda : _IS_SYNC , "This test only works with the synchronous API" , func = func
818
+ )
819
+
815
820
def mongos_seeds (self ):
816
821
return "," .join ("{}:{}" .format (* address ) for address in self .mongoses )
817
822
@@ -919,6 +924,32 @@ def _target() -> None:
919
924
self .assertEqual (proc .exitcode , 0 )
920
925
921
926
927
+ class UnitTest (PyMongoTestCase ):
928
+ """Async base class for TestCases that don't require a connection to MongoDB."""
929
+
930
+ @classmethod
931
+ def setUpClass (cls ):
932
+ if _IS_SYNC :
933
+ cls ._setup_class ()
934
+ else :
935
+ asyncio .run (cls ._setup_class ())
936
+
937
+ @classmethod
938
+ def tearDownClass (cls ):
939
+ if _IS_SYNC :
940
+ cls ._tearDown_class ()
941
+ else :
942
+ asyncio .run (cls ._tearDown_class ())
943
+
944
+ @classmethod
945
+ def _setup_class (cls ):
946
+ cls ._setup_class ()
947
+
948
+ @classmethod
949
+ def _tearDown_class (cls ):
950
+ cls ._tearDown_class ()
951
+
952
+
922
953
class IntegrationTest (PyMongoTestCase ):
923
954
"""Async base class for TestCases that need a connection to MongoDB to pass."""
924
955
@@ -933,6 +964,13 @@ def setUpClass(cls):
933
964
else :
934
965
asyncio .run (cls ._setup_class ())
935
966
967
+ @classmethod
968
+ def tearDownClass (cls ):
969
+ if _IS_SYNC :
970
+ cls ._tearDown_class ()
971
+ else :
972
+ asyncio .run (cls ._tearDown_class ())
973
+
936
974
@classmethod
937
975
@client_context .require_connection
938
976
def _setup_class (cls ):
@@ -947,6 +985,10 @@ def _setup_class(cls):
947
985
else :
948
986
cls .credentials = {}
949
987
988
+ @classmethod
989
+ def _tearDown_class (cls ):
990
+ pass
991
+
950
992
def cleanup_colls (self , * collections ):
951
993
"""Cleanup collections faster than drop_collection."""
952
994
for c in collections :
@@ -959,7 +1001,7 @@ def patch_system_certs(self, ca_certs):
959
1001
self .addCleanup (patcher .disable )
960
1002
961
1003
962
- class MockClientTest (unittest . TestCase ):
1004
+ class MockClientTest (UnitTest ):
963
1005
"""Base class for TestCases that use MockClient.
964
1006
965
1007
This class is *not* an IntegrationTest: if properly written, MockClient
@@ -972,8 +1014,26 @@ class MockClientTest(unittest.TestCase):
972
1014
# multiple seed addresses, or wait for heartbeat events are incompatible
973
1015
# with loadBalanced=True.
974
1016
@classmethod
975
- @client_context .require_no_load_balancer
976
1017
def setUpClass (cls ):
1018
+ if _IS_SYNC :
1019
+ cls ._setup_class ()
1020
+ else :
1021
+ asyncio .run (cls ._setup_class ())
1022
+
1023
+ @classmethod
1024
+ def tearDownClass (cls ):
1025
+ if _IS_SYNC :
1026
+ cls ._tearDown_class ()
1027
+ else :
1028
+ asyncio .run (cls ._tearDown_class ())
1029
+
1030
+ @classmethod
1031
+ @client_context .require_no_load_balancer
1032
+ def _setup_class (cls ):
1033
+ pass
1034
+
1035
+ @classmethod
1036
+ def _tearDown_class (cls ):
977
1037
pass
978
1038
979
1039
def setUp (self ):
@@ -1051,3 +1111,38 @@ def print_running_clients():
1051
1111
processed .add (obj ._topology_id )
1052
1112
except ReferenceError :
1053
1113
pass
1114
+
1115
+
1116
+ def _all_users (db ):
1117
+ return {u ["user" ] for u in (db .command ("usersInfo" )).get ("users" , [])}
1118
+
1119
+
1120
+ def _create_user (authdb , user , pwd = None , roles = None , ** kwargs ):
1121
+ cmd = SON ([("createUser" , user )])
1122
+ # X509 doesn't use a password
1123
+ if pwd :
1124
+ cmd ["pwd" ] = pwd
1125
+ cmd ["roles" ] = roles or ["root" ]
1126
+ cmd .update (** kwargs )
1127
+ return authdb .command (cmd )
1128
+
1129
+
1130
+ def connected (client ):
1131
+ """Convenience to wait for a newly-constructed client to connect."""
1132
+ with warnings .catch_warnings ():
1133
+ # Ignore warning that ping is always routed to primary even
1134
+ # if client's read preference isn't PRIMARY.
1135
+ warnings .simplefilter ("ignore" , UserWarning )
1136
+ client .admin .command ("ping" ) # Force connection.
1137
+
1138
+ return client
1139
+
1140
+
1141
+ def drop_collections (db : Database ):
1142
+ # Drop all non-system collections in this database.
1143
+ for coll in db .list_collection_names (filter = {"name" : {"$regex" : r"^(?!system\.)" }}):
1144
+ db .drop_collection (coll )
1145
+
1146
+
1147
+ def remove_all_users (db : Database ):
1148
+ db .command ("dropAllUsersFromDatabase" , 1 , writeConcern = {"w" : client_context .w })
0 commit comments