12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- """Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
15
+ """Shared helper methods for pymongo, bson, and gridfs test suites."""
16
16
from __future__ import annotations
17
17
18
18
import asyncio
19
- import base64
20
- import gc
21
- import multiprocessing
22
- import os
23
- import signal
24
- import socket
25
- import subprocess
26
- import sys
27
19
import threading
28
- import time
29
- import traceback
30
- import unittest
31
- import warnings
32
- from inspect import iscoroutinefunction
33
- from pathlib import Path
20
+ from typing import Optional
34
21
22
+ from bson import SON
35
23
from pymongo ._asyncio_task import create_task
36
-
37
- try :
38
- import ipaddress
39
-
40
- HAVE_IPADDRESS = True
41
- except ImportError :
42
- HAVE_IPADDRESS = False
43
- from functools import wraps
44
- from typing import Any , Callable , Dict , Generator , Optional , no_type_check
45
- from unittest import SkipTest
46
-
47
- from bson .son import SON
48
- from pymongo import common , message
49
24
from pymongo .read_preferences import ReadPreference
50
- from pymongo .ssl_support import HAVE_SSL , _ssl # type:ignore[attr-defined]
51
- from pymongo .synchronous .uri_parser import parse_uri
52
-
53
- if HAVE_SSL :
54
- import ssl
55
25
56
26
_IS_SYNC = False
57
27
58
- # Enable debug output for uncollectable objects. PyPy does not have set_debug.
59
- if hasattr (gc , "set_debug" ):
60
- gc .set_debug (
61
- gc .DEBUG_UNCOLLECTABLE | getattr (gc , "DEBUG_OBJECTS" , 0 ) | getattr (gc , "DEBUG_INSTANCES" , 0 )
62
- )
63
-
64
- # The host and port of a single mongod or mongos, or the seed host
65
- # for a replica set.
66
- host = os .environ .get ("DB_IP" , "localhost" )
67
- port = int (os .environ .get ("DB_PORT" , 27017 ))
68
- IS_SRV = "mongodb+srv" in host
69
-
70
- db_user = os .environ .get ("DB_USER" , "user" )
71
- db_pwd = os .environ .get ("DB_PASSWORD" , "password" )
72
-
73
- HERE = Path (__file__ ).absolute ()
74
- if _IS_SYNC :
75
- CERT_PATH = str (HERE .parent / "certificates" )
76
- else :
77
- CERT_PATH = str (HERE .parent .parent / "certificates" )
78
- CLIENT_PEM = os .environ .get ("CLIENT_PEM" , os .path .join (CERT_PATH , "client.pem" ))
79
- CA_PEM = os .environ .get ("CA_PEM" , os .path .join (CERT_PATH , "ca.pem" ))
80
-
81
- TLS_OPTIONS : Dict = {"tls" : True }
82
- if CLIENT_PEM :
83
- TLS_OPTIONS ["tlsCertificateKeyFile" ] = CLIENT_PEM
84
- if CA_PEM :
85
- TLS_OPTIONS ["tlsCAFile" ] = CA_PEM
86
-
87
- COMPRESSORS = os .environ .get ("COMPRESSORS" )
88
- MONGODB_API_VERSION = os .environ .get ("MONGODB_API_VERSION" )
89
- TEST_LOADBALANCER = bool (os .environ .get ("TEST_LOAD_BALANCER" ))
90
- SINGLE_MONGOS_LB_URI = os .environ .get ("SINGLE_MONGOS_LB_URI" )
91
- MULTI_MONGOS_LB_URI = os .environ .get ("MULTI_MONGOS_LB_URI" )
92
-
93
- if TEST_LOADBALANCER :
94
- res = parse_uri (SINGLE_MONGOS_LB_URI or "" )
95
- host , port = res ["nodelist" ][0 ]
96
- db_user = res ["username" ] or db_user
97
- db_pwd = res ["password" ] or db_pwd
98
-
99
-
100
- # Shared KMS data.
101
- LOCAL_MASTER_KEY = base64 .b64decode (
102
- b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ"
103
- b"5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk"
104
- )
105
- AWS_CREDS = {
106
- "accessKeyId" : os .environ .get ("FLE_AWS_KEY" , "" ),
107
- "secretAccessKey" : os .environ .get ("FLE_AWS_SECRET" , "" ),
108
- }
109
- AWS_CREDS_2 = {
110
- "accessKeyId" : os .environ .get ("FLE_AWS_KEY2" , "" ),
111
- "secretAccessKey" : os .environ .get ("FLE_AWS_SECRET2" , "" ),
112
- }
113
- AZURE_CREDS = {
114
- "tenantId" : os .environ .get ("FLE_AZURE_TENANTID" , "" ),
115
- "clientId" : os .environ .get ("FLE_AZURE_CLIENTID" , "" ),
116
- "clientSecret" : os .environ .get ("FLE_AZURE_CLIENTSECRET" , "" ),
117
- }
118
- GCP_CREDS = {
119
- "email" : os .environ .get ("FLE_GCP_EMAIL" , "" ),
120
- "privateKey" : os .environ .get ("FLE_GCP_PRIVATEKEY" , "" ),
121
- }
122
- KMIP_CREDS = {"endpoint" : os .environ .get ("FLE_KMIP_ENDPOINT" , "localhost:5698" )}
123
- AWS_TEMP_CREDS = {
124
- "accessKeyId" : os .environ .get ("CSFLE_AWS_TEMP_ACCESS_KEY_ID" , "" ),
125
- "secretAccessKey" : os .environ .get ("CSFLE_AWS_TEMP_SECRET_ACCESS_KEY" , "" ),
126
- "sessionToken" : os .environ .get ("CSFLE_AWS_TEMP_SESSION_TOKEN" , "" ),
127
- }
128
-
129
- ALL_KMS_PROVIDERS = dict (
130
- aws = AWS_CREDS ,
131
- azure = AZURE_CREDS ,
132
- gcp = GCP_CREDS ,
133
- local = dict (key = LOCAL_MASTER_KEY ),
134
- kmip = KMIP_CREDS ,
135
- )
136
- DEFAULT_KMS_TLS = dict (kmip = dict (tlsCAFile = CA_PEM , tlsCertificateKeyFile = CLIENT_PEM ))
137
-
138
- # Ensure Evergreen metadata doesn't result in truncation
139
- os .environ .setdefault ("MONGOB_LOG_MAX_DOCUMENT_LENGTH" , "2000" )
140
-
141
-
142
- def is_server_resolvable ():
143
- """Returns True if 'server' is resolvable."""
144
- socket_timeout = socket .getdefaulttimeout ()
145
- socket .setdefaulttimeout (1 )
146
- try :
147
- try :
148
- socket .gethostbyname ("server" )
149
- return True
150
- except OSError :
151
- return False
152
- finally :
153
- socket .setdefaulttimeout (socket_timeout )
154
-
155
-
156
- def _create_user (authdb , user , pwd = None , roles = None , ** kwargs ):
157
- cmd = SON ([("createUser" , user )])
158
- # X509 doesn't use a password
159
- if pwd :
160
- cmd ["pwd" ] = pwd
161
- cmd ["roles" ] = roles or ["root" ]
162
- cmd .update (** kwargs )
163
- return authdb .command (cmd )
164
-
165
28
166
29
async def async_repl_set_step_down (client , ** kwargs ):
167
30
"""Run replSetStepDown, first unfreezing a secondary with replSetFreeze."""
@@ -173,216 +36,6 @@ async def async_repl_set_step_down(client, **kwargs):
173
36
await client .admin .command (cmd )
174
37
175
38
176
- class client_knobs :
177
- def __init__ (
178
- self ,
179
- heartbeat_frequency = None ,
180
- min_heartbeat_interval = None ,
181
- kill_cursor_frequency = None ,
182
- events_queue_frequency = None ,
183
- ):
184
- self .heartbeat_frequency = heartbeat_frequency
185
- self .min_heartbeat_interval = min_heartbeat_interval
186
- self .kill_cursor_frequency = kill_cursor_frequency
187
- self .events_queue_frequency = events_queue_frequency
188
-
189
- self .old_heartbeat_frequency = None
190
- self .old_min_heartbeat_interval = None
191
- self .old_kill_cursor_frequency = None
192
- self .old_events_queue_frequency = None
193
- self ._enabled = False
194
- self ._stack = None
195
-
196
- def enable (self ):
197
- self .old_heartbeat_frequency = common .HEARTBEAT_FREQUENCY
198
- self .old_min_heartbeat_interval = common .MIN_HEARTBEAT_INTERVAL
199
- self .old_kill_cursor_frequency = common .KILL_CURSOR_FREQUENCY
200
- self .old_events_queue_frequency = common .EVENTS_QUEUE_FREQUENCY
201
-
202
- if self .heartbeat_frequency is not None :
203
- common .HEARTBEAT_FREQUENCY = self .heartbeat_frequency
204
-
205
- if self .min_heartbeat_interval is not None :
206
- common .MIN_HEARTBEAT_INTERVAL = self .min_heartbeat_interval
207
-
208
- if self .kill_cursor_frequency is not None :
209
- common .KILL_CURSOR_FREQUENCY = self .kill_cursor_frequency
210
-
211
- if self .events_queue_frequency is not None :
212
- common .EVENTS_QUEUE_FREQUENCY = self .events_queue_frequency
213
- self ._enabled = True
214
- # Store the allocation traceback to catch non-disabled client_knobs.
215
- self ._stack = "" .join (traceback .format_stack ())
216
-
217
- def __enter__ (self ):
218
- self .enable ()
219
-
220
- @no_type_check
221
- def disable (self ):
222
- common .HEARTBEAT_FREQUENCY = self .old_heartbeat_frequency
223
- common .MIN_HEARTBEAT_INTERVAL = self .old_min_heartbeat_interval
224
- common .KILL_CURSOR_FREQUENCY = self .old_kill_cursor_frequency
225
- common .EVENTS_QUEUE_FREQUENCY = self .old_events_queue_frequency
226
- self ._enabled = False
227
-
228
- def __exit__ (self , exc_type , exc_val , exc_tb ):
229
- self .disable ()
230
-
231
- def __call__ (self , func ):
232
- def make_wrapper (f ):
233
- @wraps (f )
234
- async def wrap (* args , ** kwargs ):
235
- with self :
236
- return await f (* args , ** kwargs )
237
-
238
- return wrap
239
-
240
- return make_wrapper (func )
241
-
242
- def __del__ (self ):
243
- if self ._enabled :
244
- msg = (
245
- "ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY={}, "
246
- "MIN_HEARTBEAT_INTERVAL={}, KILL_CURSOR_FREQUENCY={}, "
247
- "EVENTS_QUEUE_FREQUENCY={}, stack:\n {}" .format (
248
- common .HEARTBEAT_FREQUENCY ,
249
- common .MIN_HEARTBEAT_INTERVAL ,
250
- common .KILL_CURSOR_FREQUENCY ,
251
- common .EVENTS_QUEUE_FREQUENCY ,
252
- self ._stack ,
253
- )
254
- )
255
- self .disable ()
256
- raise Exception (msg )
257
-
258
-
259
- def _all_users (db ):
260
- return {u ["user" ] for u in db .command ("usersInfo" ).get ("users" , [])}
261
-
262
-
263
- def sanitize_cmd (cmd ):
264
- cp = cmd .copy ()
265
- cp .pop ("$clusterTime" , None )
266
- cp .pop ("$db" , None )
267
- cp .pop ("$readPreference" , None )
268
- cp .pop ("lsid" , None )
269
- if MONGODB_API_VERSION :
270
- # Stable API parameters
271
- cp .pop ("apiVersion" , None )
272
- # OP_MSG encoding may move the payload type one field to the
273
- # end of the command. Do the same here.
274
- name = next (iter (cp ))
275
- try :
276
- identifier = message ._FIELD_MAP [name ]
277
- docs = cp .pop (identifier )
278
- cp [identifier ] = docs
279
- except KeyError :
280
- pass
281
- return cp
282
-
283
-
284
- def sanitize_reply (reply ):
285
- cp = reply .copy ()
286
- cp .pop ("$clusterTime" , None )
287
- cp .pop ("operationTime" , None )
288
- return cp
289
-
290
-
291
- def print_thread_tracebacks () -> None :
292
- """Print all Python thread tracebacks."""
293
- for thread_id , frame in sys ._current_frames ().items ():
294
- sys .stderr .write (f"\n --- Traceback for thread { thread_id } ---\n " )
295
- traceback .print_stack (frame , file = sys .stderr )
296
-
297
-
298
- def print_thread_stacks (pid : int ) -> None :
299
- """Print all C-level thread stacks for a given process id."""
300
- if sys .platform == "darwin" :
301
- cmd = ["lldb" , "--attach-pid" , f"{ pid } " , "--batch" , "--one-line" , '"thread backtrace all"' ]
302
- else :
303
- cmd = ["gdb" , f"--pid={ pid } " , "--batch" , '--eval-command="thread apply all bt"' ]
304
-
305
- try :
306
- res = subprocess .run (
307
- cmd , stdout = subprocess .PIPE , stderr = subprocess .STDOUT , encoding = "utf-8"
308
- )
309
- except Exception as exc :
310
- sys .stderr .write (f"Could not print C-level thread stacks because { cmd [0 ]} failed: { exc } " )
311
- else :
312
- sys .stderr .write (res .stdout )
313
-
314
-
315
- # Global knobs to speed up the test suite.
316
- global_knobs = client_knobs (events_queue_frequency = 0.05 )
317
-
318
-
319
- def _get_executors (topology ):
320
- executors = []
321
- for server in topology ._servers .values ():
322
- # Some MockMonitor do not have an _executor.
323
- if hasattr (server ._monitor , "_executor" ):
324
- executors .append (server ._monitor ._executor )
325
- if hasattr (server ._monitor , "_rtt_monitor" ):
326
- executors .append (server ._monitor ._rtt_monitor ._executor )
327
- executors .append (topology ._Topology__events_executor )
328
- if topology ._srv_monitor :
329
- executors .append (topology ._srv_monitor ._executor )
330
-
331
- return [e for e in executors if e is not None ]
332
-
333
-
334
- def print_running_topology (topology ):
335
- running = [e for e in _get_executors (topology ) if not e ._stopped ]
336
- if running :
337
- print (
338
- "WARNING: found Topology with running threads:\n "
339
- f" Threads: { running } \n "
340
- f" Topology: { topology } \n "
341
- f" Creation traceback:\n { topology ._settings ._stack } "
342
- )
343
-
344
-
345
- def test_cases (suite ):
346
- """Iterator over all TestCases within a TestSuite."""
347
- for suite_or_case in suite ._tests :
348
- if isinstance (suite_or_case , unittest .TestCase ):
349
- # unittest.TestCase
350
- yield suite_or_case
351
- else :
352
- # unittest.TestSuite
353
- yield from test_cases (suite_or_case )
354
-
355
-
356
- # Helper method to workaround https://bugs.python.org/issue21724
357
- def clear_warning_registry ():
358
- """Clear the __warningregistry__ for all modules."""
359
- for _ , module in list (sys .modules .items ()):
360
- if hasattr (module , "__warningregistry__" ):
361
- module .__warningregistry__ = {} # type:ignore[attr-defined]
362
-
363
-
364
- class SystemCertsPatcher :
365
- def __init__ (self , ca_certs ):
366
- if (
367
- ssl .OPENSSL_VERSION .lower ().startswith ("libressl" )
368
- and sys .platform == "darwin"
369
- and not _ssl .IS_PYOPENSSL
370
- ):
371
- raise SkipTest (
372
- "LibreSSL on OSX doesn't support setting CA certificates "
373
- "using SSL_CERT_FILE environment variable."
374
- )
375
- self .original_certs = os .environ .get ("SSL_CERT_FILE" )
376
- # Tell OpenSSL where CA certificates live.
377
- os .environ ["SSL_CERT_FILE" ] = ca_certs
378
-
379
- def disable (self ):
380
- if self .original_certs is None :
381
- os .environ .pop ("SSL_CERT_FILE" )
382
- else :
383
- os .environ ["SSL_CERT_FILE" ] = self .original_certs
384
-
385
-
386
39
if _IS_SYNC :
387
40
PARENT = threading .Thread
388
41
else :
0 commit comments