Skip to content

Commit 446cd96

Browse files
committed
add helper file
1 parent 13d5ee2 commit 446cd96

File tree

1 file changed

+358
-0
lines changed

1 file changed

+358
-0
lines changed

test/helpers_shared.py

Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
# Copyright 2019-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import base64
18+
import gc
19+
import os
20+
import socket
21+
import subprocess
22+
import sys
23+
import traceback
24+
import unittest
25+
from pathlib import Path
26+
27+
try:
28+
import ipaddress
29+
30+
HAVE_IPADDRESS = True
31+
except ImportError:
32+
HAVE_IPADDRESS = False
33+
from functools import wraps
34+
from typing import no_type_check
35+
from unittest import SkipTest
36+
37+
from bson.son import SON
38+
from pymongo import common, message
39+
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
40+
from pymongo.synchronous.uri_parser import parse_uri
41+
42+
if HAVE_SSL:
43+
import ssl
44+
45+
46+
# Enable debug output for uncollectable objects. PyPy does not have set_debug.
47+
if hasattr(gc, "set_debug"):
48+
gc.set_debug(
49+
gc.DEBUG_UNCOLLECTABLE | getattr(gc, "DEBUG_OBJECTS", 0) | getattr(gc, "DEBUG_INSTANCES", 0)
50+
)
51+
52+
# The host and port of a single mongod or mongos, or the seed host
53+
# for a replica set.
54+
host = os.environ.get("DB_IP", "localhost")
55+
port = int(os.environ.get("DB_PORT", 27017))
56+
IS_SRV = "mongodb+srv" in host
57+
58+
db_user = os.environ.get("DB_USER", "user")
59+
db_pwd = os.environ.get("DB_PASSWORD", "password")
60+
61+
HERE = Path(__file__).absolute()
62+
CERT_PATH = str(HERE.parent / "certificates")
63+
CLIENT_PEM = os.environ.get("CLIENT_PEM", os.path.join(CERT_PATH, "client.pem"))
64+
CA_PEM = os.environ.get("CA_PEM", os.path.join(CERT_PATH, "ca.pem"))
65+
66+
TLS_OPTIONS: dict = {"tls": True}
67+
if CLIENT_PEM:
68+
TLS_OPTIONS["tlsCertificateKeyFile"] = CLIENT_PEM
69+
if CA_PEM:
70+
TLS_OPTIONS["tlsCAFile"] = CA_PEM
71+
72+
COMPRESSORS = os.environ.get("COMPRESSORS")
73+
MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION")
74+
TEST_LOADBALANCER = bool(os.environ.get("TEST_LOAD_BALANCER"))
75+
SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI")
76+
MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI")
77+
78+
if TEST_LOADBALANCER:
79+
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
80+
host, port = res["nodelist"][0]
81+
db_user = res["username"] or db_user
82+
db_pwd = res["password"] or db_pwd
83+
84+
85+
# Shared KMS data.
86+
LOCAL_MASTER_KEY = base64.b64decode(
87+
b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ"
88+
b"5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk"
89+
)
90+
AWS_CREDS = {
91+
"accessKeyId": os.environ.get("FLE_AWS_KEY", ""),
92+
"secretAccessKey": os.environ.get("FLE_AWS_SECRET", ""),
93+
}
94+
AWS_CREDS_2 = {
95+
"accessKeyId": os.environ.get("FLE_AWS_KEY2", ""),
96+
"secretAccessKey": os.environ.get("FLE_AWS_SECRET2", ""),
97+
}
98+
AZURE_CREDS = {
99+
"tenantId": os.environ.get("FLE_AZURE_TENANTID", ""),
100+
"clientId": os.environ.get("FLE_AZURE_CLIENTID", ""),
101+
"clientSecret": os.environ.get("FLE_AZURE_CLIENTSECRET", ""),
102+
}
103+
GCP_CREDS = {
104+
"email": os.environ.get("FLE_GCP_EMAIL", ""),
105+
"privateKey": os.environ.get("FLE_GCP_PRIVATEKEY", ""),
106+
}
107+
KMIP_CREDS = {"endpoint": os.environ.get("FLE_KMIP_ENDPOINT", "localhost:5698")}
108+
AWS_TEMP_CREDS = {
109+
"accessKeyId": os.environ.get("CSFLE_AWS_TEMP_ACCESS_KEY_ID", ""),
110+
"secretAccessKey": os.environ.get("CSFLE_AWS_TEMP_SECRET_ACCESS_KEY", ""),
111+
"sessionToken": os.environ.get("CSFLE_AWS_TEMP_SESSION_TOKEN", ""),
112+
}
113+
114+
ALL_KMS_PROVIDERS = dict(
115+
aws=AWS_CREDS,
116+
azure=AZURE_CREDS,
117+
gcp=GCP_CREDS,
118+
local=dict(key=LOCAL_MASTER_KEY),
119+
kmip=KMIP_CREDS,
120+
)
121+
DEFAULT_KMS_TLS = dict(kmip=dict(tlsCAFile=CA_PEM, tlsCertificateKeyFile=CLIENT_PEM))
122+
123+
# Ensure Evergreen metadata doesn't result in truncation
124+
os.environ.setdefault("MONGOB_LOG_MAX_DOCUMENT_LENGTH", "2000")
125+
126+
127+
def is_server_resolvable():
128+
"""Returns True if 'server' is resolvable."""
129+
socket_timeout = socket.getdefaulttimeout()
130+
socket.setdefaulttimeout(1)
131+
try:
132+
try:
133+
socket.gethostbyname("server")
134+
return True
135+
except OSError:
136+
return False
137+
finally:
138+
socket.setdefaulttimeout(socket_timeout)
139+
140+
141+
def _create_user(authdb, user, pwd=None, roles=None, **kwargs):
142+
cmd = SON([("createUser", user)])
143+
# X509 doesn't use a password
144+
if pwd:
145+
cmd["pwd"] = pwd
146+
cmd["roles"] = roles or ["root"]
147+
cmd.update(**kwargs)
148+
return authdb.command(cmd)
149+
150+
151+
class client_knobs:
152+
def __init__(
153+
self,
154+
heartbeat_frequency=None,
155+
min_heartbeat_interval=None,
156+
kill_cursor_frequency=None,
157+
events_queue_frequency=None,
158+
):
159+
self.heartbeat_frequency = heartbeat_frequency
160+
self.min_heartbeat_interval = min_heartbeat_interval
161+
self.kill_cursor_frequency = kill_cursor_frequency
162+
self.events_queue_frequency = events_queue_frequency
163+
164+
self.old_heartbeat_frequency = None
165+
self.old_min_heartbeat_interval = None
166+
self.old_kill_cursor_frequency = None
167+
self.old_events_queue_frequency = None
168+
self._enabled = False
169+
self._stack = None
170+
171+
def enable(self):
172+
self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY
173+
self.old_min_heartbeat_interval = common.MIN_HEARTBEAT_INTERVAL
174+
self.old_kill_cursor_frequency = common.KILL_CURSOR_FREQUENCY
175+
self.old_events_queue_frequency = common.EVENTS_QUEUE_FREQUENCY
176+
177+
if self.heartbeat_frequency is not None:
178+
common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency
179+
180+
if self.min_heartbeat_interval is not None:
181+
common.MIN_HEARTBEAT_INTERVAL = self.min_heartbeat_interval
182+
183+
if self.kill_cursor_frequency is not None:
184+
common.KILL_CURSOR_FREQUENCY = self.kill_cursor_frequency
185+
186+
if self.events_queue_frequency is not None:
187+
common.EVENTS_QUEUE_FREQUENCY = self.events_queue_frequency
188+
self._enabled = True
189+
# Store the allocation traceback to catch non-disabled client_knobs.
190+
self._stack = "".join(traceback.format_stack())
191+
192+
def __enter__(self):
193+
self.enable()
194+
195+
@no_type_check
196+
def disable(self):
197+
common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency
198+
common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval
199+
common.KILL_CURSOR_FREQUENCY = self.old_kill_cursor_frequency
200+
common.EVENTS_QUEUE_FREQUENCY = self.old_events_queue_frequency
201+
self._enabled = False
202+
203+
def __exit__(self, exc_type, exc_val, exc_tb):
204+
self.disable()
205+
206+
def __call__(self, func):
207+
def make_wrapper(f):
208+
@wraps(f)
209+
async def wrap(*args, **kwargs):
210+
with self:
211+
return await f(*args, **kwargs)
212+
213+
return wrap
214+
215+
return make_wrapper(func)
216+
217+
def __del__(self):
218+
if self._enabled:
219+
msg = (
220+
"ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY={}, "
221+
"MIN_HEARTBEAT_INTERVAL={}, KILL_CURSOR_FREQUENCY={}, "
222+
"EVENTS_QUEUE_FREQUENCY={}, stack:\n{}".format(
223+
common.HEARTBEAT_FREQUENCY,
224+
common.MIN_HEARTBEAT_INTERVAL,
225+
common.KILL_CURSOR_FREQUENCY,
226+
common.EVENTS_QUEUE_FREQUENCY,
227+
self._stack,
228+
)
229+
)
230+
self.disable()
231+
raise Exception(msg)
232+
233+
234+
def _all_users(db):
235+
return {u["user"] for u in db.command("usersInfo").get("users", [])}
236+
237+
238+
def sanitize_cmd(cmd):
239+
cp = cmd.copy()
240+
cp.pop("$clusterTime", None)
241+
cp.pop("$db", None)
242+
cp.pop("$readPreference", None)
243+
cp.pop("lsid", None)
244+
if MONGODB_API_VERSION:
245+
# Stable API parameters
246+
cp.pop("apiVersion", None)
247+
# OP_MSG encoding may move the payload type one field to the
248+
# end of the command. Do the same here.
249+
name = next(iter(cp))
250+
try:
251+
identifier = message._FIELD_MAP[name]
252+
docs = cp.pop(identifier)
253+
cp[identifier] = docs
254+
except KeyError:
255+
pass
256+
return cp
257+
258+
259+
def sanitize_reply(reply):
260+
cp = reply.copy()
261+
cp.pop("$clusterTime", None)
262+
cp.pop("operationTime", None)
263+
return cp
264+
265+
266+
def print_thread_tracebacks() -> None:
267+
"""Print all Python thread tracebacks."""
268+
for thread_id, frame in sys._current_frames().items():
269+
sys.stderr.write(f"\n--- Traceback for thread {thread_id} ---\n")
270+
traceback.print_stack(frame, file=sys.stderr)
271+
272+
273+
def print_thread_stacks(pid: int) -> None:
274+
"""Print all C-level thread stacks for a given process id."""
275+
if sys.platform == "darwin":
276+
cmd = ["lldb", "--attach-pid", f"{pid}", "--batch", "--one-line", '"thread backtrace all"']
277+
else:
278+
cmd = ["gdb", f"--pid={pid}", "--batch", '--eval-command="thread apply all bt"']
279+
280+
try:
281+
res = subprocess.run(
282+
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8"
283+
)
284+
except Exception as exc:
285+
sys.stderr.write(f"Could not print C-level thread stacks because {cmd[0]} failed: {exc}")
286+
else:
287+
sys.stderr.write(res.stdout)
288+
289+
290+
# Global knobs to speed up the test suite.
291+
global_knobs = client_knobs(events_queue_frequency=0.05)
292+
293+
294+
def _get_executors(topology):
295+
executors = []
296+
for server in topology._servers.values():
297+
# Some MockMonitor do not have an _executor.
298+
if hasattr(server._monitor, "_executor"):
299+
executors.append(server._monitor._executor)
300+
if hasattr(server._monitor, "_rtt_monitor"):
301+
executors.append(server._monitor._rtt_monitor._executor)
302+
executors.append(topology._Topology__events_executor)
303+
if topology._srv_monitor:
304+
executors.append(topology._srv_monitor._executor)
305+
306+
return [e for e in executors if e is not None]
307+
308+
309+
def print_running_topology(topology):
310+
running = [e for e in _get_executors(topology) if not e._stopped]
311+
if running:
312+
print(
313+
"WARNING: found Topology with running threads:\n"
314+
f" Threads: {running}\n"
315+
f" Topology: {topology}\n"
316+
f" Creation traceback:\n{topology._settings._stack}"
317+
)
318+
319+
320+
def test_cases(suite):
321+
"""Iterator over all TestCases within a TestSuite."""
322+
for suite_or_case in suite._tests:
323+
if isinstance(suite_or_case, unittest.TestCase):
324+
# unittest.TestCase
325+
yield suite_or_case
326+
else:
327+
# unittest.TestSuite
328+
yield from test_cases(suite_or_case)
329+
330+
331+
# Helper method to workaround https://bugs.python.org/issue21724
332+
def clear_warning_registry():
333+
"""Clear the __warningregistry__ for all modules."""
334+
for _, module in list(sys.modules.items()):
335+
if hasattr(module, "__warningregistry__"):
336+
module.__warningregistry__ = {} # type:ignore[attr-defined]
337+
338+
339+
class SystemCertsPatcher:
340+
def __init__(self, ca_certs):
341+
if (
342+
ssl.OPENSSL_VERSION.lower().startswith("libressl")
343+
and sys.platform == "darwin"
344+
and not _ssl.IS_PYOPENSSL
345+
):
346+
raise SkipTest(
347+
"LibreSSL on OSX doesn't support setting CA certificates "
348+
"using SSL_CERT_FILE environment variable."
349+
)
350+
self.original_certs = os.environ.get("SSL_CERT_FILE")
351+
# Tell OpenSSL where CA certificates live.
352+
os.environ["SSL_CERT_FILE"] = ca_certs
353+
354+
def disable(self):
355+
if self.original_certs is None:
356+
os.environ.pop("SSL_CERT_FILE")
357+
else:
358+
os.environ["SSL_CERT_FILE"] = self.original_certs

0 commit comments

Comments
 (0)