|
17 | 17 |
|
18 | 18 | import asyncio |
19 | 19 | import gc |
| 20 | +import inspect |
20 | 21 | import logging |
21 | 22 | import multiprocessing |
22 | 23 | import os |
|
30 | 31 | import unittest |
31 | 32 | import warnings |
32 | 33 | from asyncio import iscoroutinefunction |
33 | | -from test.helpers import ( |
34 | | - COMPRESSORS, |
35 | | - IS_SRV, |
36 | | - MONGODB_API_VERSION, |
37 | | - MULTI_MONGOS_LB_URI, |
38 | | - TEST_LOADBALANCER, |
39 | | - TEST_SERVERLESS, |
40 | | - TLS_OPTIONS, |
41 | | - SystemCertsPatcher, |
42 | | - client_knobs, |
43 | | - db_pwd, |
44 | | - db_user, |
45 | | - global_knobs, |
46 | | - host, |
47 | | - is_server_resolvable, |
48 | | - port, |
49 | | - print_running_topology, |
50 | | - print_thread_stacks, |
51 | | - print_thread_tracebacks, |
52 | | - sanitize_cmd, |
53 | | - sanitize_reply, |
54 | | -) |
55 | 34 |
|
56 | 35 | from pymongo.uri_parser import parse_uri |
57 | 36 |
|
|
63 | 42 | HAVE_IPADDRESS = False |
64 | 43 | from contextlib import asynccontextmanager, contextmanager |
65 | 44 | from functools import partial, wraps |
66 | | -from test.version import Version |
67 | 45 | from typing import Any, Callable, Dict, Generator, overload |
68 | 46 | from unittest import SkipTest |
69 | 47 | from urllib.parse import quote_plus |
|
78 | 56 | from pymongo.server_api import ServerApi |
79 | 57 | from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] |
80 | 58 |
|
| 59 | +sys.path[0:0] = [""] |
| 60 | + |
| 61 | +from test.helpers import ( |
| 62 | + COMPRESSORS, |
| 63 | + IS_SRV, |
| 64 | + MONGODB_API_VERSION, |
| 65 | + MULTI_MONGOS_LB_URI, |
| 66 | + TEST_LOADBALANCER, |
| 67 | + TEST_SERVERLESS, |
| 68 | + TLS_OPTIONS, |
| 69 | + SystemCertsPatcher, |
| 70 | + client_knobs, |
| 71 | + db_pwd, |
| 72 | + db_user, |
| 73 | + global_knobs, |
| 74 | + host, |
| 75 | + is_server_resolvable, |
| 76 | + port, |
| 77 | + print_running_topology, |
| 78 | + print_thread_stacks, |
| 79 | + print_thread_tracebacks, |
| 80 | + sanitize_cmd, |
| 81 | + sanitize_reply, |
| 82 | +) |
| 83 | +from test.version import Version |
| 84 | + |
81 | 85 | _IS_SYNC = False |
82 | 86 |
|
83 | 87 |
|
@@ -865,18 +869,66 @@ async def max_message_size_bytes(self): |
865 | 869 | # Reusable client context |
866 | 870 | async_client_context = AsyncClientContext() |
867 | 871 |
|
| 872 | +# Global event loop for async tests. |
| 873 | +LOOP = None |
| 874 | + |
| 875 | + |
| 876 | +def get_loop() -> asyncio.AbstractEventLoop: |
| 877 | + """Get the test suite's global event loop.""" |
| 878 | + global LOOP |
| 879 | + if LOOP is None: |
| 880 | + try: |
| 881 | + LOOP = asyncio.get_running_loop() |
| 882 | + except RuntimeError: |
| 883 | + # no running event loop, fallback to get_event_loop. |
| 884 | + try: |
| 885 | + # Ignore DeprecationWarning: There is no current event loop |
| 886 | + with warnings.catch_warnings(): |
| 887 | + warnings.simplefilter("ignore", DeprecationWarning) |
| 888 | + LOOP = asyncio.get_event_loop() |
| 889 | + except RuntimeError: |
| 890 | + LOOP = asyncio.new_event_loop() |
| 891 | + asyncio.set_event_loop(LOOP) |
| 892 | + return LOOP |
| 893 | + |
| 894 | + |
| 895 | +class AsyncPyMongoTestCase(unittest.TestCase): |
| 896 | + if not _IS_SYNC: |
| 897 | + # An async TestCase that uses a single event loop for all tests. |
| 898 | + # Inspired by IsolatedAsyncioTestCase. |
| 899 | + async def asyncSetUp(self): |
| 900 | + pass |
868 | 901 |
|
869 | | -async def reset_client_context(): |
870 | | - if _IS_SYNC: |
871 | | - # sync tests don't need to reset a client context |
872 | | - return |
873 | | - elif async_client_context.client is not None: |
874 | | - await async_client_context.client.close() |
875 | | - async_client_context.client = None |
876 | | - await async_client_context._init_client() |
| 902 | + async def asyncTearDown(self): |
| 903 | + pass |
877 | 904 |
|
| 905 | + def addAsyncCleanup(self, func, /, *args, **kwargs): |
| 906 | + self.addCleanup(*(func, *args), **kwargs) |
| 907 | + |
| 908 | + def _callSetUp(self): |
| 909 | + self.setUp() |
| 910 | + self._callAsync(self.asyncSetUp) |
| 911 | + |
| 912 | + def _callTestMethod(self, method): |
| 913 | + self._callMaybeAsync(method) |
| 914 | + |
| 915 | + def _callTearDown(self): |
| 916 | + self._callAsync(self.asyncTearDown) |
| 917 | + self.tearDown() |
| 918 | + |
| 919 | + def _callCleanup(self, function, *args, **kwargs): |
| 920 | + self._callMaybeAsync(function, *args, **kwargs) |
| 921 | + |
| 922 | + def _callAsync(self, func, /, *args, **kwargs): |
| 923 | + assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function" |
| 924 | + return get_loop().run_until_complete(func(*args, **kwargs)) |
| 925 | + |
| 926 | + def _callMaybeAsync(self, func, /, *args, **kwargs): |
| 927 | + if inspect.iscoroutinefunction(func): |
| 928 | + return get_loop().run_until_complete(func(*args, **kwargs)) |
| 929 | + else: |
| 930 | + return func(*args, **kwargs) |
878 | 931 |
|
879 | | -class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase): |
880 | 932 | def assertEqualCommand(self, expected, actual, msg=None): |
881 | 933 | self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) |
882 | 934 |
|
@@ -1154,8 +1206,6 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase): |
1154 | 1206 |
|
1155 | 1207 | @async_client_context.require_connection |
1156 | 1208 | async def asyncSetUp(self) -> None: |
1157 | | - if not _IS_SYNC: |
1158 | | - await reset_client_context() |
1159 | 1209 | if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): |
1160 | 1210 | raise SkipTest("this test does not support load balancers") |
1161 | 1211 | if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False): |
@@ -1204,6 +1254,9 @@ async def asyncTearDown(self) -> None: |
1204 | 1254 |
|
1205 | 1255 |
|
1206 | 1256 | async def async_setup(): |
| 1257 | + if not _IS_SYNC: |
| 1258 | + # Set up the event loop. |
| 1259 | + get_loop() |
1207 | 1260 | await async_client_context.init() |
1208 | 1261 | warnings.resetwarnings() |
1209 | 1262 | warnings.simplefilter("always") |
|
0 commit comments