diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index bfae302dac..d2b45fd64a 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -59,8 +59,8 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp -from pymongo import _csot, common, helpers_shared, uri_parser -from pymongo.asynchronous import client_session, database, periodic_executor +from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser +from pymongo.asynchronous import client_session, database from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream from pymongo.asynchronous.client_bulk import _AsyncClientBulk from pymongo.asynchronous.client_session import _EmptyServerSession @@ -908,7 +908,7 @@ async def target() -> bool: await AsyncMongoClient._process_periodic_tasks(client) return True - executor = periodic_executor.PeriodicExecutor( + executor = periodic_executor.AsyncPeriodicExecutor( interval=common.KILL_CURSOR_FREQUENCY, min_interval=common.MIN_HEARTBEAT_INTERVAL, target=target, diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index f9e912b084..bbfd6a2998 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -22,14 +22,13 @@ import weakref from typing import TYPE_CHECKING, Any, Mapping, Optional, cast -from pymongo import common +from pymongo import common, periodic_executor from pymongo._csot import MovingMinimum -from pymongo.asynchronous import periodic_executor -from pymongo.asynchronous.periodic_executor import _shutdown_executors from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled from pymongo.hello import Hello from pymongo.lock import _create_lock from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage +from pymongo.periodic_executor import _shutdown_executors from pymongo.pool_options import _is_faas from pymongo.read_preferences import MovingAverage from pymongo.server_description import ServerDescription @@ -76,7 +75,7 @@ async def target() -> bool: await monitor._run() # type:ignore[attr-defined] return True - executor = periodic_executor.PeriodicExecutor( + executor = periodic_executor.AsyncPeriodicExecutor( interval=interval, min_interval=min_interval, target=target, name=name ) @@ -112,9 +111,9 @@ async def close(self) -> None: """ self.gc_safe_close() - def join(self, timeout: Optional[int] = None) -> None: + async def join(self, timeout: Optional[int] = None) -> None: """Wait for the monitor to stop.""" - self._executor.join(timeout) + await self._executor.join(timeout) def request_check(self) -> None: """If the monitor is sleeping, wake it soon.""" @@ -139,7 +138,7 @@ def __init__( """ super().__init__( topology, - "pymongo_server_monitor_thread", + "pymongo_server_monitor_task", topology_settings.heartbeat_frequency, common.MIN_HEARTBEAT_INTERVAL, ) @@ -250,7 +249,7 @@ async def _check_server(self) -> ServerDescription: except (OperationFailure, NotPrimaryError) as exc: # Update max cluster time even when hello fails. details = cast(Mapping[str, Any], exc.details) - self._topology.receive_cluster_time(details.get("$clusterTime")) + await self._topology.receive_cluster_time(details.get("$clusterTime")) raise except ReferenceError: raise @@ -434,7 +433,7 @@ def __init__(self, topology: Topology, topology_settings: TopologySettings, pool """ super().__init__( topology, - "pymongo_server_rtt_thread", + "pymongo_server_rtt_task", topology_settings.heartbeat_frequency, common.MIN_HEARTBEAT_INTERVAL, ) @@ -531,4 +530,5 @@ def _shutdown_resources() -> None: shutdown() -atexit.register(_shutdown_resources) +if _IS_SYNC: + atexit.register(_shutdown_resources) diff --git a/pymongo/asynchronous/periodic_executor.py b/pymongo/asynchronous/periodic_executor.py deleted file mode 100644 index f3d2fddba3..0000000000 --- a/pymongo/asynchronous/periodic_executor.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2014-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Run a target function on a background thread.""" - -from __future__ import annotations - -import asyncio -import sys -import threading -import time -import weakref -from typing import Any, Optional - -from pymongo.lock import _ALock, _create_lock - -_IS_SYNC = False - - -class PeriodicExecutor: - def __init__( - self, - interval: float, - min_interval: float, - target: Any, - name: Optional[str] = None, - ): - """Run a target function periodically on a background thread. - - If the target's return value is false, the executor stops. - - :param interval: Seconds between calls to `target`. - :param min_interval: Minimum seconds between calls if `wake` is - called very often. - :param target: A function. - :param name: A name to give the underlying thread. - """ - # threading.Event and its internal condition variable are expensive - # in Python 2, see PYTHON-983. Use a boolean to know when to wake. - # The executor's design is constrained by several Python issues, see - # "periodic_executor.rst" in this repository. - self._event = False - self._interval = interval - self._min_interval = min_interval - self._target = target - self._stopped = False - self._thread: Optional[threading.Thread] = None - self._name = name - self._skip_sleep = False - self._thread_will_exit = False - self._lock = _ALock(_create_lock()) - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>" - - def _run_async(self) -> None: - # The default asyncio loop implementation on Windows - # has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240) - # We explicitly use a different loop implementation here to prevent that issue - if sys.platform == "win32": - loop = asyncio.SelectorEventLoop() - try: - loop.run_until_complete(self._run()) # type: ignore[func-returns-value] - finally: - loop.close() - else: - asyncio.run(self._run()) # type: ignore[func-returns-value] - - def open(self) -> None: - """Start. Multiple calls have no effect. - - Not safe to call from multiple threads at once. - """ - with self._lock: - if self._thread_will_exit: - # If the background thread has read self._stopped as True - # there is a chance that it has not yet exited. The call to - # join should not block indefinitely because there is no - # other work done outside the while loop in self._run. - try: - assert self._thread is not None - self._thread.join() - except ReferenceError: - # Thread terminated. - pass - self._thread_will_exit = False - self._stopped = False - started: Any = False - try: - started = self._thread and self._thread.is_alive() - except ReferenceError: - # Thread terminated. - pass - - if not started: - if _IS_SYNC: - thread = threading.Thread(target=self._run, name=self._name) - else: - thread = threading.Thread(target=self._run_async, name=self._name) - thread.daemon = True - self._thread = weakref.proxy(thread) - _register_executor(self) - # Mitigation to RuntimeError firing when thread starts on shutdown - # https://github.com/python/cpython/issues/114570 - try: - thread.start() - except RuntimeError as e: - if "interpreter shutdown" in str(e) or sys.is_finalizing(): - self._thread = None - return - raise - - def close(self, dummy: Any = None) -> None: - """Stop. To restart, call open(). - - The dummy parameter allows an executor's close method to be a weakref - callback; see monitor.py. - """ - self._stopped = True - - def join(self, timeout: Optional[int] = None) -> None: - if self._thread is not None: - try: - self._thread.join(timeout) - except (ReferenceError, RuntimeError): - # Thread already terminated, or not yet started. - pass - - def wake(self) -> None: - """Execute the target function soon.""" - self._event = True - - def update_interval(self, new_interval: int) -> None: - self._interval = new_interval - - def skip_sleep(self) -> None: - self._skip_sleep = True - - async def _should_stop(self) -> bool: - async with self._lock: - if self._stopped: - self._thread_will_exit = True - return True - return False - - async def _run(self) -> None: - while not await self._should_stop(): - try: - if not await self._target(): - self._stopped = True - break - except BaseException: - async with self._lock: - self._stopped = True - self._thread_will_exit = True - - raise - - if self._skip_sleep: - self._skip_sleep = False - else: - deadline = time.monotonic() + self._interval - while not self._stopped and time.monotonic() < deadline: - await asyncio.sleep(self._min_interval) - if self._event: - break # Early wake. - - self._event = False - - -# _EXECUTORS has a weakref to each running PeriodicExecutor. Once started, -# an executor is kept alive by a strong reference from its thread and perhaps -# from other objects. When the thread dies and all other referrers are freed, -# the executor is freed and removed from _EXECUTORS. If any threads are -# running when the interpreter begins to shut down, we try to halt and join -# them to avoid spurious errors. -_EXECUTORS = set() - - -def _register_executor(executor: PeriodicExecutor) -> None: - ref = weakref.ref(executor, _on_executor_deleted) - _EXECUTORS.add(ref) - - -def _on_executor_deleted(ref: weakref.ReferenceType[PeriodicExecutor]) -> None: - _EXECUTORS.remove(ref) - - -def _shutdown_executors() -> None: - if _EXECUTORS is None: - return - - # Copy the set. Stopping threads has the side effect of removing executors. - executors = list(_EXECUTORS) - - # First signal all executors to close... - for ref in executors: - executor = ref() - if executor: - executor.close() - - # ...then try to join them. - for ref in executors: - executor = ref() - if executor: - executor.join(1) - - executor = None diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 82af4257ba..f0cb56cbf1 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -27,8 +27,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast -from pymongo import _csot, common, helpers_shared -from pymongo.asynchronous import periodic_executor +from pymongo import _csot, common, helpers_shared, periodic_executor from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool from pymongo.asynchronous.monitor import SrvMonitor from pymongo.asynchronous.pool import Pool @@ -185,7 +184,7 @@ def __init__(self, topology_settings: TopologySettings): async def target() -> bool: return process_events_queue(weak) - executor = periodic_executor.PeriodicExecutor( + executor = periodic_executor.AsyncPeriodicExecutor( interval=common.EVENTS_QUEUE_FREQUENCY, min_interval=common.MIN_HEARTBEAT_INTERVAL, target=target, @@ -742,7 +741,7 @@ async def close(self) -> None: if self._publish_server or self._publish_tp: # Make sure the events executor thread is fully closed before publishing the remaining events self.__events_executor.close() - self.__events_executor.join(1) + await self.__events_executor.join(1) process_events_queue(weakref.ref(self._events)) # type: ignore[arg-type] @property diff --git a/pymongo/synchronous/periodic_executor.py b/pymongo/periodic_executor.py similarity index 69% rename from pymongo/synchronous/periodic_executor.py rename to pymongo/periodic_executor.py index 525268b14b..216a4457c7 100644 --- a/pymongo/synchronous/periodic_executor.py +++ b/pymongo/periodic_executor.py @@ -25,7 +25,96 @@ from pymongo.lock import _create_lock -_IS_SYNC = True +_IS_SYNC = False + + +class AsyncPeriodicExecutor: + def __init__( + self, + interval: float, + min_interval: float, + target: Any, + name: Optional[str] = None, + ): + """Run a target function periodically on a background task. + + If the target's return value is false, the executor stops. + + :param interval: Seconds between calls to `target`. + :param min_interval: Minimum seconds between calls if `wake` is + called very often. + :param target: A function. + :param name: A name to give the underlying task. + """ + self._event = False + self._interval = interval + self._min_interval = min_interval + self._target = target + self._stopped = False + self._task: Optional[asyncio.Task] = None + self._name = name + self._skip_sleep = False + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>" + + def open(self) -> None: + """Start. Multiple calls have no effect.""" + self._stopped = False + started = self._task and not self._task.done() + + if not started: + self._task = asyncio.get_event_loop().create_task(self._run(), name=self._name) + + def close(self, dummy: Any = None) -> None: + """Stop. To restart, call open(). + + The dummy parameter allows an executor's close method to be a weakref + callback; see monitor.py. + """ + self._stopped = True + + async def join(self, timeout: Optional[int] = None) -> None: + if self._task is not None: + try: + await asyncio.wait_for(self._task, timeout=timeout) # type-ignore: [arg-type] + except asyncio.TimeoutError: + # Task timed out + pass + except asyncio.exceptions.CancelledError: + # Task was already finished, or not yet started. + pass + + def wake(self) -> None: + """Execute the target function soon.""" + self._event = True + + def update_interval(self, new_interval: int) -> None: + self._interval = new_interval + + def skip_sleep(self) -> None: + self._skip_sleep = True + + async def _run(self) -> None: + while not self._stopped: + try: + if not await self._target(): + self._stopped = True + break + except BaseException: + self._stopped = True + raise + + if self._skip_sleep: + self._skip_sleep = False + else: + deadline = time.monotonic() + self._interval + while not self._stopped and time.monotonic() < deadline: + await asyncio.sleep(self._min_interval) + if self._event: + break # Early wake. + + self._event = False class PeriodicExecutor: @@ -64,19 +153,6 @@ def __init__( def __repr__(self) -> str: return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>" - def _run_async(self) -> None: - # The default asyncio loop implementation on Windows - # has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240) - # We explicitly use a different loop implementation here to prevent that issue - if sys.platform == "win32": - loop = asyncio.SelectorEventLoop() - try: - loop.run_until_complete(self._run()) # type: ignore[func-returns-value] - finally: - loop.close() - else: - asyncio.run(self._run()) # type: ignore[func-returns-value] - def open(self) -> None: """Start. Multiple calls have no effect. @@ -104,10 +180,7 @@ def open(self) -> None: pass if not started: - if _IS_SYNC: - thread = threading.Thread(target=self._run, name=self._name) - else: - thread = threading.Thread(target=self._run_async, name=self._name) + thread = threading.Thread(target=self._run, name=self._name) thread.daemon = True self._thread = weakref.proxy(thread) _register_executor(self) diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 1351cb200f..8f4d9cacf2 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -58,7 +58,7 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp -from pymongo import _csot, common, helpers_shared, uri_parser +from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser from pymongo.client_options import ClientOptions from pymongo.errors import ( AutoReconnect, @@ -91,7 +91,7 @@ from pymongo.results import ClientBulkWriteResult from pymongo.server_selectors import writable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import client_session, database, periodic_executor +from pymongo.synchronous import client_session, database from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream from pymongo.synchronous.client_bulk import _ClientBulk from pymongo.synchronous.client_session import _EmptyServerSession diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index 3f9bb2ea75..a806670f2c 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -22,18 +22,17 @@ import weakref from typing import TYPE_CHECKING, Any, Mapping, Optional, cast -from pymongo import common +from pymongo import common, periodic_executor from pymongo._csot import MovingMinimum from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled from pymongo.hello import Hello from pymongo.lock import _create_lock from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage +from pymongo.periodic_executor import _shutdown_executors from pymongo.pool_options import _is_faas from pymongo.read_preferences import MovingAverage from pymongo.server_description import ServerDescription from pymongo.srv_resolver import _SrvResolver -from pymongo.synchronous import periodic_executor -from pymongo.synchronous.periodic_executor import _shutdown_executors if TYPE_CHECKING: from pymongo.synchronous.pool import Connection, Pool, _CancellationContext @@ -531,4 +530,5 @@ def _shutdown_resources() -> None: shutdown() -atexit.register(_shutdown_resources) +if _IS_SYNC: + atexit.register(_shutdown_resources) diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index a350c1702e..e34de6bc50 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -27,7 +27,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast -from pymongo import _csot, common, helpers_shared +from pymongo import _csot, common, helpers_shared, periodic_executor from pymongo.errors import ( ConnectionFailure, InvalidOperation, @@ -56,7 +56,6 @@ secondary_server_selector, writable_server_selector, ) -from pymongo.synchronous import periodic_executor from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool from pymongo.synchronous.monitor import SrvMonitor from pymongo.synchronous.pool import Pool diff --git a/test/__init__.py b/test/__init__.py index af12bc032a..6be3b49ce6 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -859,6 +859,15 @@ def max_message_size_bytes(self): client_context = ClientContext() +def reset_client_context(): + if _IS_SYNC: + # sync tests don't need to reset a client context + return + client_context.client.close() + client_context.client = None + client_context._init_client() + + class PyMongoTestCase(unittest.TestCase): def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 2a44785b2f..1a386fe766 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -861,6 +861,15 @@ async def max_message_size_bytes(self): async_client_context = AsyncClientContext() +async def reset_client_context(): + if _IS_SYNC: + # sync tests don't need to reset a client context + return + await async_client_context.client.close() + async_client_context.client = None + await async_client_context._init_client() + + class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase): def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) diff --git a/test/asynchronous/conftest.py b/test/asynchronous/conftest.py index e443dff6c0..a27a9f213d 100644 --- a/test/asynchronous/conftest.py +++ b/test/asynchronous/conftest.py @@ -22,7 +22,7 @@ def event_loop_policy(): return asyncio.get_event_loop_policy() -@pytest_asyncio.fixture(scope="session", autouse=True) +@pytest_asyncio.fixture(scope="package", autouse=True) async def test_setup_and_teardown(): await async_setup() yield diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index faa23348c9..c4d71cdbe6 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -73,7 +73,6 @@ is_greenthread_patched, lazy_client_trial, one, - wait_until, ) import bson @@ -693,8 +692,8 @@ async def test_max_idle_time_reaper_removes_stale_minPoolSize(self): # When the reaper runs at the same time as the get_socket, two # connections could be created and checked into the pool. self.assertGreaterEqual(len(server._pool.conns), 1) - wait_until(lambda: conn not in server._pool.conns, "remove stale socket") - wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket") + await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket") + await async_wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket") async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): @@ -710,8 +709,8 @@ async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): # When the reaper runs at the same time as the get_socket, # maxPoolSize=1 should prevent two connections from being created. self.assertEqual(1, len(server._pool.conns)) - wait_until(lambda: conn not in server._pool.conns, "remove stale socket") - wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket") + await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket") + await async_wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket") async def test_max_idle_time_reaper_removes_stale(self): with client_knobs(kill_cursor_frequency=0.1): @@ -727,7 +726,7 @@ async def test_max_idle_time_reaper_removes_stale(self): async with server._pool.checkout() as conn_two: pass self.assertIs(conn_one, conn_two) - wait_until( + await async_wait_until( lambda: len(server._pool.conns) == 0, "stale socket reaped and new one NOT added to the pool", ) @@ -745,7 +744,7 @@ async def test_min_pool_size(self): server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) - wait_until( + await async_wait_until( lambda: len(server._pool.conns) == 10, "pool initialized with 10 connections", ) @@ -753,7 +752,7 @@ async def test_min_pool_size(self): # Assert that if a socket is closed, a new one takes its place async with server._pool.checkout() as conn: conn.close_conn(None) - wait_until( + await async_wait_until( lambda: len(server._pool.conns) == 10, "a closed socket gets replaced from the pool", ) @@ -941,8 +940,10 @@ async def test_repr(self): async with eval(the_repr) as client_two: self.assertEqual(client_two, client) - def test_getters(self): - wait_until(lambda: async_client_context.nodes == self.client.nodes, "find all nodes") + async def test_getters(self): + await async_wait_until( + lambda: async_client_context.nodes == self.client.nodes, "find all nodes" + ) async def test_list_databases(self): cmd_docs = (await self.client.admin.command("listDatabases"))["databases"] @@ -1067,14 +1068,21 @@ async def test_uri_connect_option(self): self.assertFalse(client._topology._opened) # Ensure kill cursors thread has not been started. - kc_thread = client._kill_cursors_executor._thread - self.assertFalse(kc_thread and kc_thread.is_alive()) - + if _IS_SYNC: + kc_thread = client._kill_cursors_executor._thread + self.assertFalse(kc_thread and kc_thread.is_alive()) + else: + kc_task = client._kill_cursors_executor._task + self.assertFalse(kc_task and not kc_task.done()) # Using the client should open topology and start the thread. await client.admin.command("ping") self.assertTrue(client._topology._opened) - kc_thread = client._kill_cursors_executor._thread - self.assertTrue(kc_thread and kc_thread.is_alive()) + if _IS_SYNC: + kc_thread = client._kill_cursors_executor._thread + self.assertTrue(kc_thread and kc_thread.is_alive()) + else: + kc_task = client._kill_cursors_executor._task + self.assertTrue(kc_task and not kc_task.done()) async def test_close_does_not_open_servers(self): client = await self.async_rs_client(connect=False) @@ -1610,7 +1618,7 @@ def init(self, *args): await async_client_context.port, ) await self.async_single_client(uri, event_listeners=[listener]) - wait_until( + await async_wait_until( lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents" ) @@ -1768,16 +1776,16 @@ async def test_background_connections_do_not_hold_locks(self): pool = await async_get_pool(client) original_connect = pool.connect - def stall_connect(*args, **kwargs): - time.sleep(2) - return original_connect(*args, **kwargs) + async def stall_connect(*args, **kwargs): + await asyncio.sleep(2) + return await original_connect(*args, **kwargs) pool.connect = stall_connect # Un-patch Pool.connect to break the cyclic reference. self.addCleanup(delattr, pool, "connect") # Wait for the background thread to start creating connections - wait_until(lambda: len(pool.conns) > 1, "start creating connections") + await async_wait_until(lambda: len(pool.conns) > 1, "start creating connections") # Assert that application operations do not block. for _ in range(10): @@ -1860,7 +1868,7 @@ async def test_process_periodic_tasks(self): await client.close() # Add cursor to kill cursors queue del cursor - wait_until( + await async_wait_until( lambda: client._kill_cursors_queue, "waited for cursor to be added to queue", ) @@ -2218,7 +2226,7 @@ async def test_exhaust_getmore_network_error(self): await cursor.to_list() self.assertTrue(conn.closed) - wait_until( + await async_wait_until( lambda: len(client._kill_cursors_queue) == 0, "waited for all killCursor requests to complete", ) @@ -2389,7 +2397,7 @@ async def test_discover_primary(self): ) self.addAsyncCleanup(c.close) - wait_until(lambda: len(c.nodes) == 3, "connect") + await async_wait_until(lambda: len(c.nodes) == 3, "connect") self.assertEqual(await c.address, ("a", 1)) # Fail over. @@ -2416,7 +2424,7 @@ async def test_reconnect(self): ) self.addAsyncCleanup(c.close) - wait_until(lambda: len(c.nodes) == 3, "connect") + await async_wait_until(lambda: len(c.nodes) == 3, "connect") # Total failure. c.kill_host("a:1") @@ -2458,7 +2466,7 @@ async def _test_network_error(self, operation_callback): c.set_wire_version_range("a:1", 2, MIN_SUPPORTED_WIRE_VERSION) c.set_wire_version_range("b:2", 2, MIN_SUPPORTED_WIRE_VERSION + 1) await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST) - wait_until(lambda: len(c.nodes) == 2, "connect") + await async_wait_until(lambda: len(c.nodes) == 2, "connect") c.kill_host("a:1") @@ -2530,11 +2538,11 @@ async def test_rs_client_does_not_maintain_pool_to_arbiters(self): ) self.addAsyncCleanup(c.close) - wait_until(lambda: len(c.nodes) == 3, "connect") + await async_wait_until(lambda: len(c.nodes) == 3, "connect") self.assertEqual(await c.address, ("a", 1)) self.assertEqual(await c.arbiters, {("c", 3)}) # Assert that we create 2 and only 2 pooled connections. - listener.wait_for_event(monitoring.ConnectionReadyEvent, 2) + await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 2) self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 2) # Assert that we do not create connections to arbiters. arbiter = c._topology.get_server_by_address(("c", 3)) @@ -2560,7 +2568,7 @@ async def test_direct_client_maintains_pool_to_arbiter(self): ) self.addAsyncCleanup(c.close) - wait_until(lambda: len(c.nodes) == 1, "connect") + await async_wait_until(lambda: len(c.nodes) == 1, "connect") self.assertEqual(await c.address, ("c", 3)) # Assert that we create 1 pooled connection. listener.wait_for_event(monitoring.ConnectionReadyEvent, 1) diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index 612090b69f..470425f4ce 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -39,7 +39,6 @@ async_get_pool, async_is_mongos, async_wait_until, - wait_until, ) from bson import encode @@ -1022,7 +1021,10 @@ async def test_replace_bypass_document_validation(self): await db.test.insert_one({"y": 1}, bypass_document_validation=True) await db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True) - await async_wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document") + async def predicate(): + return await db_w0.test.find_one({"x": 1}) + + await async_wait_until(predicate, "find w:0 replaced document") async def test_update_bypass_document_validation(self): db = self.db @@ -1870,7 +1872,7 @@ async def test_exhaust(self): await cur.close() cur = None # Wait until the background thread returns the socket. - wait_until(lambda: pool.active_sockets == 0, "return socket") + await async_wait_until(lambda: pool.active_sockets == 0, "return socket") # The socket should be discarded. self.assertEqual(0, len(pool.conns)) diff --git a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py index 289cf49751..dc04cb28a7 100644 --- a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py +++ b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py @@ -19,7 +19,12 @@ sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous import ( + AsyncIntegrationTest, + async_client_context, + reset_client_context, + unittest, +) from test.asynchronous.helpers import async_repl_set_step_down from test.utils import ( CMAPListener, @@ -60,6 +65,7 @@ async def _setup_class(cls): @classmethod async def _tearDown_class(cls): await cls.client.close() + await reset_client_context() async def asyncSetUp(self): # Note that all ops use same write-concern as self.db (majority). diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index e79ad00641..f7b795cdae 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -34,9 +34,9 @@ AllowListEventListener, EventListener, OvertCommandListener, + async_wait_until, delay, ignore_deprecations, - wait_until, ) from bson import decode_all @@ -1324,8 +1324,8 @@ async def test_timeout_kills_cursor_asynchronously(self): with self.assertRaises(ExecutionTimeout): await cursor.next() - def assertCursorKilled(): - wait_until( + async def assertCursorKilled(): + await async_wait_until( lambda: len(listener.succeeded_events), "find successful killCursors command", ) @@ -1335,7 +1335,7 @@ def assertCursorKilled(): self.assertEqual(1, len(listener.succeeded_events)) self.assertEqual("killCursors", listener.succeeded_events[0].command_name) - assertCursorKilled() + await assertCursorKilled() listener.reset() cursor = await coll.aggregate([], batchSize=1) @@ -1345,7 +1345,7 @@ def assertCursorKilled(): with self.assertRaises(ExecutionTimeout): await cursor.next() - assertCursorKilled() + await assertCursorKilled() def test_delete_not_initialized(self): # Creating a cursor with invalid arguments will not run __init__ diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index d264b5ecb0..c1dac6f56d 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -37,7 +37,6 @@ EventListener, ExceptionCatchingThread, async_wait_until, - wait_until, ) from bson import DBRef diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index b5d0686417..229046e79b 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -26,7 +26,7 @@ from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest from test.utils import ( OvertCommandListener, - wait_until, + async_wait_until, ) from typing import List @@ -162,7 +162,7 @@ async def test_unpin_for_next_transaction(self): client = await self.async_rs_client( async_client_context.mongos_seeds(), localThresholdMS=1000 ) - wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") + await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. await coll.insert_one({}) @@ -191,7 +191,7 @@ async def test_unpin_for_non_transaction_operation(self): client = await self.async_rs_client( async_client_context.mongos_seeds(), localThresholdMS=1000 ) - wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") + await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. await coll.insert_one({}) diff --git a/test/conftest.py b/test/conftest.py index a3d954c7c3..91fad28d0a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -20,7 +20,7 @@ def event_loop_policy(): return asyncio.get_event_loop_policy() -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="package", autouse=True) def test_setup_and_teardown(): setup() yield diff --git a/test/test_client.py b/test/test_client.py index be1994dd93..a4c521157b 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1041,14 +1041,21 @@ def test_uri_connect_option(self): self.assertFalse(client._topology._opened) # Ensure kill cursors thread has not been started. - kc_thread = client._kill_cursors_executor._thread - self.assertFalse(kc_thread and kc_thread.is_alive()) - + if _IS_SYNC: + kc_thread = client._kill_cursors_executor._thread + self.assertFalse(kc_thread and kc_thread.is_alive()) + else: + kc_task = client._kill_cursors_executor._task + self.assertFalse(kc_task and not kc_task.done()) # Using the client should open topology and start the thread. client.admin.command("ping") self.assertTrue(client._topology._opened) - kc_thread = client._kill_cursors_executor._thread - self.assertTrue(kc_thread and kc_thread.is_alive()) + if _IS_SYNC: + kc_thread = client._kill_cursors_executor._thread + self.assertTrue(kc_thread and kc_thread.is_alive()) + else: + kc_task = client._kill_cursors_executor._task + self.assertTrue(kc_task and not kc_task.done()) def test_close_does_not_open_servers(self): client = self.rs_client(connect=False) diff --git a/test/test_collection.py b/test/test_collection.py index a2c3b0b0b6..f2f01ac686 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -1009,7 +1009,10 @@ def test_replace_bypass_document_validation(self): db.test.insert_one({"y": 1}, bypass_document_validation=True) db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True) - wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document") + def predicate(): + return db_w0.test.find_one({"x": 1}) + + wait_until(predicate, "find w:0 replaced document") def test_update_bypass_document_validation(self): db = self.db diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index 54cc4e0482..984d700fb3 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -19,7 +19,12 @@ sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest +from test import ( + IntegrationTest, + client_context, + reset_client_context, + unittest, +) from test.helpers import repl_set_step_down from test.utils import ( CMAPListener, @@ -60,6 +65,7 @@ def _setup_class(cls): @classmethod def _tearDown_class(cls): cls.client.close() + reset_client_context() def setUp(self): # Note that all ops use same write-concern as self.db (majority). diff --git a/test/test_monitor.py b/test/test_monitor.py index f8e9443fae..a704f3d8cb 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -29,7 +29,7 @@ wait_until, ) -from pymongo.synchronous.periodic_executor import _EXECUTORS +from pymongo.periodic_executor import _EXECUTORS def unregistered(ref): diff --git a/test/utils.py b/test/utils.py index 9c78cff3ad..174b1708ba 100644 --- a/test/utils.py +++ b/test/utils.py @@ -98,6 +98,12 @@ def wait_for_event(self, event, count): """Wait for a number of events to be published, or fail.""" wait_until(lambda: self.event_count(event) >= count, f"find {count} {event} event(s)") + async def async_wait_for_event(self, event, count): + """Wait for a number of events to be published, or fail.""" + await async_wait_until( + lambda: self.event_count(event) >= count, f"find {count} {event} event(s)" + ) + class CMAPListener(BaseListener, monitoring.ConnectionPoolListener): def connection_created(self, event): @@ -789,7 +795,10 @@ async def async_wait_until(predicate, success_description, timeout=10): start = time.time() interval = min(float(timeout) / 100, 0.1) while True: - retval = await predicate() + if iscoroutinefunction(predicate): + retval = await predicate() + else: + retval = predicate() if retval: return retval diff --git a/tools/synchro.py b/tools/synchro.py index 0ec8985a05..c3c0b568ed 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -108,6 +108,10 @@ "async_set_fail_point": "set_fail_point", "async_ensure_all_connected": "ensure_all_connected", "async_repl_set_step_down": "repl_set_step_down", + "AsyncPeriodicExecutor": "PeriodicExecutor", + "async_wait_for_event": "wait_for_event", + "pymongo_server_monitor_task": "pymongo_server_monitor_thread", + "pymongo_server_rtt_task": "pymongo_server_rtt_thread", } docstring_replacements: dict[tuple[str, str], str] = {