Skip to content

Commit 13d5ee2

Browse files
committed
refactor helper methods and constants
1 parent 69697a7 commit 13d5ee2

14 files changed

+21
-715
lines changed

test/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959

6060
sys.path[0:0] = [""]
6161

62-
from test.helpers import (
62+
from test.helpers_shared import (
6363
COMPRESSORS,
6464
IS_SRV,
6565
MONGODB_API_VERSION,

test/asynchronous/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959

6060
sys.path[0:0] = [""]
6161

62-
from test.helpers import (
62+
from test.helpers_shared import (
6363
COMPRESSORS,
6464
IS_SRV,
6565
MONGODB_API_VERSION,

test/asynchronous/helpers.py

Lines changed: 3 additions & 350 deletions
Original file line numberDiff line numberDiff line change
@@ -12,156 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Shared constants and helper methods for pymongo, bson, and gridfs test suites."""
15+
"""Shared helper methods for pymongo, bson, and gridfs test suites."""
1616
from __future__ import annotations
1717

1818
import asyncio
19-
import base64
20-
import gc
21-
import multiprocessing
22-
import os
23-
import signal
24-
import socket
25-
import subprocess
26-
import sys
2719
import threading
28-
import time
29-
import traceback
30-
import unittest
31-
import warnings
32-
from inspect import iscoroutinefunction
33-
from pathlib import Path
20+
from typing import Optional
3421

22+
from bson import SON
3523
from pymongo._asyncio_task import create_task
36-
37-
try:
38-
import ipaddress
39-
40-
HAVE_IPADDRESS = True
41-
except ImportError:
42-
HAVE_IPADDRESS = False
43-
from functools import wraps
44-
from typing import Any, Callable, Dict, Generator, Optional, no_type_check
45-
from unittest import SkipTest
46-
47-
from bson.son import SON
48-
from pymongo import common, message
4924
from pymongo.read_preferences import ReadPreference
50-
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
51-
from pymongo.synchronous.uri_parser import parse_uri
52-
53-
if HAVE_SSL:
54-
import ssl
5525

5626
_IS_SYNC = False
5727

58-
# Enable debug output for uncollectable objects. PyPy does not have set_debug.
59-
if hasattr(gc, "set_debug"):
60-
gc.set_debug(
61-
gc.DEBUG_UNCOLLECTABLE | getattr(gc, "DEBUG_OBJECTS", 0) | getattr(gc, "DEBUG_INSTANCES", 0)
62-
)
63-
64-
# The host and port of a single mongod or mongos, or the seed host
65-
# for a replica set.
66-
host = os.environ.get("DB_IP", "localhost")
67-
port = int(os.environ.get("DB_PORT", 27017))
68-
IS_SRV = "mongodb+srv" in host
69-
70-
db_user = os.environ.get("DB_USER", "user")
71-
db_pwd = os.environ.get("DB_PASSWORD", "password")
72-
73-
HERE = Path(__file__).absolute()
74-
if _IS_SYNC:
75-
CERT_PATH = str(HERE.parent / "certificates")
76-
else:
77-
CERT_PATH = str(HERE.parent.parent / "certificates")
78-
CLIENT_PEM = os.environ.get("CLIENT_PEM", os.path.join(CERT_PATH, "client.pem"))
79-
CA_PEM = os.environ.get("CA_PEM", os.path.join(CERT_PATH, "ca.pem"))
80-
81-
TLS_OPTIONS: Dict = {"tls": True}
82-
if CLIENT_PEM:
83-
TLS_OPTIONS["tlsCertificateKeyFile"] = CLIENT_PEM
84-
if CA_PEM:
85-
TLS_OPTIONS["tlsCAFile"] = CA_PEM
86-
87-
COMPRESSORS = os.environ.get("COMPRESSORS")
88-
MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION")
89-
TEST_LOADBALANCER = bool(os.environ.get("TEST_LOAD_BALANCER"))
90-
SINGLE_MONGOS_LB_URI = os.environ.get("SINGLE_MONGOS_LB_URI")
91-
MULTI_MONGOS_LB_URI = os.environ.get("MULTI_MONGOS_LB_URI")
92-
93-
if TEST_LOADBALANCER:
94-
res = parse_uri(SINGLE_MONGOS_LB_URI or "")
95-
host, port = res["nodelist"][0]
96-
db_user = res["username"] or db_user
97-
db_pwd = res["password"] or db_pwd
98-
99-
100-
# Shared KMS data.
101-
LOCAL_MASTER_KEY = base64.b64decode(
102-
b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ"
103-
b"5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk"
104-
)
105-
AWS_CREDS = {
106-
"accessKeyId": os.environ.get("FLE_AWS_KEY", ""),
107-
"secretAccessKey": os.environ.get("FLE_AWS_SECRET", ""),
108-
}
109-
AWS_CREDS_2 = {
110-
"accessKeyId": os.environ.get("FLE_AWS_KEY2", ""),
111-
"secretAccessKey": os.environ.get("FLE_AWS_SECRET2", ""),
112-
}
113-
AZURE_CREDS = {
114-
"tenantId": os.environ.get("FLE_AZURE_TENANTID", ""),
115-
"clientId": os.environ.get("FLE_AZURE_CLIENTID", ""),
116-
"clientSecret": os.environ.get("FLE_AZURE_CLIENTSECRET", ""),
117-
}
118-
GCP_CREDS = {
119-
"email": os.environ.get("FLE_GCP_EMAIL", ""),
120-
"privateKey": os.environ.get("FLE_GCP_PRIVATEKEY", ""),
121-
}
122-
KMIP_CREDS = {"endpoint": os.environ.get("FLE_KMIP_ENDPOINT", "localhost:5698")}
123-
AWS_TEMP_CREDS = {
124-
"accessKeyId": os.environ.get("CSFLE_AWS_TEMP_ACCESS_KEY_ID", ""),
125-
"secretAccessKey": os.environ.get("CSFLE_AWS_TEMP_SECRET_ACCESS_KEY", ""),
126-
"sessionToken": os.environ.get("CSFLE_AWS_TEMP_SESSION_TOKEN", ""),
127-
}
128-
129-
ALL_KMS_PROVIDERS = dict(
130-
aws=AWS_CREDS,
131-
azure=AZURE_CREDS,
132-
gcp=GCP_CREDS,
133-
local=dict(key=LOCAL_MASTER_KEY),
134-
kmip=KMIP_CREDS,
135-
)
136-
DEFAULT_KMS_TLS = dict(kmip=dict(tlsCAFile=CA_PEM, tlsCertificateKeyFile=CLIENT_PEM))
137-
138-
# Ensure Evergreen metadata doesn't result in truncation
139-
os.environ.setdefault("MONGOB_LOG_MAX_DOCUMENT_LENGTH", "2000")
140-
141-
142-
def is_server_resolvable():
143-
"""Returns True if 'server' is resolvable."""
144-
socket_timeout = socket.getdefaulttimeout()
145-
socket.setdefaulttimeout(1)
146-
try:
147-
try:
148-
socket.gethostbyname("server")
149-
return True
150-
except OSError:
151-
return False
152-
finally:
153-
socket.setdefaulttimeout(socket_timeout)
154-
155-
156-
def _create_user(authdb, user, pwd=None, roles=None, **kwargs):
157-
cmd = SON([("createUser", user)])
158-
# X509 doesn't use a password
159-
if pwd:
160-
cmd["pwd"] = pwd
161-
cmd["roles"] = roles or ["root"]
162-
cmd.update(**kwargs)
163-
return authdb.command(cmd)
164-
16528

16629
async def async_repl_set_step_down(client, **kwargs):
16730
"""Run replSetStepDown, first unfreezing a secondary with replSetFreeze."""
@@ -173,216 +36,6 @@ async def async_repl_set_step_down(client, **kwargs):
17336
await client.admin.command(cmd)
17437

17538

176-
class client_knobs:
177-
def __init__(
178-
self,
179-
heartbeat_frequency=None,
180-
min_heartbeat_interval=None,
181-
kill_cursor_frequency=None,
182-
events_queue_frequency=None,
183-
):
184-
self.heartbeat_frequency = heartbeat_frequency
185-
self.min_heartbeat_interval = min_heartbeat_interval
186-
self.kill_cursor_frequency = kill_cursor_frequency
187-
self.events_queue_frequency = events_queue_frequency
188-
189-
self.old_heartbeat_frequency = None
190-
self.old_min_heartbeat_interval = None
191-
self.old_kill_cursor_frequency = None
192-
self.old_events_queue_frequency = None
193-
self._enabled = False
194-
self._stack = None
195-
196-
def enable(self):
197-
self.old_heartbeat_frequency = common.HEARTBEAT_FREQUENCY
198-
self.old_min_heartbeat_interval = common.MIN_HEARTBEAT_INTERVAL
199-
self.old_kill_cursor_frequency = common.KILL_CURSOR_FREQUENCY
200-
self.old_events_queue_frequency = common.EVENTS_QUEUE_FREQUENCY
201-
202-
if self.heartbeat_frequency is not None:
203-
common.HEARTBEAT_FREQUENCY = self.heartbeat_frequency
204-
205-
if self.min_heartbeat_interval is not None:
206-
common.MIN_HEARTBEAT_INTERVAL = self.min_heartbeat_interval
207-
208-
if self.kill_cursor_frequency is not None:
209-
common.KILL_CURSOR_FREQUENCY = self.kill_cursor_frequency
210-
211-
if self.events_queue_frequency is not None:
212-
common.EVENTS_QUEUE_FREQUENCY = self.events_queue_frequency
213-
self._enabled = True
214-
# Store the allocation traceback to catch non-disabled client_knobs.
215-
self._stack = "".join(traceback.format_stack())
216-
217-
def __enter__(self):
218-
self.enable()
219-
220-
@no_type_check
221-
def disable(self):
222-
common.HEARTBEAT_FREQUENCY = self.old_heartbeat_frequency
223-
common.MIN_HEARTBEAT_INTERVAL = self.old_min_heartbeat_interval
224-
common.KILL_CURSOR_FREQUENCY = self.old_kill_cursor_frequency
225-
common.EVENTS_QUEUE_FREQUENCY = self.old_events_queue_frequency
226-
self._enabled = False
227-
228-
def __exit__(self, exc_type, exc_val, exc_tb):
229-
self.disable()
230-
231-
def __call__(self, func):
232-
def make_wrapper(f):
233-
@wraps(f)
234-
async def wrap(*args, **kwargs):
235-
with self:
236-
return await f(*args, **kwargs)
237-
238-
return wrap
239-
240-
return make_wrapper(func)
241-
242-
def __del__(self):
243-
if self._enabled:
244-
msg = (
245-
"ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY={}, "
246-
"MIN_HEARTBEAT_INTERVAL={}, KILL_CURSOR_FREQUENCY={}, "
247-
"EVENTS_QUEUE_FREQUENCY={}, stack:\n{}".format(
248-
common.HEARTBEAT_FREQUENCY,
249-
common.MIN_HEARTBEAT_INTERVAL,
250-
common.KILL_CURSOR_FREQUENCY,
251-
common.EVENTS_QUEUE_FREQUENCY,
252-
self._stack,
253-
)
254-
)
255-
self.disable()
256-
raise Exception(msg)
257-
258-
259-
def _all_users(db):
260-
return {u["user"] for u in db.command("usersInfo").get("users", [])}
261-
262-
263-
def sanitize_cmd(cmd):
264-
cp = cmd.copy()
265-
cp.pop("$clusterTime", None)
266-
cp.pop("$db", None)
267-
cp.pop("$readPreference", None)
268-
cp.pop("lsid", None)
269-
if MONGODB_API_VERSION:
270-
# Stable API parameters
271-
cp.pop("apiVersion", None)
272-
# OP_MSG encoding may move the payload type one field to the
273-
# end of the command. Do the same here.
274-
name = next(iter(cp))
275-
try:
276-
identifier = message._FIELD_MAP[name]
277-
docs = cp.pop(identifier)
278-
cp[identifier] = docs
279-
except KeyError:
280-
pass
281-
return cp
282-
283-
284-
def sanitize_reply(reply):
285-
cp = reply.copy()
286-
cp.pop("$clusterTime", None)
287-
cp.pop("operationTime", None)
288-
return cp
289-
290-
291-
def print_thread_tracebacks() -> None:
292-
"""Print all Python thread tracebacks."""
293-
for thread_id, frame in sys._current_frames().items():
294-
sys.stderr.write(f"\n--- Traceback for thread {thread_id} ---\n")
295-
traceback.print_stack(frame, file=sys.stderr)
296-
297-
298-
def print_thread_stacks(pid: int) -> None:
299-
"""Print all C-level thread stacks for a given process id."""
300-
if sys.platform == "darwin":
301-
cmd = ["lldb", "--attach-pid", f"{pid}", "--batch", "--one-line", '"thread backtrace all"']
302-
else:
303-
cmd = ["gdb", f"--pid={pid}", "--batch", '--eval-command="thread apply all bt"']
304-
305-
try:
306-
res = subprocess.run(
307-
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8"
308-
)
309-
except Exception as exc:
310-
sys.stderr.write(f"Could not print C-level thread stacks because {cmd[0]} failed: {exc}")
311-
else:
312-
sys.stderr.write(res.stdout)
313-
314-
315-
# Global knobs to speed up the test suite.
316-
global_knobs = client_knobs(events_queue_frequency=0.05)
317-
318-
319-
def _get_executors(topology):
320-
executors = []
321-
for server in topology._servers.values():
322-
# Some MockMonitor do not have an _executor.
323-
if hasattr(server._monitor, "_executor"):
324-
executors.append(server._monitor._executor)
325-
if hasattr(server._monitor, "_rtt_monitor"):
326-
executors.append(server._monitor._rtt_monitor._executor)
327-
executors.append(topology._Topology__events_executor)
328-
if topology._srv_monitor:
329-
executors.append(topology._srv_monitor._executor)
330-
331-
return [e for e in executors if e is not None]
332-
333-
334-
def print_running_topology(topology):
335-
running = [e for e in _get_executors(topology) if not e._stopped]
336-
if running:
337-
print(
338-
"WARNING: found Topology with running threads:\n"
339-
f" Threads: {running}\n"
340-
f" Topology: {topology}\n"
341-
f" Creation traceback:\n{topology._settings._stack}"
342-
)
343-
344-
345-
def test_cases(suite):
346-
"""Iterator over all TestCases within a TestSuite."""
347-
for suite_or_case in suite._tests:
348-
if isinstance(suite_or_case, unittest.TestCase):
349-
# unittest.TestCase
350-
yield suite_or_case
351-
else:
352-
# unittest.TestSuite
353-
yield from test_cases(suite_or_case)
354-
355-
356-
# Helper method to workaround https://bugs.python.org/issue21724
357-
def clear_warning_registry():
358-
"""Clear the __warningregistry__ for all modules."""
359-
for _, module in list(sys.modules.items()):
360-
if hasattr(module, "__warningregistry__"):
361-
module.__warningregistry__ = {} # type:ignore[attr-defined]
362-
363-
364-
class SystemCertsPatcher:
365-
def __init__(self, ca_certs):
366-
if (
367-
ssl.OPENSSL_VERSION.lower().startswith("libressl")
368-
and sys.platform == "darwin"
369-
and not _ssl.IS_PYOPENSSL
370-
):
371-
raise SkipTest(
372-
"LibreSSL on OSX doesn't support setting CA certificates "
373-
"using SSL_CERT_FILE environment variable."
374-
)
375-
self.original_certs = os.environ.get("SSL_CERT_FILE")
376-
# Tell OpenSSL where CA certificates live.
377-
os.environ["SSL_CERT_FILE"] = ca_certs
378-
379-
def disable(self):
380-
if self.original_certs is None:
381-
os.environ.pop("SSL_CERT_FILE")
382-
else:
383-
os.environ["SSL_CERT_FILE"] = self.original_certs
384-
385-
38639
if _IS_SYNC:
38740
PARENT = threading.Thread
38841
else:

0 commit comments

Comments
 (0)