diff --git a/.evergreen/config.yml b/.evergreen/config.yml index 59b8a543fd..7ca3a72b1a 100644 --- a/.evergreen/config.yml +++ b/.evergreen/config.yml @@ -281,7 +281,7 @@ functions: "run tests": - command: subprocess.exec params: - include_expansions_in_env: ["TEST_DATA_LAKE", "AUTH", "SSL", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE"] + include_expansions_in_env: ["TEST_DATA_LAKE", "PYTHON_BINARY", "AUTH", "SSL", "TEST_INDEX_MANAGEMENT", "CRYPT_SHARED_LIB_PATH", "test_encryption", "test_encryption_pyopenssl", "test_crypt_shared", "test_pyopenssl", "test_loadbalancer", "test_serverless", "ORCHESTRATION_FILE"] binary: bash working_dir: "src" args: diff --git a/THIRD-PARTY-NOTICES b/THIRD-PARTY-NOTICES index 55b8ff7078..ad00831a2a 100644 --- a/THIRD-PARTY-NOTICES +++ b/THIRD-PARTY-NOTICES @@ -38,3 +38,61 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +2) License Notice for _asyncio_lock.py +----------------------------------------- + +1. This LICENSE AGREEMENT is between the Python Software Foundation +("PSF"), and the Individual or Organization ("Licensee") accessing and +otherwise using this software ("Python") in source or binary form and +its associated documentation. + +2. Subject to the terms and conditions of this License Agreement, PSF hereby +grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +analyze, test, perform and/or display publicly, prepare derivative works, +distribute, and otherwise use Python alone or in any derivative version, +provided, however, that PSF's License Agreement and PSF's notice of copyright, +i.e., "Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved" +are retained in Python alone or in any derivative version prepared by Licensee. + +3. In the event Licensee prepares a derivative work that is based on +or incorporates Python or any part thereof, and wants to make +the derivative work available to others as provided herein, then +Licensee hereby agrees to include in any such work a brief summary of +the changes made to Python. + +4. PSF is making Python available to Licensee on an "AS IS" +basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +INFRINGE ANY THIRD PARTY RIGHTS. + +5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. + +6. This License Agreement will automatically terminate upon a material +breach of its terms and conditions. + +7. Nothing in this License Agreement shall be deemed to create any +relationship of agency, partnership, or joint venture between PSF and +Licensee. This License Agreement does not grant permission to use PSF +trademarks or trade name in a trademark sense to endorse or promote +products or services of Licensee, or any third party. + +8. By copying, installing or otherwise using Python, Licensee +agrees to be bound by the terms and conditions of this License +Agreement. + +Permission to use, copy, modify, and/or distribute this software for any +purpose with or without fee is hereby granted. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH +REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, +INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR +OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +PERFORMANCE OF THIS SOFTWARE. diff --git a/pymongo/_asyncio_lock.py b/pymongo/_asyncio_lock.py new file mode 100644 index 0000000000..669b0f63a7 --- /dev/null +++ b/pymongo/_asyncio_lock.py @@ -0,0 +1,309 @@ +# Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved + +"""Lock and Condition classes vendored from https://github.com/python/cpython/blob/main/Lib/asyncio/locks.py +to port 3.13 fixes to older versions of Python. +Can be removed once we drop Python 3.12 support.""" + +from __future__ import annotations + +import collections +import threading +from asyncio import events, exceptions +from typing import Any, Coroutine, Optional + +_global_lock = threading.Lock() + + +class _LoopBoundMixin: + _loop = None + + def _get_loop(self) -> Any: + loop = events._get_running_loop() + + if self._loop is None: + with _global_lock: + if self._loop is None: + self._loop = loop + if loop is not self._loop: + raise RuntimeError(f"{self!r} is bound to a different event loop") + return loop + + +class _ContextManagerMixin: + async def __aenter__(self) -> None: + await self.acquire() # type: ignore[attr-defined] + # We have no use for the "as ..." clause in the with + # statement for locks. + return + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + self.release() # type: ignore[attr-defined] + + +class Lock(_ContextManagerMixin, _LoopBoundMixin): + """Primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned + by a particular task when locked. A primitive lock is in one + of two states, 'locked' or 'unlocked'. + + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() + changes the state to locked and returns immediately. When the + state is locked, acquire() blocks until a call to release() in + another task changes it to unlocked, then the acquire() call + resets it to locked and returns. The release() method should only + be called in the locked state; it changes the state to unlocked + and returns immediately. If an attempt is made to release an + unlocked lock, a RuntimeError will be raised. + + When more than one task is blocked in acquire() waiting for + the state to turn to unlocked, only one task proceeds when a + release() call resets the state to unlocked; successive release() + calls will unblock tasks in FIFO order. + + Locks also support the asynchronous context management protocol. + 'async with lock' statement should be used. + + Usage: + + lock = Lock() + ... + await lock.acquire() + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + async with lock: + ... + + Lock objects can be tested for locking state: + + if not lock.locked(): + await lock.acquire() + else: + # lock is acquired + ... + + """ + + def __init__(self) -> None: + self._waiters: Optional[collections.deque] = None + self._locked = False + + def __repr__(self) -> str: + res = super().__repr__() + extra = "locked" if self._locked else "unlocked" + if self._waiters: + extra = f"{extra}, waiters:{len(self._waiters)}" + return f"<{res[1:-1]} [{extra}]>" + + def locked(self) -> bool: + """Return True if lock is acquired.""" + return self._locked + + async def acquire(self) -> bool: + """Acquire a lock. + + This method blocks until the lock is unlocked, then sets it to + locked and returns True. + """ + # Implement fair scheduling, where thread always waits + # its turn. Jumping the queue if all are cancelled is an optimization. + if not self._locked and ( + self._waiters is None or all(w.cancelled() for w in self._waiters) + ): + self._locked = True + return True + + if self._waiters is None: + self._waiters = collections.deque() + fut = self._get_loop().create_future() + self._waiters.append(fut) + + try: + try: + await fut + finally: + self._waiters.remove(fut) + except exceptions.CancelledError: + # Currently the only exception designed be able to occur here. + + # Ensure the lock invariant: If lock is not claimed (or about + # to be claimed by us) and there is a Task in waiters, + # ensure that the Task at the head will run. + if not self._locked: + self._wake_up_first() + raise + + # assert self._locked is False + self._locked = True + return True + + def release(self) -> None: + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other tasks are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + self._wake_up_first() + else: + raise RuntimeError("Lock is not acquired.") + + def _wake_up_first(self) -> None: + """Ensure that the first waiter will wake up.""" + if not self._waiters: + return + try: + fut = next(iter(self._waiters)) + except StopIteration: + return + + # .done() means that the waiter is already set to wake up. + if not fut.done(): + fut.set_result(True) + + +class Condition(_ContextManagerMixin, _LoopBoundMixin): + """Asynchronous equivalent to threading.Condition. + + This class implements condition variable objects. A condition variable + allows one or more tasks to wait until they are notified by another + task. + + A new Lock object is created and used as the underlying lock. + """ + + def __init__(self, lock: Optional[Lock] = None) -> None: + if lock is None: + lock = Lock() + + self._lock = lock + # Export the lock's locked(), acquire() and release() methods. + self.locked = lock.locked + self.acquire = lock.acquire + self.release = lock.release + + self._waiters: collections.deque = collections.deque() + + def __repr__(self) -> str: + res = super().__repr__() + extra = "locked" if self.locked() else "unlocked" + if self._waiters: + extra = f"{extra}, waiters:{len(self._waiters)}" + return f"<{res[1:-1]} [{extra}]>" + + async def wait(self) -> bool: + """Wait until notified. + + If the calling task has not acquired the lock when this + method is called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another task. Once + awakened, it re-acquires the lock and returns True. + + This method may return spuriously, + which is why the caller should always + re-check the state and be prepared to wait() again. + """ + if not self.locked(): + raise RuntimeError("cannot wait on un-acquired lock") + + fut = self._get_loop().create_future() + self.release() + try: + try: + self._waiters.append(fut) + try: + await fut + return True + finally: + self._waiters.remove(fut) + + finally: + # Must re-acquire lock even if wait is cancelled. + # We only catch CancelledError here, since we don't want any + # other (fatal) errors with the future to cause us to spin. + err = None + while True: + try: + await self.acquire() + break + except exceptions.CancelledError as e: + err = e + + if err is not None: + try: + raise err # Re-raise most recent exception instance. + finally: + err = None # Break reference cycles. + except BaseException: + # Any error raised out of here _may_ have occurred after this Task + # believed to have been successfully notified. + # Make sure to notify another Task instead. This may result + # in a "spurious wakeup", which is allowed as part of the + # Condition Variable protocol. + self._notify(1) + raise + + async def wait_for(self, predicate: Any) -> Coroutine: + """Wait until a predicate becomes true. + + The predicate should be a callable whose result will be + interpreted as a boolean value. The method will repeatedly + wait() until it evaluates to true. The final predicate value is + the return value. + """ + result = predicate() + while not result: + await self.wait() + result = predicate() + return result + + def notify(self, n: int = 1) -> None: + """By default, wake up one task waiting on this condition, if any. + If the calling task has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up n of the tasks waiting for the condition + variable; if fewer than n are waiting, they are all awoken. + + Note: an awakened task does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self.locked(): + raise RuntimeError("cannot notify on un-acquired lock") + self._notify(n) + + def _notify(self, n: int) -> None: + idx = 0 + for fut in self._waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self) -> None: + """Wake up all tasks waiting on this condition. This method acts + like notify(), but wakes up all waiting tasks instead of one. If the + calling task has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._waiters)) diff --git a/pymongo/_asyncio_task.py b/pymongo/_asyncio_task.py new file mode 100644 index 0000000000..8e457763d9 --- /dev/null +++ b/pymongo/_asyncio_task.py @@ -0,0 +1,49 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A custom asyncio.Task that allows checking if a task has been sent a cancellation request. +Can be removed once we drop Python 3.10 support in favor of asyncio.Task.cancelling.""" + + +from __future__ import annotations + +import asyncio +import sys +from typing import Any, Coroutine, Optional + + +# TODO (https://jira.mongodb.org/browse/PYTHON-4981): Revisit once the underlying cause of the swallowed cancellations is uncovered +class _Task(asyncio.Task): + def __init__(self, coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> None: + super().__init__(coro, name=name) + self._cancel_requests = 0 + asyncio._register_task(self) + + def cancel(self, msg: Optional[str] = None) -> bool: + self._cancel_requests += 1 + return super().cancel(msg=msg) + + def uncancel(self) -> int: + if self._cancel_requests > 0: + self._cancel_requests -= 1 + return self._cancel_requests + + def cancelling(self) -> int: + return self._cancel_requests + + +def create_task(coro: Coroutine[Any, Any, Any], *, name: Optional[str] = None) -> asyncio.Task: + if sys.version_info >= (3, 11): + return asyncio.create_task(coro, name=name) + return _Task(coro, name=name) diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 0dcdaa6c07..45824256da 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -476,7 +476,6 @@ async def _process_results_cursor( if op_type == "delete": res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment] full_result[f"{op_type}Results"][original_index] = res - except Exception as exc: # Attempt to close the cursor, then raise top-level error. if cmd_cursor.alive: diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 4b4bb52a8e..7d7ae4a5db 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -45,7 +45,7 @@ ) from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure -from pymongo.lock import _ALock, _create_lock +from pymongo.lock import _async_create_lock from pymongo.message import ( _CursorAddress, _GetMore, @@ -77,7 +77,7 @@ class _ConnectionManager: def __init__(self, conn: AsyncConnection, more_to_come: bool): self.conn: Optional[AsyncConnection] = conn self.more_to_come = more_to_come - self._alock = _ALock(_create_lock()) + self._lock = _async_create_lock() def update_exhaust(self, more_to_come: bool) -> None: self.more_to_come = more_to_come diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index 735e543047..4802c3f54e 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -15,6 +15,7 @@ """Support for explicit client-side field level encryption.""" from __future__ import annotations +import asyncio import contextlib import enum import socket @@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]: # BSON encoding/decoding errors are unrelated to encryption so # we should propagate them unchanged. raise + except asyncio.CancelledError: + raise except Exception as exc: raise EncryptionError(exc) from exc @@ -200,6 +203,8 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None: conn.close() except (PyMongoError, MongoCryptError): raise # Propagate pymongo errors directly. + except asyncio.CancelledError: + raise except Exception as error: # Wrap I/O errors in PyMongo exceptions. _raise_connection_failure((host, port), error) @@ -722,6 +727,8 @@ async def create_encrypted_collection( await database.create_collection(name=name, **kwargs), encrypted_fields, ) + except asyncio.CancelledError: + raise except Exception as exc: raise EncryptedCollectionError(exc, encrypted_fields) from exc diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 3e4dc482d7..1600e50628 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -32,6 +32,7 @@ """ from __future__ import annotations +import asyncio import contextlib import os import warnings @@ -59,8 +60,8 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp -from pymongo import _csot, common, helpers_shared, uri_parser -from pymongo.asynchronous import client_session, database, periodic_executor +from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser +from pymongo.asynchronous import client_session, database from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream from pymongo.asynchronous.client_bulk import _AsyncClientBulk from pymongo.asynchronous.client_session import _EmptyServerSession @@ -82,7 +83,11 @@ WaitQueueTimeoutError, WriteConcernError, ) -from pymongo.lock import _HAS_REGISTER_AT_FORK, _ALock, _create_lock, _release_locks +from pymongo.lock import ( + _HAS_REGISTER_AT_FORK, + _async_create_lock, + _release_locks, +) from pymongo.logger import _CLIENT_LOGGER, _log_or_warn from pymongo.message import _CursorAddress, _GetMore, _Query from pymongo.monitoring import ConnectionClosedReason @@ -842,7 +847,7 @@ def __init__( self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC) self._default_database_name = dbase - self._lock = _ALock(_create_lock()) + self._lock = _async_create_lock() self._kill_cursors_queue: list = [] self._event_listeners = options.pool_options._event_listeners @@ -908,7 +913,7 @@ async def target() -> bool: await AsyncMongoClient._process_periodic_tasks(client) return True - executor = periodic_executor.PeriodicExecutor( + executor = periodic_executor.AsyncPeriodicExecutor( interval=common.KILL_CURSOR_FREQUENCY, min_interval=common.MIN_HEARTBEAT_INTERVAL, target=target, @@ -1722,7 +1727,7 @@ async def _run_operation( address=address, ) - async with operation.conn_mgr._alock: + async with operation.conn_mgr._lock: async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] err_handler.contribute_socket(operation.conn_mgr.conn) return await server.run_operation( @@ -1970,7 +1975,7 @@ async def _close_cursor_now( try: if conn_mgr: - async with conn_mgr._alock: + async with conn_mgr._lock: # Cursor is pinned to LB outside of a transaction. assert address is not None assert conn_mgr.conn is not None @@ -2033,6 +2038,8 @@ async def _process_kill_cursors(self) -> None: for address, cursor_id, conn_mgr in pinned_cursors: try: await self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False) + except asyncio.CancelledError: + raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: # Raise the exception when client is closed so that it @@ -2047,6 +2054,8 @@ async def _process_kill_cursors(self) -> None: for address, cursor_ids in address_to_cursor_ids.items(): try: await self._kill_cursors(cursor_ids, address, topology, session=None) + except asyncio.CancelledError: + raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: raise @@ -2061,6 +2070,8 @@ async def _process_periodic_tasks(self) -> None: try: await self._process_kill_cursors() await self._topology.update_pool() + except asyncio.CancelledError: + raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: return diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index a4dc9b7f45..ad1bc70aba 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -16,20 +16,20 @@ from __future__ import annotations +import asyncio import atexit import logging import time import weakref from typing import TYPE_CHECKING, Any, Mapping, Optional, cast -from pymongo import common +from pymongo import common, periodic_executor from pymongo._csot import MovingMinimum -from pymongo.asynchronous import periodic_executor -from pymongo.asynchronous.periodic_executor import _shutdown_executors from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled from pymongo.hello import Hello -from pymongo.lock import _create_lock +from pymongo.lock import _async_create_lock from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage +from pymongo.periodic_executor import _shutdown_executors from pymongo.pool_options import _is_faas from pymongo.read_preferences import MovingAverage from pymongo.server_description import ServerDescription @@ -76,7 +76,7 @@ async def target() -> bool: await monitor._run() # type:ignore[attr-defined] return True - executor = periodic_executor.PeriodicExecutor( + executor = periodic_executor.AsyncPeriodicExecutor( interval=interval, min_interval=min_interval, target=target, name=name ) @@ -112,9 +112,9 @@ async def close(self) -> None: """ self.gc_safe_close() - def join(self, timeout: Optional[int] = None) -> None: + async def join(self, timeout: Optional[int] = None) -> None: """Wait for the monitor to stop.""" - self._executor.join(timeout) + await self._executor.join(timeout) def request_check(self) -> None: """If the monitor is sleeping, wake it soon.""" @@ -139,7 +139,7 @@ def __init__( """ super().__init__( topology, - "pymongo_server_monitor_thread", + "pymongo_server_monitor_task", topology_settings.heartbeat_frequency, common.MIN_HEARTBEAT_INTERVAL, ) @@ -238,6 +238,9 @@ async def _run(self) -> None: except ReferenceError: # Topology was garbage-collected. await self.close() + finally: + if self._executor._stopped: + await self._rtt_monitor.close() async def _check_server(self) -> ServerDescription: """Call hello or read the next streaming response. @@ -252,8 +255,10 @@ async def _check_server(self) -> ServerDescription: except (OperationFailure, NotPrimaryError) as exc: # Update max cluster time even when hello fails. details = cast(Mapping[str, Any], exc.details) - self._topology.receive_cluster_time(details.get("$clusterTime")) + await self._topology.receive_cluster_time(details.get("$clusterTime")) raise + except asyncio.CancelledError: + raise except ReferenceError: raise except Exception as error: @@ -280,7 +285,7 @@ async def _check_server(self) -> ServerDescription: await self._reset_connection() if isinstance(error, _OperationCancelled): raise - self._rtt_monitor.reset() + await self._rtt_monitor.reset() # Server type defaults to Unknown. return ServerDescription(address, error=error) @@ -321,9 +326,9 @@ async def _check_once(self) -> ServerDescription: self._conn_id = conn.id response, round_trip_time = await self._check_with_socket(conn) if not response.awaitable: - self._rtt_monitor.add_sample(round_trip_time) + await self._rtt_monitor.add_sample(round_trip_time) - avg_rtt, min_rtt = self._rtt_monitor.get() + avg_rtt, min_rtt = await self._rtt_monitor.get() sd = ServerDescription(address, response, avg_rtt, min_round_trip_time=min_rtt) if self._publish: assert self._listeners is not None @@ -419,6 +424,8 @@ def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: if len(seedlist) == 0: # As per the spec: this should be treated as a failure. raise Exception + except asyncio.CancelledError: + raise except Exception: # As per the spec, upon encountering an error: # - An error must not be raised @@ -439,7 +446,7 @@ def __init__(self, topology: Topology, topology_settings: TopologySettings, pool """ super().__init__( topology, - "pymongo_server_rtt_thread", + "pymongo_server_rtt_task", topology_settings.heartbeat_frequency, common.MIN_HEARTBEAT_INTERVAL, ) @@ -447,7 +454,7 @@ def __init__(self, topology: Topology, topology_settings: TopologySettings, pool self._pool = pool self._moving_average = MovingAverage() self._moving_min = MovingMinimum() - self._lock = _create_lock() + self._lock = _async_create_lock() async def close(self) -> None: self.gc_safe_close() @@ -455,20 +462,20 @@ async def close(self) -> None: # thread has the socket checked out, it will be closed when checked in. await self._pool.reset() - def add_sample(self, sample: float) -> None: + async def add_sample(self, sample: float) -> None: """Add a RTT sample.""" - with self._lock: + async with self._lock: self._moving_average.add_sample(sample) self._moving_min.add_sample(sample) - def get(self) -> tuple[Optional[float], float]: + async def get(self) -> tuple[Optional[float], float]: """Get the calculated average, or None if no samples yet and the min.""" - with self._lock: + async with self._lock: return self._moving_average.get(), self._moving_min.get() - def reset(self) -> None: + async def reset(self) -> None: """Reset the average RTT.""" - with self._lock: + async with self._lock: self._moving_average.reset() self._moving_min.reset() @@ -478,10 +485,12 @@ async def _run(self) -> None: # heartbeat protocol (MongoDB 4.4+). # XXX: Skip check if the server is unknown? rtt = await self._ping() - self.add_sample(rtt) + await self.add_sample(rtt) except ReferenceError: # Topology was garbage-collected. await self.close() + except asyncio.CancelledError: + raise except Exception: await self._pool.reset() @@ -536,4 +545,5 @@ def _shutdown_resources() -> None: shutdown() -atexit.register(_shutdown_resources) +if _IS_SYNC: + atexit.register(_shutdown_resources) diff --git a/pymongo/asynchronous/periodic_executor.py b/pymongo/asynchronous/periodic_executor.py deleted file mode 100644 index f3d2fddba3..0000000000 --- a/pymongo/asynchronous/periodic_executor.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2014-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Run a target function on a background thread.""" - -from __future__ import annotations - -import asyncio -import sys -import threading -import time -import weakref -from typing import Any, Optional - -from pymongo.lock import _ALock, _create_lock - -_IS_SYNC = False - - -class PeriodicExecutor: - def __init__( - self, - interval: float, - min_interval: float, - target: Any, - name: Optional[str] = None, - ): - """Run a target function periodically on a background thread. - - If the target's return value is false, the executor stops. - - :param interval: Seconds between calls to `target`. - :param min_interval: Minimum seconds between calls if `wake` is - called very often. - :param target: A function. - :param name: A name to give the underlying thread. - """ - # threading.Event and its internal condition variable are expensive - # in Python 2, see PYTHON-983. Use a boolean to know when to wake. - # The executor's design is constrained by several Python issues, see - # "periodic_executor.rst" in this repository. - self._event = False - self._interval = interval - self._min_interval = min_interval - self._target = target - self._stopped = False - self._thread: Optional[threading.Thread] = None - self._name = name - self._skip_sleep = False - self._thread_will_exit = False - self._lock = _ALock(_create_lock()) - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>" - - def _run_async(self) -> None: - # The default asyncio loop implementation on Windows - # has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240) - # We explicitly use a different loop implementation here to prevent that issue - if sys.platform == "win32": - loop = asyncio.SelectorEventLoop() - try: - loop.run_until_complete(self._run()) # type: ignore[func-returns-value] - finally: - loop.close() - else: - asyncio.run(self._run()) # type: ignore[func-returns-value] - - def open(self) -> None: - """Start. Multiple calls have no effect. - - Not safe to call from multiple threads at once. - """ - with self._lock: - if self._thread_will_exit: - # If the background thread has read self._stopped as True - # there is a chance that it has not yet exited. The call to - # join should not block indefinitely because there is no - # other work done outside the while loop in self._run. - try: - assert self._thread is not None - self._thread.join() - except ReferenceError: - # Thread terminated. - pass - self._thread_will_exit = False - self._stopped = False - started: Any = False - try: - started = self._thread and self._thread.is_alive() - except ReferenceError: - # Thread terminated. - pass - - if not started: - if _IS_SYNC: - thread = threading.Thread(target=self._run, name=self._name) - else: - thread = threading.Thread(target=self._run_async, name=self._name) - thread.daemon = True - self._thread = weakref.proxy(thread) - _register_executor(self) - # Mitigation to RuntimeError firing when thread starts on shutdown - # https://github.com/python/cpython/issues/114570 - try: - thread.start() - except RuntimeError as e: - if "interpreter shutdown" in str(e) or sys.is_finalizing(): - self._thread = None - return - raise - - def close(self, dummy: Any = None) -> None: - """Stop. To restart, call open(). - - The dummy parameter allows an executor's close method to be a weakref - callback; see monitor.py. - """ - self._stopped = True - - def join(self, timeout: Optional[int] = None) -> None: - if self._thread is not None: - try: - self._thread.join(timeout) - except (ReferenceError, RuntimeError): - # Thread already terminated, or not yet started. - pass - - def wake(self) -> None: - """Execute the target function soon.""" - self._event = True - - def update_interval(self, new_interval: int) -> None: - self._interval = new_interval - - def skip_sleep(self) -> None: - self._skip_sleep = True - - async def _should_stop(self) -> bool: - async with self._lock: - if self._stopped: - self._thread_will_exit = True - return True - return False - - async def _run(self) -> None: - while not await self._should_stop(): - try: - if not await self._target(): - self._stopped = True - break - except BaseException: - async with self._lock: - self._stopped = True - self._thread_will_exit = True - - raise - - if self._skip_sleep: - self._skip_sleep = False - else: - deadline = time.monotonic() + self._interval - while not self._stopped and time.monotonic() < deadline: - await asyncio.sleep(self._min_interval) - if self._event: - break # Early wake. - - self._event = False - - -# _EXECUTORS has a weakref to each running PeriodicExecutor. Once started, -# an executor is kept alive by a strong reference from its thread and perhaps -# from other objects. When the thread dies and all other referrers are freed, -# the executor is freed and removed from _EXECUTORS. If any threads are -# running when the interpreter begins to shut down, we try to halt and join -# them to avoid spurious errors. -_EXECUTORS = set() - - -def _register_executor(executor: PeriodicExecutor) -> None: - ref = weakref.ref(executor, _on_executor_deleted) - _EXECUTORS.add(ref) - - -def _on_executor_deleted(ref: weakref.ReferenceType[PeriodicExecutor]) -> None: - _EXECUTORS.remove(ref) - - -def _shutdown_executors() -> None: - if _EXECUTORS is None: - return - - # Copy the set. Stopping threads has the side effect of removing executors. - executors = list(_EXECUTORS) - - # First signal all executors to close... - for ref in executors: - executor = ref() - if executor: - executor.close() - - # ...then try to join them. - for ref in executors: - executor = ref() - if executor: - executor.join(1) - - executor = None diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index ca0cebd417..5dc5675a0a 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -23,7 +23,6 @@ import socket import ssl import sys -import threading import time import weakref from typing import ( @@ -65,7 +64,11 @@ _CertificateError, ) from pymongo.hello import Hello, HelloCompat -from pymongo.lock import _ACondition, _ALock, _create_lock +from pymongo.lock import ( + _async_cond_wait, + _async_create_condition, + _async_create_lock, +) from pymongo.logger import ( _CONNECTION_LOGGER, _ConnectionStatusMessage, @@ -208,11 +211,6 @@ def _raise_connection_failure( raise AutoReconnect(msg) from error -async def _cond_wait(condition: _ACondition, deadline: Optional[float]) -> bool: - timeout = deadline - time.monotonic() if deadline else None - return await condition.wait(timeout) - - def _get_timeout_details(options: PoolOptions) -> dict[str, float]: details = {} timeout = _csot.get_timeout() @@ -706,6 +704,8 @@ def _close_conn(self) -> None: # shutdown. try: self.conn.close() + except asyncio.CancelledError: + raise except Exception: # noqa: S110 pass @@ -992,8 +992,8 @@ def __init__( # from the right side. self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() - _lock = _create_lock() - self.lock = _ALock(_lock) + self.lock = _async_create_lock() + self._max_connecting_cond = _async_create_condition(self.lock) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 @@ -1019,7 +1019,7 @@ def __init__( # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue - self.size_cond = _ACondition(threading.Condition(_lock)) + self.size_cond = _async_create_condition(self.lock) self.requests = 0 self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: @@ -1027,7 +1027,7 @@ def __init__( # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue - self._max_connecting_cond = _ACondition(threading.Condition(_lock)) + self._max_connecting_cond = _async_create_condition(self.lock) self._max_connecting = self.opts.max_connecting self._pending = 0 self._client_id = client_id @@ -1466,7 +1466,8 @@ async def _get_conn( async with self.size_cond: self._raise_if_not_ready(checkout_started_time, emit_event=True) while not (self.requests < self.max_pool_size): - if not await _cond_wait(self.size_cond, deadline): + timeout = deadline - time.monotonic() if deadline else None + if not await _async_cond_wait(self.size_cond, timeout): # Timed out, notify the next thread to ensure a # timeout doesn't consume the condition. if self.requests < self.max_pool_size: @@ -1489,7 +1490,8 @@ async def _get_conn( async with self._max_connecting_cond: self._raise_if_not_ready(checkout_started_time, emit_event=False) while not (self.conns or self._pending < self._max_connecting): - if not await _cond_wait(self._max_connecting_cond, deadline): + timeout = deadline - time.monotonic() if deadline else None + if not await _async_cond_wait(self._max_connecting_cond, timeout): # Timed out, notify the next thread to ensure a # timeout doesn't consume the condition. if self.conns or self._pending < self._max_connecting: diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index 82af4257ba..6d67710a7e 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -27,8 +27,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast -from pymongo import _csot, common, helpers_shared -from pymongo.asynchronous import periodic_executor +from pymongo import _csot, common, helpers_shared, periodic_executor from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool from pymongo.asynchronous.monitor import SrvMonitor from pymongo.asynchronous.pool import Pool @@ -44,7 +43,11 @@ WriteError, ) from pymongo.hello import Hello -from pymongo.lock import _ACondition, _ALock, _create_lock +from pymongo.lock import ( + _async_cond_wait, + _async_create_condition, + _async_create_lock, +) from pymongo.logger import ( _SDAM_LOGGER, _SERVER_SELECTION_LOGGER, @@ -170,9 +173,10 @@ def __init__(self, topology_settings: TopologySettings): self._seed_addresses = list(topology_description.server_descriptions()) self._opened = False self._closed = False - _lock = _create_lock() - self._lock = _ALock(_lock) - self._condition = _ACondition(self._settings.condition_class(_lock)) + self._lock = _async_create_lock() + self._condition = _async_create_condition( + self._lock, self._settings.condition_class if _IS_SYNC else None + ) self._servers: dict[_Address, Server] = {} self._pid: Optional[int] = None self._max_cluster_time: Optional[ClusterTime] = None @@ -185,7 +189,7 @@ def __init__(self, topology_settings: TopologySettings): async def target() -> bool: return process_events_queue(weak) - executor = periodic_executor.PeriodicExecutor( + executor = periodic_executor.AsyncPeriodicExecutor( interval=common.EVENTS_QUEUE_FREQUENCY, min_interval=common.MIN_HEARTBEAT_INTERVAL, target=target, @@ -354,7 +358,7 @@ async def _select_servers_loop( # change, or for a timeout. We won't miss any changes that # came after our most recent apply_selector call, since we've # held the lock until now. - await self._condition.wait(common.MIN_HEARTBEAT_INTERVAL) + await _async_cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL) self._description.check_compatible() now = time.monotonic() server_descriptions = self._description.apply_selector( @@ -654,7 +658,7 @@ async def request_check_all(self, wait_time: int = 5) -> None: """Wake all monitors, wait for at least one to check its server.""" async with self._lock: self._request_check_all() - await self._condition.wait(wait_time) + await _async_cond_wait(self._condition, wait_time) def data_bearing_servers(self) -> list[ServerDescription]: """Return a list of all data-bearing servers. @@ -742,7 +746,7 @@ async def close(self) -> None: if self._publish_server or self._publish_tp: # Make sure the events executor thread is fully closed before publishing the remaining events self.__events_executor.close() - self.__events_executor.join(1) + await self.__events_executor.join(1) process_events_queue(weakref.ref(self._events)) # type: ignore[arg-type] @property diff --git a/pymongo/lock.py b/pymongo/lock.py index 0cbfb4a57e..6bf7138017 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -11,15 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""Internal helpers for lock and condition coordination primitives.""" + from __future__ import annotations import asyncio -import collections import os +import sys import threading -import time import weakref -from typing import Any, Callable, Optional, TypeVar +from asyncio import wait_for +from typing import Any, Optional, TypeVar + +import pymongo._asyncio_lock _HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork") @@ -28,6 +33,15 @@ _T = TypeVar("_T") +# Needed to support 3.13 asyncio fixes (https://github.com/python/cpython/issues/112202) +# in older versions of Python +if sys.version_info >= (3, 13): + Lock = asyncio.Lock + Condition = asyncio.Condition +else: + Lock = pymongo._asyncio_lock.Lock + Condition = pymongo._asyncio_lock.Condition + def _create_lock() -> threading.Lock: """Represents a lock that is tracked upon instantiation using a WeakSet and @@ -39,6 +53,27 @@ def _create_lock() -> threading.Lock: return lock +def _async_create_lock() -> Lock: + """Represents an asyncio.Lock.""" + return Lock() + + +def _create_condition( + lock: threading.Lock, condition_class: Optional[Any] = None +) -> threading.Condition: + """Represents a threading.Condition.""" + if condition_class: + return condition_class(lock) + return threading.Condition(lock) + + +def _async_create_condition(lock: Lock, condition_class: Optional[Any] = None) -> Condition: + """Represents an asyncio.Condition.""" + if condition_class: + return condition_class(lock) + return Condition(lock) + + def _release_locks() -> None: # Completed the fork, reset all the locks in the child. for lock in _forkable_locks: @@ -46,202 +81,12 @@ def _release_locks() -> None: lock.release() -# Needed only for synchro.py compat. -def _Lock(lock: threading.Lock) -> threading.Lock: - return lock +async def _async_cond_wait(condition: Condition, timeout: Optional[float]) -> bool: + try: + return await wait_for(condition.wait(), timeout) + except asyncio.TimeoutError: + return False -class _ALock: - __slots__ = ("_lock",) - - def __init__(self, lock: threading.Lock) -> None: - self._lock = lock - - def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: - return self._lock.acquire(blocking=blocking, timeout=timeout) - - async def a_acquire(self, blocking: bool = True, timeout: float = -1) -> bool: - if timeout > 0: - tstart = time.monotonic() - while True: - acquired = self._lock.acquire(blocking=False) - if acquired: - return True - if timeout > 0 and (time.monotonic() - tstart) > timeout: - return False - if not blocking: - return False - await asyncio.sleep(0) - - def release(self) -> None: - self._lock.release() - - async def __aenter__(self) -> _ALock: - await self.a_acquire() - return self - - def __enter__(self) -> _ALock: - self._lock.acquire() - return self - - def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - self.release() - - async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - self.release() - - -def _safe_set_result(fut: asyncio.Future) -> None: - # Ensure the future hasn't been cancelled before calling set_result. - if not fut.done(): - fut.set_result(False) - - -class _ACondition: - __slots__ = ("_condition", "_waiters") - - def __init__(self, condition: threading.Condition) -> None: - self._condition = condition - self._waiters: collections.deque = collections.deque() - - async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: - if timeout > 0: - tstart = time.monotonic() - while True: - acquired = self._condition.acquire(blocking=False) - if acquired: - return True - if timeout > 0 and (time.monotonic() - tstart) > timeout: - return False - if not blocking: - return False - await asyncio.sleep(0) - - async def wait(self, timeout: Optional[float] = None) -> bool: - """Wait until notified. - - If the calling task has not acquired the lock when this - method is called, a RuntimeError is raised. - - This method releases the underlying lock, and then blocks - until it is awakened by a notify() or notify_all() call for - the same condition variable in another task. Once - awakened, it re-acquires the lock and returns True. - - This method may return spuriously, - which is why the caller should always - re-check the state and be prepared to wait() again. - """ - loop = asyncio.get_running_loop() - fut = loop.create_future() - self._waiters.append((loop, fut)) - self.release() - try: - try: - try: - await asyncio.wait_for(fut, timeout) - return True - except asyncio.TimeoutError: - return False # Return false on timeout for sync pool compat. - finally: - # Must re-acquire lock even if wait is cancelled. - # We only catch CancelledError here, since we don't want any - # other (fatal) errors with the future to cause us to spin. - err = None - while True: - try: - await self.acquire() - break - except asyncio.exceptions.CancelledError as e: - err = e - - self._waiters.remove((loop, fut)) - if err is not None: - try: - raise err # Re-raise most recent exception instance. - finally: - err = None # Break reference cycles. - except BaseException: - # Any error raised out of here _may_ have occurred after this Task - # believed to have been successfully notified. - # Make sure to notify another Task instead. This may result - # in a "spurious wakeup", which is allowed as part of the - # Condition Variable protocol. - self.notify(1) - raise - - async def wait_for(self, predicate: Callable[[], _T]) -> _T: - """Wait until a predicate becomes true. - - The predicate should be a callable whose result will be - interpreted as a boolean value. The method will repeatedly - wait() until it evaluates to true. The final predicate value is - the return value. - """ - result = predicate() - while not result: - await self.wait() - result = predicate() - return result - - def notify(self, n: int = 1) -> None: - """By default, wake up one coroutine waiting on this condition, if any. - If the calling coroutine has not acquired the lock when this method - is called, a RuntimeError is raised. - - This method wakes up at most n of the coroutines waiting for the - condition variable; it is a no-op if no coroutines are waiting. - - Note: an awakened coroutine does not actually return from its - wait() call until it can reacquire the lock. Since notify() does - not release the lock, its caller should. - """ - idx = 0 - to_remove = [] - for loop, fut in self._waiters: - if idx >= n: - break - - if fut.done(): - continue - - try: - loop.call_soon_threadsafe(_safe_set_result, fut) - except RuntimeError: - # Loop was closed, ignore. - to_remove.append((loop, fut)) - continue - - idx += 1 - - for waiter in to_remove: - self._waiters.remove(waiter) - - def notify_all(self) -> None: - """Wake up all threads waiting on this condition. This method acts - like notify(), but wakes up all waiting threads instead of one. If the - calling thread has not acquired the lock when this method is called, - a RuntimeError is raised. - """ - self.notify(len(self._waiters)) - - def locked(self) -> bool: - """Only needed for tests in test_locks.""" - return self._condition._lock.locked() # type: ignore[attr-defined] - - def release(self) -> None: - self._condition.release() - - async def __aenter__(self) -> _ACondition: - await self.acquire() - return self - - def __enter__(self) -> _ACondition: - self._condition.acquire() - return self - - async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - self.release() - - def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - self.release() +def _cond_wait(condition: threading.Condition, timeout: Optional[float]) -> bool: + return condition.wait(timeout) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index aa16e85a07..6ab6db2f7d 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -29,6 +29,7 @@ ) from pymongo import _csot, ssl_support +from pymongo._asyncio_task import create_task from pymongo.errors import _OperationCancelled from pymongo.socket_checker import _errno_from_exception @@ -259,19 +260,20 @@ async def async_receive_data( sock.settimeout(0.0) loop = asyncio.get_event_loop() - cancellation_task = asyncio.create_task(_poll_cancellation(conn)) + cancellation_task = create_task(_poll_cancellation(conn)) try: if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): - read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type] + read_task = create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type] else: - read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type] + read_task = create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type] tasks = [read_task, cancellation_task] done, pending = await asyncio.wait( tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED ) for task in pending: task.cancel() - await asyncio.wait(pending) + if pending: + await asyncio.wait(pending) if len(done) == 0: raise socket.timeout("timed out") if read_task in done: diff --git a/pymongo/synchronous/periodic_executor.py b/pymongo/periodic_executor.py similarity index 67% rename from pymongo/synchronous/periodic_executor.py rename to pymongo/periodic_executor.py index 525268b14b..2f89b91deb 100644 --- a/pymongo/synchronous/periodic_executor.py +++ b/pymongo/periodic_executor.py @@ -23,9 +23,102 @@ import weakref from typing import Any, Optional +from pymongo._asyncio_task import create_task from pymongo.lock import _create_lock -_IS_SYNC = True +_IS_SYNC = False + + +class AsyncPeriodicExecutor: + def __init__( + self, + interval: float, + min_interval: float, + target: Any, + name: Optional[str] = None, + ): + """Run a target function periodically on a background task. + + If the target's return value is false, the executor stops. + + :param interval: Seconds between calls to `target`. + :param min_interval: Minimum seconds between calls if `wake` is + called very often. + :param target: A function. + :param name: A name to give the underlying task. + """ + self._event = False + self._interval = interval + self._min_interval = min_interval + self._target = target + self._stopped = False + self._task: Optional[asyncio.Task] = None + self._name = name + self._skip_sleep = False + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>" + + def open(self) -> None: + """Start. Multiple calls have no effect.""" + self._stopped = False + + if self._task is None or ( + self._task.done() and not self._task.cancelled() and not self._task.cancelling() # type: ignore[unused-ignore, attr-defined] + ): + self._task = create_task(self._run(), name=self._name) + + def close(self, dummy: Any = None) -> None: + """Stop. To restart, call open(). + + The dummy parameter allows an executor's close method to be a weakref + callback; see monitor.py. + """ + self._stopped = True + + async def join(self, timeout: Optional[int] = None) -> None: + if self._task is not None: + try: + await asyncio.wait_for(self._task, timeout=timeout) # type-ignore: [arg-type] + except asyncio.TimeoutError: + # Task timed out + pass + except asyncio.exceptions.CancelledError: + # Task was already finished, or not yet started. + raise + + def wake(self) -> None: + """Execute the target function soon.""" + self._event = True + + def update_interval(self, new_interval: int) -> None: + self._interval = new_interval + + def skip_sleep(self) -> None: + self._skip_sleep = True + + async def _run(self) -> None: + while not self._stopped: + if self._task and self._task.cancelling(): # type: ignore[unused-ignore, attr-defined] + raise asyncio.CancelledError + try: + if not await self._target(): + self._stopped = True + break + except BaseException: + self._stopped = True + raise + + if self._skip_sleep: + self._skip_sleep = False + else: + deadline = time.monotonic() + self._interval + while not self._stopped and time.monotonic() < deadline: + await asyncio.sleep(self._min_interval) + if self._event: + break # Early wake. + + self._event = False class PeriodicExecutor: @@ -64,19 +157,6 @@ def __init__( def __repr__(self) -> str: return f"<{self.__class__.__name__}(name={self._name}) object at 0x{id(self):x}>" - def _run_async(self) -> None: - # The default asyncio loop implementation on Windows - # has issues with sharing sockets across loops (https://github.com/python/cpython/issues/122240) - # We explicitly use a different loop implementation here to prevent that issue - if sys.platform == "win32": - loop = asyncio.SelectorEventLoop() - try: - loop.run_until_complete(self._run()) # type: ignore[func-returns-value] - finally: - loop.close() - else: - asyncio.run(self._run()) # type: ignore[func-returns-value] - def open(self) -> None: """Start. Multiple calls have no effect. @@ -104,10 +184,7 @@ def open(self) -> None: pass if not started: - if _IS_SYNC: - thread = threading.Thread(target=self._run, name=self._name) - else: - thread = threading.Thread(target=self._run_async, name=self._name) + thread = threading.Thread(target=self._run, name=self._name) thread.daemon = True self._thread = weakref.proxy(thread) _register_executor(self) diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index 625e8429eb..9f6e3f7cf0 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -474,7 +474,6 @@ def _process_results_cursor( if op_type == "delete": res = DeleteResult(doc, acknowledged=True) # type: ignore[assignment] full_result[f"{op_type}Results"][original_index] = res - except Exception as exc: # Attempt to close the cursor, then raise top-level error. if cmd_cursor.alive: diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index 27a76cf91d..9a7637704f 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -77,7 +77,7 @@ class _ConnectionManager: def __init__(self, conn: Connection, more_to_come: bool): self.conn: Optional[Connection] = conn self.more_to_come = more_to_come - self._alock = _create_lock() + self._lock = _create_lock() def update_exhaust(self, more_to_come: bool) -> None: self.more_to_come = more_to_come diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index 506ff8bcba..09d0c0f2fd 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -15,6 +15,7 @@ """Support for explicit client-side field level encryption.""" from __future__ import annotations +import asyncio import contextlib import enum import socket @@ -111,6 +112,8 @@ def _wrap_encryption_errors() -> Iterator[None]: # BSON encoding/decoding errors are unrelated to encryption so # we should propagate them unchanged. raise + except asyncio.CancelledError: + raise except Exception as exc: raise EncryptionError(exc) from exc @@ -200,6 +203,8 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None: conn.close() except (PyMongoError, MongoCryptError): raise # Propagate pymongo errors directly. + except asyncio.CancelledError: + raise except Exception as error: # Wrap I/O errors in PyMongo exceptions. _raise_connection_failure((host, port), error) @@ -716,6 +721,8 @@ def create_encrypted_collection( database.create_collection(name=name, **kwargs), encrypted_fields, ) + except asyncio.CancelledError: + raise except Exception as exc: raise EncryptedCollectionError(exc, encrypted_fields) from exc diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 00c6203a94..a694a58c1e 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -32,6 +32,7 @@ """ from __future__ import annotations +import asyncio import contextlib import os import warnings @@ -58,7 +59,7 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp -from pymongo import _csot, common, helpers_shared, uri_parser +from pymongo import _csot, common, helpers_shared, periodic_executor, uri_parser from pymongo.client_options import ClientOptions from pymongo.errors import ( AutoReconnect, @@ -74,7 +75,11 @@ WaitQueueTimeoutError, WriteConcernError, ) -from pymongo.lock import _HAS_REGISTER_AT_FORK, _create_lock, _release_locks +from pymongo.lock import ( + _HAS_REGISTER_AT_FORK, + _create_lock, + _release_locks, +) from pymongo.logger import _CLIENT_LOGGER, _log_or_warn from pymongo.message import _CursorAddress, _GetMore, _Query from pymongo.monitoring import ConnectionClosedReason @@ -91,7 +96,7 @@ from pymongo.results import ClientBulkWriteResult from pymongo.server_selectors import writable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import client_session, database, periodic_executor +from pymongo.synchronous import client_session, database from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream from pymongo.synchronous.client_bulk import _ClientBulk from pymongo.synchronous.client_session import _EmptyServerSession @@ -1716,7 +1721,7 @@ def _run_operation( address=address, ) - with operation.conn_mgr._alock: + with operation.conn_mgr._lock: with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] err_handler.contribute_socket(operation.conn_mgr.conn) return server.run_operation( @@ -1964,7 +1969,7 @@ def _close_cursor_now( try: if conn_mgr: - with conn_mgr._alock: + with conn_mgr._lock: # Cursor is pinned to LB outside of a transaction. assert address is not None assert conn_mgr.conn is not None @@ -2027,6 +2032,8 @@ def _process_kill_cursors(self) -> None: for address, cursor_id, conn_mgr in pinned_cursors: try: self._cleanup_cursor_lock(cursor_id, address, conn_mgr, None, False) + except asyncio.CancelledError: + raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: # Raise the exception when client is closed so that it @@ -2041,6 +2048,8 @@ def _process_kill_cursors(self) -> None: for address, cursor_ids in address_to_cursor_ids.items(): try: self._kill_cursors(cursor_ids, address, topology, session=None) + except asyncio.CancelledError: + raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: raise @@ -2055,6 +2064,8 @@ def _process_periodic_tasks(self) -> None: try: self._process_kill_cursors() self._topology.update_pool() + except asyncio.CancelledError: + raise except Exception as exc: if isinstance(exc, InvalidOperation) and self._topology._closed: return diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index d02ad0a6fd..df4130d4ab 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -16,24 +16,24 @@ from __future__ import annotations +import asyncio import atexit import logging import time import weakref from typing import TYPE_CHECKING, Any, Mapping, Optional, cast -from pymongo import common +from pymongo import common, periodic_executor from pymongo._csot import MovingMinimum from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled from pymongo.hello import Hello from pymongo.lock import _create_lock from pymongo.logger import _SDAM_LOGGER, _debug_log, _SDAMStatusMessage +from pymongo.periodic_executor import _shutdown_executors from pymongo.pool_options import _is_faas from pymongo.read_preferences import MovingAverage from pymongo.server_description import ServerDescription from pymongo.srv_resolver import _SrvResolver -from pymongo.synchronous import periodic_executor -from pymongo.synchronous.periodic_executor import _shutdown_executors if TYPE_CHECKING: from pymongo.synchronous.pool import Connection, Pool, _CancellationContext @@ -238,6 +238,9 @@ def _run(self) -> None: except ReferenceError: # Topology was garbage-collected. self.close() + finally: + if self._executor._stopped: + self._rtt_monitor.close() def _check_server(self) -> ServerDescription: """Call hello or read the next streaming response. @@ -254,6 +257,8 @@ def _check_server(self) -> ServerDescription: details = cast(Mapping[str, Any], exc.details) self._topology.receive_cluster_time(details.get("$clusterTime")) raise + except asyncio.CancelledError: + raise except ReferenceError: raise except Exception as error: @@ -419,6 +424,8 @@ def _get_seedlist(self) -> Optional[list[tuple[str, Any]]]: if len(seedlist) == 0: # As per the spec: this should be treated as a failure. raise Exception + except asyncio.CancelledError: + raise except Exception: # As per the spec, upon encountering an error: # - An error must not be raised @@ -482,6 +489,8 @@ def _run(self) -> None: except ReferenceError: # Topology was garbage-collected. self.close() + except asyncio.CancelledError: + raise except Exception: self._pool.reset() @@ -536,4 +545,5 @@ def _shutdown_resources() -> None: shutdown() -atexit.register(_shutdown_resources) +if _IS_SYNC: + atexit.register(_shutdown_resources) diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 86baf15b9a..1a155c82d7 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -23,7 +23,6 @@ import socket import ssl import sys -import threading import time import weakref from typing import ( @@ -62,7 +61,11 @@ _CertificateError, ) from pymongo.hello import Hello, HelloCompat -from pymongo.lock import _create_lock, _Lock +from pymongo.lock import ( + _cond_wait, + _create_condition, + _create_lock, +) from pymongo.logger import ( _CONNECTION_LOGGER, _ConnectionStatusMessage, @@ -208,11 +211,6 @@ def _raise_connection_failure( raise AutoReconnect(msg) from error -def _cond_wait(condition: threading.Condition, deadline: Optional[float]) -> bool: - timeout = deadline - time.monotonic() if deadline else None - return condition.wait(timeout) - - def _get_timeout_details(options: PoolOptions) -> dict[str, float]: details = {} timeout = _csot.get_timeout() @@ -704,6 +702,8 @@ def _close_conn(self) -> None: # shutdown. try: self.conn.close() + except asyncio.CancelledError: + raise except Exception: # noqa: S110 pass @@ -988,8 +988,8 @@ def __init__( # from the right side. self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() - _lock = _create_lock() - self.lock = _Lock(_lock) + self.lock = _create_lock() + self._max_connecting_cond = _create_condition(self.lock) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 @@ -1015,7 +1015,7 @@ def __init__( # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue - self.size_cond = threading.Condition(_lock) + self.size_cond = _create_condition(self.lock) self.requests = 0 self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: @@ -1023,7 +1023,7 @@ def __init__( # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue - self._max_connecting_cond = threading.Condition(_lock) + self._max_connecting_cond = _create_condition(self.lock) self._max_connecting = self.opts.max_connecting self._pending = 0 self._client_id = client_id @@ -1460,7 +1460,8 @@ def _get_conn( with self.size_cond: self._raise_if_not_ready(checkout_started_time, emit_event=True) while not (self.requests < self.max_pool_size): - if not _cond_wait(self.size_cond, deadline): + timeout = deadline - time.monotonic() if deadline else None + if not _cond_wait(self.size_cond, timeout): # Timed out, notify the next thread to ensure a # timeout doesn't consume the condition. if self.requests < self.max_pool_size: @@ -1483,7 +1484,8 @@ def _get_conn( with self._max_connecting_cond: self._raise_if_not_ready(checkout_started_time, emit_event=False) while not (self.conns or self._pending < self._max_connecting): - if not _cond_wait(self._max_connecting_cond, deadline): + timeout = deadline - time.monotonic() if deadline else None + if not _cond_wait(self._max_connecting_cond, timeout): # Timed out, notify the next thread to ensure a # timeout doesn't consume the condition. if self.conns or self._pending < self._max_connecting: diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index a350c1702e..b03269ae43 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -27,7 +27,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast -from pymongo import _csot, common, helpers_shared +from pymongo import _csot, common, helpers_shared, periodic_executor from pymongo.errors import ( ConnectionFailure, InvalidOperation, @@ -39,7 +39,11 @@ WriteError, ) from pymongo.hello import Hello -from pymongo.lock import _create_lock, _Lock +from pymongo.lock import ( + _cond_wait, + _create_condition, + _create_lock, +) from pymongo.logger import ( _SDAM_LOGGER, _SERVER_SELECTION_LOGGER, @@ -56,7 +60,6 @@ secondary_server_selector, writable_server_selector, ) -from pymongo.synchronous import periodic_executor from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool from pymongo.synchronous.monitor import SrvMonitor from pymongo.synchronous.pool import Pool @@ -170,9 +173,10 @@ def __init__(self, topology_settings: TopologySettings): self._seed_addresses = list(topology_description.server_descriptions()) self._opened = False self._closed = False - _lock = _create_lock() - self._lock = _Lock(_lock) - self._condition = self._settings.condition_class(_lock) + self._lock = _create_lock() + self._condition = _create_condition( + self._lock, self._settings.condition_class if _IS_SYNC else None + ) self._servers: dict[_Address, Server] = {} self._pid: Optional[int] = None self._max_cluster_time: Optional[ClusterTime] = None @@ -354,7 +358,7 @@ def _select_servers_loop( # change, or for a timeout. We won't miss any changes that # came after our most recent apply_selector call, since we've # held the lock until now. - self._condition.wait(common.MIN_HEARTBEAT_INTERVAL) + _cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL) self._description.check_compatible() now = time.monotonic() server_descriptions = self._description.apply_selector( @@ -652,7 +656,7 @@ def request_check_all(self, wait_time: int = 5) -> None: """Wake all monitors, wait for at least one to check its server.""" with self._lock: self._request_check_all() - self._condition.wait(wait_time) + _cond_wait(self._condition, wait_time) def data_bearing_servers(self) -> list[ServerDescription]: """Return a list of all data-bearing servers. diff --git a/test/__init__.py b/test/__init__.py index fd33fde293..d3a63db2d5 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -17,6 +17,7 @@ import asyncio import gc +import logging import multiprocessing import os import signal @@ -25,6 +26,7 @@ import sys import threading import time +import traceback import unittest import warnings from asyncio import iscoroutinefunction @@ -191,6 +193,8 @@ def _connect(self, host, port, **kwargs): client.close() def _init_client(self): + self.mongoses = [] + self.connection_attempts = [] self.client = self._connect(host, port) if self.client is not None: # Return early when connected to dataLake as mongohoused does not @@ -860,6 +864,16 @@ def max_message_size_bytes(self): client_context = ClientContext() +def reset_client_context(): + if _IS_SYNC: + # sync tests don't need to reset a client context + return + elif client_context.client is not None: + client_context.client.close() + client_context.client = None + client_context._init_client() + + class PyMongoTestCase(unittest.TestCase): def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) @@ -1106,26 +1120,10 @@ def enable_replication(self, client): class UnitTest(PyMongoTestCase): """Async base class for TestCases that don't require a connection to MongoDB.""" - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod - def _setup_class(cls): + def setUp(self) -> None: pass - @classmethod - def _tearDown_class(cls): + def tearDown(self) -> None: pass @@ -1136,37 +1134,20 @@ class IntegrationTest(PyMongoTestCase): db: Database credentials: Dict[str, str] - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod @client_context.require_connection - def _setup_class(cls): - if client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False): + def setUp(self) -> None: + if not _IS_SYNC: + reset_client_context() + if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") - if client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False): + if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False): raise SkipTest("this test does not support serverless") - cls.client = client_context.client - cls.db = cls.client.pymongo_test + self.client = client_context.client + self.db = self.client.pymongo_test if client_context.auth_enabled: - cls.credentials = {"username": db_user, "password": db_pwd} + self.credentials = {"username": db_user, "password": db_pwd} else: - cls.credentials = {} - - @classmethod - def _tearDown_class(cls): - pass + self.credentials = {} def cleanup_colls(self, *collections): """Cleanup collections faster than drop_collection.""" @@ -1192,37 +1173,14 @@ class MockClientTest(UnitTest): # MockClients tests that use replicaSet, directConnection=True, pass # multiple seed addresses, or wait for heartbeat events are incompatible # with loadBalanced=True. - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod @client_context.require_no_load_balancer - def _setup_class(cls): - pass - - @classmethod - def _tearDown_class(cls): - pass - - def setUp(self): + def setUp(self) -> None: super().setUp() self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001) - self.client_knobs.enable() - def tearDown(self): + def tearDown(self) -> None: self.client_knobs.disable() super().tearDown() @@ -1253,7 +1211,6 @@ def teardown(): c.drop_database("pymongo_test_mike") c.drop_database("pymongo_test_bernie") c.close() - print_running_clients() diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 0579828c49..73e2824742 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -17,6 +17,7 @@ import asyncio import gc +import logging import multiprocessing import os import signal @@ -25,6 +26,7 @@ import sys import threading import time +import traceback import unittest import warnings from asyncio import iscoroutinefunction @@ -191,6 +193,8 @@ async def _connect(self, host, port, **kwargs): await client.close() async def _init_client(self): + self.mongoses = [] + self.connection_attempts = [] self.client = await self._connect(host, port) if self.client is not None: # Return early when connected to dataLake as mongohoused does not @@ -862,6 +866,16 @@ async def max_message_size_bytes(self): async_client_context = AsyncClientContext() +async def reset_client_context(): + if _IS_SYNC: + # sync tests don't need to reset a client context + return + elif async_client_context.client is not None: + await async_client_context.client.close() + async_client_context.client = None + await async_client_context._init_client() + + class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase): def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) @@ -1124,26 +1138,10 @@ async def enable_replication(self, client): class AsyncUnitTest(AsyncPyMongoTestCase): """Async base class for TestCases that don't require a connection to MongoDB.""" - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod - async def _setup_class(cls): + async def asyncSetUp(self) -> None: pass - @classmethod - async def _tearDown_class(cls): + async def asyncTearDown(self) -> None: pass @@ -1154,37 +1152,20 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase): db: AsyncDatabase credentials: Dict[str, str] - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod @async_client_context.require_connection - async def _setup_class(cls): - if async_client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False): + async def asyncSetUp(self) -> None: + if not _IS_SYNC: + await reset_client_context() + if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") - if async_client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False): + if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False): raise SkipTest("this test does not support serverless") - cls.client = async_client_context.client - cls.db = cls.client.pymongo_test + self.client = async_client_context.client + self.db = self.client.pymongo_test if async_client_context.auth_enabled: - cls.credentials = {"username": db_user, "password": db_pwd} + self.credentials = {"username": db_user, "password": db_pwd} else: - cls.credentials = {} - - @classmethod - async def _tearDown_class(cls): - pass + self.credentials = {} async def cleanup_colls(self, *collections): """Cleanup collections faster than drop_collection.""" @@ -1210,39 +1191,16 @@ class AsyncMockClientTest(AsyncUnitTest): # MockClients tests that use replicaSet, directConnection=True, pass # multiple seed addresses, or wait for heartbeat events are incompatible # with loadBalanced=True. - @classmethod - def setUpClass(cls): - if _IS_SYNC: - cls._setup_class() - else: - asyncio.run(cls._setup_class()) - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls._tearDown_class() - else: - asyncio.run(cls._tearDown_class()) - - @classmethod @async_client_context.require_no_load_balancer - async def _setup_class(cls): - pass - - @classmethod - async def _tearDown_class(cls): - pass - - def setUp(self): - super().setUp() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001) - self.client_knobs.enable() - def tearDown(self): + async def asyncTearDown(self) -> None: self.client_knobs.disable() - super().tearDown() + await super().asyncTearDown() async def async_setup(): @@ -1271,7 +1229,6 @@ async def async_teardown(): await c.drop_database("pymongo_test_mike") await c.drop_database("pymongo_test_bernie") await c.close() - print_running_clients() diff --git a/test/asynchronous/conftest.py b/test/asynchronous/conftest.py index e443dff6c0..a27a9f213d 100644 --- a/test/asynchronous/conftest.py +++ b/test/asynchronous/conftest.py @@ -22,7 +22,7 @@ def event_loop_policy(): return asyncio.get_event_loop_policy() -@pytest_asyncio.fixture(scope="session", autouse=True) +@pytest_asyncio.fixture(scope="package", autouse=True) async def test_setup_and_teardown(): await async_setup() yield diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index c9ff167b43..7191a412c1 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -42,15 +42,11 @@ class AsyncBulkTestBase(AsyncIntegrationTest): coll: AsyncCollection coll_w0: AsyncCollection - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.coll = cls.db.test - cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0)) - async def asyncSetUp(self): - super().setUp() + await super().asyncSetUp() + self.coll = self.db.test await self.coll.drop() + self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0)) def assertEqualResponse(self, expected, actual): """Compare response from bulk.execute() to expected response.""" @@ -787,14 +783,10 @@ async def test_large_inserts_unordered(self): class AsyncBulkAuthorizationTestBase(AsyncBulkTestBase): - @classmethod @async_client_context.require_auth @async_client_context.require_no_api_version - async def _setup_class(cls): - await super()._setup_class() - async def asyncSetUp(self): - super().setUp() + await super().asyncSetUp() await async_client_context.create_user(self.db.name, "readonly", "pw", ["read"]) await self.db.command( "createRole", @@ -937,21 +929,19 @@ class AsyncTestBulkWriteConcern(AsyncBulkTestBase): w: Optional[int] secondary: AsyncMongoClient - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.w = async_client_context.w - cls.secondary = None - if cls.w is not None and cls.w > 1: + async def asyncSetUp(self): + await super().asyncSetUp() + self.w = async_client_context.w + self.secondary = None + if self.w is not None and self.w > 1: for member in (await async_client_context.hello)["hosts"]: if member != (await async_client_context.hello)["primary"]: - cls.secondary = await cls.unmanaged_async_single_client(*partition_node(member)) + self.secondary = await self.async_single_client(*partition_node(member)) break - @classmethod - async def async_tearDownClass(cls): - if cls.secondary: - await cls.secondary.close() + async def asyncTearDown(self): + if self.secondary: + await self.secondary.close() async def cause_wtimeout(self, requests, ordered): if not async_client_context.test_commands_enabled: diff --git a/test/asynchronous/test_change_stream.py b/test/asynchronous/test_change_stream.py index 8e16fe7528..08da00cc1e 100644 --- a/test/asynchronous/test_change_stream.py +++ b/test/asynchronous/test_change_stream.py @@ -836,18 +836,16 @@ async def test_split_large_change(self): class TestClusterAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin): dbs: list - @classmethod @async_client_context.require_version_min(4, 0, 0, -1) @async_client_context.require_change_streams - async def _setup_class(cls): - await super()._setup_class() - cls.dbs = [cls.db, cls.client.pymongo_test_2] + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.dbs = [self.db, self.client.pymongo_test_2] - @classmethod - async def _tearDown_class(cls): - for db in cls.dbs: - await cls.client.drop_database(db) - await super()._tearDown_class() + async def asyncTearDown(self): + for db in self.dbs: + await self.client.drop_database(db) + await super().asyncTearDown() async def change_stream_with_client(self, client, *args, **kwargs): return await client.watch(*args, **kwargs) @@ -898,11 +896,10 @@ async def test_full_pipeline(self): class TestAsyncDatabaseAsyncChangeStream(TestAsyncChangeStreamBase, APITestsMixin): - @classmethod @async_client_context.require_version_min(4, 0, 0, -1) @async_client_context.require_change_streams - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() async def change_stream_with_client(self, client, *args, **kwargs): return await client[self.db.name].watch(*args, **kwargs) @@ -988,12 +985,9 @@ async def test_isolation(self): class TestAsyncCollectionAsyncChangeStream( TestAsyncChangeStreamBase, APITestsMixin, ProseSpecTestsMixin ): - @classmethod @async_client_context.require_change_streams - async def _setup_class(cls): - await super()._setup_class() - async def asyncSetUp(self): + await super().asyncSetUp() # Use a new collection for each test. await self.watched_collection().drop() await self.watched_collection().insert_one({}) @@ -1133,20 +1127,11 @@ class TestAllLegacyScenarios(AsyncIntegrationTest): RUN_ON_LOAD_BALANCER = True listener: AllowListEventListener - @classmethod @async_client_context.require_connection - async def _setup_class(cls): - await super()._setup_class() - cls.listener = AllowListEventListener("aggregate", "getMore") - cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) - - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - await super()._tearDown_class() - - def asyncSetUp(self): - super().asyncSetUp() + async def asyncSetUp(self): + await super().asyncSetUp() + self.listener = AllowListEventListener("aggregate", "getMore") + self.client = await self.async_rs_or_single_client(event_listeners=[self.listener]) self.listener.reset() async def asyncSetUpCluster(self, scenario_dict): diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 590154b857..db232386ee 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -73,7 +73,6 @@ is_greenthread_patched, lazy_client_trial, one, - wait_until, ) import bson @@ -131,16 +130,11 @@ class AsyncClientUnitTest(AsyncUnitTest): client: AsyncMongoClient - @classmethod - async def _setup_class(cls): - cls.client = await cls.unmanaged_async_rs_or_single_client( + async def asyncSetUp(self) -> None: + self.client = await self.async_rs_or_single_client( connect=False, serverSelectionTimeoutMS=100 ) - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): self._caplog = caplog @@ -693,8 +687,8 @@ async def test_max_idle_time_reaper_removes_stale_minPoolSize(self): # When the reaper runs at the same time as the get_socket, two # connections could be created and checked into the pool. self.assertGreaterEqual(len(server._pool.conns), 1) - wait_until(lambda: conn not in server._pool.conns, "remove stale socket") - wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket") + await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket") + await async_wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket") async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): @@ -710,8 +704,8 @@ async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): # When the reaper runs at the same time as the get_socket, # maxPoolSize=1 should prevent two connections from being created. self.assertEqual(1, len(server._pool.conns)) - wait_until(lambda: conn not in server._pool.conns, "remove stale socket") - wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket") + await async_wait_until(lambda: conn not in server._pool.conns, "remove stale socket") + await async_wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket") async def test_max_idle_time_reaper_removes_stale(self): with client_knobs(kill_cursor_frequency=0.1): @@ -727,7 +721,7 @@ async def test_max_idle_time_reaper_removes_stale(self): async with server._pool.checkout() as conn_two: pass self.assertIs(conn_one, conn_two) - wait_until( + await async_wait_until( lambda: len(server._pool.conns) == 0, "stale socket reaped and new one NOT added to the pool", ) @@ -745,7 +739,7 @@ async def test_min_pool_size(self): server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) - wait_until( + await async_wait_until( lambda: len(server._pool.conns) == 10, "pool initialized with 10 connections", ) @@ -753,7 +747,7 @@ async def test_min_pool_size(self): # Assert that if a socket is closed, a new one takes its place async with server._pool.checkout() as conn: conn.close_conn(None) - wait_until( + await async_wait_until( lambda: len(server._pool.conns) == 10, "a closed socket gets replaced from the pool", ) @@ -939,8 +933,10 @@ async def test_repr(self): async with eval(the_repr) as client_two: self.assertEqual(client_two, client) - def test_getters(self): - wait_until(lambda: async_client_context.nodes == self.client.nodes, "find all nodes") + async def test_getters(self): + await async_wait_until( + lambda: async_client_context.nodes == self.client.nodes, "find all nodes" + ) async def test_list_databases(self): cmd_docs = (await self.client.admin.command("listDatabases"))["databases"] @@ -1065,14 +1061,21 @@ async def test_uri_connect_option(self): self.assertFalse(client._topology._opened) # Ensure kill cursors thread has not been started. - kc_thread = client._kill_cursors_executor._thread - self.assertFalse(kc_thread and kc_thread.is_alive()) - + if _IS_SYNC: + kc_thread = client._kill_cursors_executor._thread + self.assertFalse(kc_thread and kc_thread.is_alive()) + else: + kc_task = client._kill_cursors_executor._task + self.assertFalse(kc_task and not kc_task.done()) # Using the client should open topology and start the thread. await client.admin.command("ping") self.assertTrue(client._topology._opened) - kc_thread = client._kill_cursors_executor._thread - self.assertTrue(kc_thread and kc_thread.is_alive()) + if _IS_SYNC: + kc_thread = client._kill_cursors_executor._thread + self.assertTrue(kc_thread and kc_thread.is_alive()) + else: + kc_task = client._kill_cursors_executor._task + self.assertTrue(kc_task and not kc_task.done()) async def test_close_does_not_open_servers(self): client = await self.async_rs_client(connect=False) @@ -1277,6 +1280,7 @@ async def get_x(db): async def test_server_selection_timeout(self): client = AsyncMongoClient(serverSelectionTimeoutMS=100, connect=False) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) + await client.close() client = AsyncMongoClient(serverSelectionTimeoutMS=0, connect=False) @@ -1289,18 +1293,22 @@ async def test_server_selection_timeout(self): self.assertRaises( ConfigurationError, AsyncMongoClient, serverSelectionTimeoutMS=None, connect=False ) + await client.close() client = AsyncMongoClient( "mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False ) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) + await client.close() client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False) self.assertAlmostEqual(0, client.options.server_selection_timeout) + await client.close() # Test invalid timeout in URI ignored and set to default. client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False) self.assertAlmostEqual(30, client.options.server_selection_timeout) + await client.close() client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False) self.assertAlmostEqual(30, client.options.server_selection_timeout) @@ -1608,7 +1616,7 @@ def init(self, *args): await async_client_context.port, ) await self.async_single_client(uri, event_listeners=[listener]) - wait_until( + await async_wait_until( lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents" ) @@ -1766,16 +1774,16 @@ async def test_background_connections_do_not_hold_locks(self): pool = await async_get_pool(client) original_connect = pool.connect - def stall_connect(*args, **kwargs): - time.sleep(2) - return original_connect(*args, **kwargs) + async def stall_connect(*args, **kwargs): + await asyncio.sleep(2) + return await original_connect(*args, **kwargs) pool.connect = stall_connect # Un-patch Pool.connect to break the cyclic reference. self.addCleanup(delattr, pool, "connect") # Wait for the background thread to start creating connections - wait_until(lambda: len(pool.conns) > 1, "start creating connections") + await async_wait_until(lambda: len(pool.conns) > 1, "start creating connections") # Assert that application operations do not block. for _ in range(10): @@ -1858,7 +1866,7 @@ async def test_process_periodic_tasks(self): await client.close() # Add cursor to kill cursors queue del cursor - wait_until( + await async_wait_until( lambda: client._kill_cursors_queue, "waited for cursor to be added to queue", ) @@ -2232,7 +2240,7 @@ async def test_exhaust_getmore_network_error(self): await cursor.to_list() self.assertTrue(conn.closed) - wait_until( + await async_wait_until( lambda: len(client._kill_cursors_queue) == 0, "waited for all killCursor requests to complete", ) @@ -2403,7 +2411,7 @@ async def test_discover_primary(self): ) self.addAsyncCleanup(c.close) - wait_until(lambda: len(c.nodes) == 3, "connect") + await async_wait_until(lambda: len(c.nodes) == 3, "connect") self.assertEqual(await c.address, ("a", 1)) # Fail over. @@ -2430,7 +2438,7 @@ async def test_reconnect(self): ) self.addAsyncCleanup(c.close) - wait_until(lambda: len(c.nodes) == 3, "connect") + await async_wait_until(lambda: len(c.nodes) == 3, "connect") # Total failure. c.kill_host("a:1") @@ -2472,7 +2480,7 @@ async def _test_network_error(self, operation_callback): c.set_wire_version_range("a:1", 2, MIN_SUPPORTED_WIRE_VERSION) c.set_wire_version_range("b:2", 2, MIN_SUPPORTED_WIRE_VERSION + 1) await (await c._get_topology()).select_servers(writable_server_selector, _Op.TEST) - wait_until(lambda: len(c.nodes) == 2, "connect") + await async_wait_until(lambda: len(c.nodes) == 2, "connect") c.kill_host("a:1") @@ -2544,11 +2552,11 @@ async def test_rs_client_does_not_maintain_pool_to_arbiters(self): ) self.addAsyncCleanup(c.close) - wait_until(lambda: len(c.nodes) == 3, "connect") + await async_wait_until(lambda: len(c.nodes) == 3, "connect") self.assertEqual(await c.address, ("a", 1)) self.assertEqual(await c.arbiters, {("c", 3)}) # Assert that we create 2 and only 2 pooled connections. - listener.wait_for_event(monitoring.ConnectionReadyEvent, 2) + await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 2) self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 2) # Assert that we do not create connections to arbiters. arbiter = c._topology.get_server_by_address(("c", 3)) @@ -2574,10 +2582,10 @@ async def test_direct_client_maintains_pool_to_arbiter(self): ) self.addAsyncCleanup(c.close) - wait_until(lambda: len(c.nodes) == 1, "connect") + await async_wait_until(lambda: len(c.nodes) == 1, "connect") self.assertEqual(await c.address, ("c", 3)) # Assert that we create 1 pooled connection. - listener.wait_for_event(monitoring.ConnectionReadyEvent, 1) + await listener.async_wait_for_event(monitoring.ConnectionReadyEvent, 1) self.assertEqual(listener.event_count(monitoring.ConnectionCreatedEvent), 1) arbiter = c._topology.get_server_by_address(("c", 3)) self.assertEqual(len(arbiter.pool.conns), 1) diff --git a/test/asynchronous/test_collation.py b/test/asynchronous/test_collation.py index d95f4c9917..d7fd85b168 100644 --- a/test/asynchronous/test_collation.py +++ b/test/asynchronous/test_collation.py @@ -97,28 +97,21 @@ class TestCollation(AsyncIntegrationTest): warn_context: Any collation: Collation - @classmethod @async_client_context.require_connection - async def _setup_class(cls): - await super()._setup_class() - cls.listener = OvertCommandListener() - cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) - cls.db = cls.client.pymongo_test - cls.collation = Collation("en_US") - cls.warn_context = warnings.catch_warnings() - cls.warn_context.__enter__() - warnings.simplefilter("ignore", DeprecationWarning) - - @classmethod - async def _tearDown_class(cls): - cls.warn_context.__exit__() - cls.warn_context = None - await cls.client.close() - await super()._tearDown_class() - - def tearDown(self): + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.listener = OvertCommandListener() + self.client = await self.async_rs_or_single_client(event_listeners=[self.listener]) + self.db = self.client.pymongo_test + self.collation = Collation("en_US") + self.warn_context = warnings.catch_warnings() + self.warn_context.__enter__() + + async def asyncTearDown(self) -> None: + self.warn_context.__exit__() + self.warn_context = None self.listener.reset() - super().tearDown() + await super().asyncTearDown() def last_command_started(self): return self.listener.started_events[-1].command diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index db52bad4ac..528919f63c 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -40,7 +40,6 @@ async_get_pool, async_is_mongos, async_wait_until, - wait_until, ) from bson import encode @@ -88,14 +87,10 @@ class TestCollectionNoConnect(AsyncUnitTest): db: AsyncDatabase client: AsyncMongoClient - @classmethod - async def _setup_class(cls): - cls.client = AsyncMongoClient(connect=False) - cls.db = cls.client.pymongo_test - - @classmethod - async def _tearDown_class(cls): - await cls.client.close() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.client = self.simple_client(connect=False) + self.db = self.client.pymongo_test def test_collection(self): self.assertRaises(TypeError, AsyncCollection, self.db, 5) @@ -165,27 +160,14 @@ def test_iteration(self): class AsyncTestCollection(AsyncIntegrationTest): w: int - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.w = async_client_context.w # type: ignore - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine] - else: - asyncio.run(cls.async_tearDownClass()) - - @classmethod - async def async_tearDownClass(cls): - await cls.db.drop_collection("test_large_limit") - async def asyncSetUp(self): - await self.db.test.drop() + await super().asyncSetUp() + self.w = async_client_context.w # type: ignore async def asyncTearDown(self): await self.db.test.drop() + await self.db.drop_collection("test_large_limit") + await super().asyncTearDown() @contextlib.contextmanager def write_concern_collection(self): @@ -1023,7 +1005,10 @@ async def test_replace_bypass_document_validation(self): await db.test.insert_one({"y": 1}, bypass_document_validation=True) await db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True) - await async_wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document") + async def predicate(): + return await db_w0.test.find_one({"x": 1}) + + await async_wait_until(predicate, "find w:0 replaced document") async def test_update_bypass_document_validation(self): db = self.db @@ -1871,7 +1856,7 @@ async def test_exhaust(self): await cur.close() cur = None # Wait until the background thread returns the socket. - wait_until(lambda: pool.active_sockets == 0, "return socket") + await async_wait_until(lambda: pool.active_sockets == 0, "return socket") # The socket should be discarded. self.assertEqual(0, len(pool.conns)) diff --git a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py index 289cf49751..bc9638b443 100644 --- a/test/asynchronous/test_connections_survive_primary_stepdown_spec.py +++ b/test/asynchronous/test_connections_survive_primary_stepdown_spec.py @@ -19,7 +19,12 @@ sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous import ( + AsyncIntegrationTest, + async_client_context, + reset_client_context, + unittest, +) from test.asynchronous.helpers import async_repl_set_step_down from test.utils import ( CMAPListener, @@ -39,29 +44,19 @@ class TestAsyncConnectionsSurvivePrimaryStepDown(AsyncIntegrationTest): listener: CMAPListener coll: AsyncCollection - @classmethod @async_client_context.require_replica_set - async def _setup_class(cls): - await super()._setup_class() - cls.listener = CMAPListener() - cls.client = await cls.unmanaged_async_rs_or_single_client( - event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500 + async def asyncSetUp(self): + self.listener = CMAPListener() + self.client = await self.async_rs_or_single_client( + event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500 ) # Ensure connections to all servers in replica set. This is to test # that the is_writable flag is properly updated for connections that # survive a replica set election. - await async_ensure_all_connected(cls.client) - cls.listener.reset() - - cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority")) - cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority")) - - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - - async def asyncSetUp(self): + await async_ensure_all_connected(self.client) + self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority")) + self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority")) # Note that all ops use same write-concern as self.db (majority). await self.db.drop_collection("step-down") await self.db.create_collection("step-down") diff --git a/test/asynchronous/test_create_entities.py b/test/asynchronous/test_create_entities.py index cb2ec63f4c..1f68cf6ddc 100644 --- a/test/asynchronous/test_create_entities.py +++ b/test/asynchronous/test_create_entities.py @@ -56,6 +56,9 @@ async def test_store_events_as_entities(self): self.assertGreater(len(final_entity_map["events1"]), 0) for event in final_entity_map["events1"]: self.assertIn("PoolCreatedEvent", event["name"]) + if self.scenario_runner.mongos_clients: + for client in self.scenario_runner.mongos_clients: + await client.close() async def test_store_all_others_as_entities(self): self.scenario_runner = UnifiedSpecTestMixinV1() @@ -122,6 +125,9 @@ async def test_store_all_others_as_entities(self): self.assertEqual(entity_map["failures"], []) self.assertEqual(entity_map["successes"], 2) self.assertEqual(entity_map["iterations"], 5) + if self.scenario_runner.mongos_clients: + for client in self.scenario_runner.mongos_clients: + await client.close() if __name__ == "__main__": diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index 787da3d957..d216479451 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -34,9 +34,9 @@ AllowListEventListener, EventListener, OvertCommandListener, + async_wait_until, delay, ignore_deprecations, - wait_until, ) from bson import decode_all @@ -1324,8 +1324,8 @@ async def test_timeout_kills_cursor_asynchronously(self): with self.assertRaises(ExecutionTimeout): await cursor.next() - def assertCursorKilled(): - wait_until( + async def assertCursorKilled(): + await async_wait_until( lambda: len(listener.succeeded_events), "find successful killCursors command", ) @@ -1335,7 +1335,7 @@ def assertCursorKilled(): self.assertEqual(1, len(listener.succeeded_events)) self.assertEqual("killCursors", listener.succeeded_events[0].command_name) - assertCursorKilled() + await assertCursorKilled() listener.reset() cursor = await coll.aggregate([], batchSize=1) @@ -1345,7 +1345,7 @@ def assertCursorKilled(): with self.assertRaises(ExecutionTimeout): await cursor.next() - assertCursorKilled() + await assertCursorKilled() def test_delete_not_initialized(self): # Creating a cursor with invalid arguments will not run __init__ @@ -1647,10 +1647,6 @@ async def test_monitoring(self): class TestRawBatchCommandCursor(AsyncIntegrationTest): - @classmethod - async def _setup_class(cls): - await super()._setup_class() - async def test_aggregate_raw(self): c = self.db.test await c.drop() diff --git a/test/asynchronous/test_database.py b/test/asynchronous/test_database.py index 61369c8542..b5a5960420 100644 --- a/test/asynchronous/test_database.py +++ b/test/asynchronous/test_database.py @@ -717,7 +717,8 @@ def test_with_options(self): class TestDatabaseAggregation(AsyncIntegrationTest): - def setUp(self): + async def asyncSetUp(self): + await super().asyncSetUp() self.pipeline: List[Mapping[str, Any]] = [ {"$listLocalSessions": {}}, {"$limit": 1}, diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 767b3ecf0a..048db2d501 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -211,11 +211,10 @@ async def test_kwargs(self): class AsyncEncryptionIntegrationTest(AsyncIntegrationTest): """Base class for encryption integration tests.""" - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @async_client_context.require_version_min(4, 2, -1) - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() def assertEncrypted(self, val): self.assertIsInstance(val, Binary) @@ -430,10 +429,9 @@ async def test_upsert_uuid_standard_encrypt(self): class TestClientMaxWireVersion(AsyncIntegrationTest): - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self): + await super().asyncSetUp() @async_client_context.require_version_max(4, 0, 99) async def test_raise_max_wire_version_error(self): @@ -818,17 +816,16 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest): "local": None, } - @classmethod @unittest.skipUnless( any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), "No environment credentials are set", ) - async def _setup_class(cls): - await super()._setup_class() - cls.listener = OvertCommandListener() - cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) - await cls.client.db.coll.drop() - cls.vault = await create_key_vault(cls.client.keyvault.datakeys) + async def asyncSetUp(self): + await super().asyncSetUp() + self.listener = OvertCommandListener() + self.client = await self.async_rs_or_single_client(event_listeners=[self.listener]) + await self.client.db.coll.drop() + self.vault = await create_key_vault(self.client.keyvault.datakeys) # Configure the encrypted field via the local schema_map option. schemas = { @@ -846,25 +843,22 @@ async def _setup_class(cls): } } opts = AutoEncryptionOpts( - cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS + self.KMS_PROVIDERS, + "keyvault.datakeys", + schema_map=schemas, + kms_tls_options=KMS_TLS_OPTS, ) - cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client( + self.client_encrypted = await self.async_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - cls.client_encryption = cls.unmanaged_create_client_encryption( - cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS + self.client_encryption = self.create_client_encryption( + self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS ) - - @classmethod - async def _tearDown_class(cls): - await cls.vault.drop() - await cls.client.close() - await cls.client_encrypted.close() - await cls.client_encryption.close() - - def setUp(self): self.listener.reset() + async def asyncTearDown(self) -> None: + await self.vault.drop() + async def run_test(self, provider_name): # Create data key. master_key: Any = self.MASTER_KEYS[provider_name] @@ -1011,10 +1005,9 @@ async def test_views_are_prohibited(self): class TestCorpus(AsyncEncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self): + await super().asyncSetUp() @staticmethod def kms_providers(): @@ -1188,12 +1181,11 @@ class TestBsonSizeBatches(AsyncEncryptionIntegrationTest): client_encrypted: AsyncMongoClient listener: OvertCommandListener - @classmethod - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self): + await super().asyncSetUp() db = async_client_context.client.db - cls.coll = db.coll - await cls.coll.drop() + self.coll = db.coll + await self.coll.drop() # Configure the encrypted 'db.coll' collection via jsonSchema. json_schema = json_data("limits", "limits-schema.json") await db.create_collection( @@ -1211,17 +1203,14 @@ async def _setup_class(cls): await coll.insert_one(json_data("limits", "limits-key.json")) opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") - cls.listener = OvertCommandListener() - cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client( - auto_encryption_opts=opts, event_listeners=[cls.listener] + self.listener = OvertCommandListener() + self.client_encrypted = await self.async_rs_or_single_client( + auto_encryption_opts=opts, event_listeners=[self.listener] ) - cls.coll_encrypted = cls.client_encrypted.db.coll + self.coll_encrypted = self.client_encrypted.db.coll - @classmethod - async def _tearDown_class(cls): - await cls.coll_encrypted.drop() - await cls.client_encrypted.close() - await super()._tearDown_class() + async def asyncTearDown(self) -> None: + await self.coll_encrypted.drop() async def test_01_insert_succeeds_under_2MiB(self): doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB} @@ -1245,7 +1234,9 @@ async def test_03_bulk_batch_split(self): doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB} self.listener.reset() await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)]) - self.assertEqual(self.listener.started_command_names(), ["insert", "insert"]) + self.assertEqual( + len([c for c in self.listener.started_command_names() if c == "insert"]), 2 + ) async def test_04_bulk_batch_split(self): limits_doc = json_data("limits", "limits-doc.json") @@ -1255,7 +1246,9 @@ async def test_04_bulk_batch_split(self): doc2.update(limits_doc) self.listener.reset() await self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)]) - self.assertEqual(self.listener.started_command_names(), ["insert", "insert"]) + self.assertEqual( + len([c for c in self.listener.started_command_names() if c == "insert"]), 2 + ) async def test_05_insert_succeeds_just_under_16MiB(self): doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)} @@ -1285,15 +1278,12 @@ async def test_06_insert_fails_over_16MiB(self): class TestCustomEndpoint(AsyncEncryptionIntegrationTest): """Prose tests for creating data keys with a custom endpoint.""" - @classmethod @unittest.skipUnless( any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), "No environment credentials are set", ) - async def _setup_class(cls): - await super()._setup_class() - - def setUp(self): + async def asyncSetUp(self): + await super().asyncSetUp() kms_providers = { "aws": AWS_CREDS, "azure": AZURE_CREDS, @@ -1322,10 +1312,6 @@ def setUp(self): self._kmip_host_error = None self._invalid_host_error = None - async def asyncTearDown(self): - await self.client_encryption.close() - await self.client_encryption_invalid.close() - async def run_test_expected_success(self, provider_name, master_key): data_key_id = await self.client_encryption.create_data_key( provider_name, master_key=master_key @@ -1500,18 +1486,18 @@ class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest): KEYVAULT_COLL = "datakeys" client: AsyncMongoClient - async def asyncSetUp(self): + async def _setup(self): keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL) await create_key_vault(keyvault, self.DEK) async def _test_explicit(self, expectation): + await self._setup() client_encryption = self.create_client_encryption( self.KMS_PROVIDER_MAP, # type: ignore[arg-type] ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]), async_client_context.client, OPTS, ) - self.addAsyncCleanup(client_encryption.close) ciphertext = await client_encryption.encrypt( "string0", @@ -1523,6 +1509,7 @@ async def _test_explicit(self, expectation): self.assertEqual(await client_encryption.decrypt(ciphertext), "string0") async def _test_automatic(self, expectation_extjson, payload): + await self._setup() encrypted_db = "db" encrypted_coll = "coll" keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]) @@ -1537,7 +1524,6 @@ async def _test_automatic(self, expectation_extjson, payload): client = await self.async_rs_or_single_client( auto_encryption_opts=encryption_opts, event_listeners=[insert_listener] ) - self.addAsyncCleanup(client.aclose) coll = client.get_database(encrypted_db).get_collection( encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority") @@ -1559,13 +1545,12 @@ async def _test_automatic(self, expectation_extjson, payload): class TestAzureEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set") - async def _setup_class(cls): - cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} - cls.DEK = json_data(BASE, "custom", "azure-dek.json") - cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") - await super()._setup_class() + async def asyncSetUp(self): + self.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} + self.DEK = json_data(BASE, "custom", "azure-dek.json") + self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") + await super().asyncSetUp() async def test_explicit(self): return await self._test_explicit( @@ -1585,13 +1570,12 @@ async def test_automatic(self): class TestGCPEncryption(AzureGCPEncryptionTestMixin, AsyncEncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set") - async def _setup_class(cls): - cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} - cls.DEK = json_data(BASE, "custom", "gcp-dek.json") - cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") - await super()._setup_class() + async def asyncSetUp(self): + self.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} + self.DEK = json_data(BASE, "custom", "gcp-dek.json") + self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") + await super().asyncSetUp() async def test_explicit(self): return await self._test_explicit( @@ -1613,6 +1597,7 @@ async def test_automatic(self): # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests class TestDeadlockProse(AsyncEncryptionIntegrationTest): async def asyncSetUp(self): + await super().asyncSetUp() self.client_test = await self.async_rs_or_single_client( maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard" ) @@ -1645,7 +1630,6 @@ async def asyncSetUp(self): self.ciphertext = await client_encryption.encrypt( "string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local" ) - await client_encryption.close() self.client_listener = OvertCommandListener() self.topology_listener = TopologyEventListener() @@ -1840,6 +1824,7 @@ async def test_case_8(self): # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events class TestDecryptProse(AsyncEncryptionIntegrationTest): async def asyncSetUp(self): + await super().asyncSetUp() self.client = async_client_context.client await self.client.db.drop_collection("decryption_events") await create_key_vault(self.client.keyvault.datakeys) @@ -2275,6 +2260,7 @@ async def test_06_named_kms_providers_apply_tls_options_kmip(self): # https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames class TestUniqueIndexOnKeyAltNamesProse(AsyncEncryptionIntegrationTest): async def asyncSetUp(self): + await super().asyncSetUp() self.client = async_client_context.client await create_key_vault(self.client.keyvault.datakeys) kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}} @@ -2624,8 +2610,6 @@ async def AsyncMongoClient(**kwargs): assert isinstance(res["encrypted_indexed"], Binary) assert isinstance(res["encrypted_unindexed"], Binary) - await client_encryption.close() - # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption class TestRangeQueryProse(AsyncEncryptionIntegrationTest): @@ -3089,17 +3073,11 @@ class TestNoSessionsSupport(AsyncEncryptionIntegrationTest): mongocryptd_client: AsyncMongoClient MONGOCRYPTD_PORT = 27020 - @classmethod @unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed") - async def _setup_class(cls): - await super()._setup_class() - start_mongocryptd(cls.MONGOCRYPTD_PORT) - - @classmethod - async def _tearDown_class(cls): - await super()._tearDown_class() - async def asyncSetUp(self) -> None: + await super().asyncSetUp() + start_mongocryptd(self.MONGOCRYPTD_PORT) + self.listener = OvertCommandListener() self.mongocryptd_client = self.simple_client( f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener] diff --git a/test/asynchronous/test_grid_file.py b/test/asynchronous/test_grid_file.py index 54fcd3abf6..affdacde91 100644 --- a/test/asynchronous/test_grid_file.py +++ b/test/asynchronous/test_grid_file.py @@ -97,6 +97,7 @@ def test_grid_in_custom_opts(self): class AsyncTestGridFile(AsyncIntegrationTest): async def asyncSetUp(self): + await super().asyncSetUp() await self.cleanup_colls(self.db.fs.files, self.db.fs.chunks) async def test_basic(self): diff --git a/test/asynchronous/test_locks.py b/test/asynchronous/test_locks.py index e0e7f2fc8d..e5a0adfee6 100644 --- a/test/asynchronous/test_locks.py +++ b/test/asynchronous/test_locks.py @@ -16,498 +16,447 @@ import asyncio import sys -import threading import unittest +from pymongo.lock import _async_create_condition, _async_create_lock + sys.path[0:0] = [""] -from pymongo.lock import _ACondition +if sys.version_info < (3, 13): + # Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py + # Includes tests for: + # - https://github.com/python/cpython/issues/111693 + # - https://github.com/python/cpython/issues/112202 + class TestConditionStdlib(unittest.IsolatedAsyncioTestCase): + async def test_wait(self): + cond = _async_create_condition(_async_create_lock()) + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait(): + result.append(1) + return True + async def c2(result): + await cond.acquire() + if await cond.wait(): + result.append(2) + return True -# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py -# Includes tests for: -# - https://github.com/python/cpython/issues/111693 -# - https://github.com/python/cpython/issues/112202 -class TestConditionStdlib(unittest.IsolatedAsyncioTestCase): - async def test_wait(self): - cond = _ACondition(threading.Condition(threading.Lock())) - result = [] + async def c3(result): + await cond.acquire() + if await cond.wait(): + result.append(3) + return True - async def c1(result): - await cond.acquire() - if await cond.wait(): - result.append(1) - return True + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) - async def c2(result): - await cond.acquire() - if await cond.wait(): - result.append(2) - return True + await asyncio.sleep(0) + self.assertEqual([], result) + self.assertFalse(cond.locked()) - async def c3(result): - await cond.acquire() - if await cond.wait(): - result.append(3) - return True - - t1 = asyncio.create_task(c1(result)) - t2 = asyncio.create_task(c2(result)) - t3 = asyncio.create_task(c3(result)) - - await asyncio.sleep(0) - self.assertEqual([], result) - self.assertFalse(cond.locked()) - - self.assertTrue(await cond.acquire()) - cond.notify() - await asyncio.sleep(0) - self.assertEqual([], result) - self.assertTrue(cond.locked()) - - cond.release() - await asyncio.sleep(0) - self.assertEqual([1], result) - self.assertTrue(cond.locked()) - - cond.notify(2) - await asyncio.sleep(0) - self.assertEqual([1], result) - self.assertTrue(cond.locked()) - - cond.release() - await asyncio.sleep(0) - self.assertEqual([1, 2], result) - self.assertTrue(cond.locked()) - - cond.release() - await asyncio.sleep(0) - self.assertEqual([1, 2, 3], result) - self.assertTrue(cond.locked()) - - self.assertTrue(t1.done()) - self.assertTrue(t1.result()) - self.assertTrue(t2.done()) - self.assertTrue(t2.result()) - self.assertTrue(t3.done()) - self.assertTrue(t3.result()) - - async def test_wait_cancel(self): - cond = _ACondition(threading.Condition(threading.Lock())) - await cond.acquire() - - wait = asyncio.create_task(cond.wait()) - asyncio.get_running_loop().call_soon(wait.cancel) - with self.assertRaises(asyncio.CancelledError): - await wait - self.assertFalse(cond._waiters) - self.assertTrue(cond.locked()) - - async def test_wait_cancel_contested(self): - cond = _ACondition(threading.Condition(threading.Lock())) - - await cond.acquire() - self.assertTrue(cond.locked()) - - wait_task = asyncio.create_task(cond.wait()) - await asyncio.sleep(0) - self.assertFalse(cond.locked()) - - # Notify, but contest the lock before cancelling - await cond.acquire() - self.assertTrue(cond.locked()) - cond.notify() - asyncio.get_running_loop().call_soon(wait_task.cancel) - asyncio.get_running_loop().call_soon(cond.release) - - try: - await wait_task - except asyncio.CancelledError: - # Should not happen, since no cancellation points - pass - - self.assertTrue(cond.locked()) - - async def test_wait_cancel_after_notify(self): - # See bpo-32841 - waited = False - - cond = _ACondition(threading.Condition(threading.Lock())) - - async def wait_on_cond(): - nonlocal waited - async with cond: - waited = True # Make sure this area was reached - await cond.wait() + self.assertTrue(await cond.acquire()) + cond.notify() + await asyncio.sleep(0) + self.assertEqual([], result) + self.assertTrue(cond.locked()) - waiter = asyncio.create_task(wait_on_cond()) - await asyncio.sleep(0) # Start waiting + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) - await cond.acquire() - cond.notify() - await asyncio.sleep(0) # Get to acquire() - waiter.cancel() - await asyncio.sleep(0) # Activate cancellation - cond.release() - await asyncio.sleep(0) # Cancellation should occur + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) - self.assertTrue(waiter.cancelled()) - self.assertTrue(waited) + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) - async def test_wait_unacquired(self): - cond = _ACondition(threading.Condition(threading.Lock())) - with self.assertRaises(RuntimeError): - await cond.wait() + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) - async def test_wait_for(self): - cond = _ACondition(threading.Condition(threading.Lock())) - presult = False + async def test_wait_cancel(self): + cond = _async_create_condition(_async_create_lock()) + await cond.acquire() - def predicate(): - return presult + wait = asyncio.create_task(cond.wait()) + asyncio.get_running_loop().call_soon(wait.cancel) + with self.assertRaises(asyncio.CancelledError): + await wait + self.assertFalse(cond._waiters) + self.assertTrue(cond.locked()) - result = [] + async def test_wait_cancel_contested(self): + cond = _async_create_condition(_async_create_lock()) - async def c1(result): await cond.acquire() - if await cond.wait_for(predicate): - result.append(1) - cond.release() - return True + self.assertTrue(cond.locked()) - t = asyncio.create_task(c1(result)) + wait_task = asyncio.create_task(cond.wait()) + await asyncio.sleep(0) + self.assertFalse(cond.locked()) - await asyncio.sleep(0) - self.assertEqual([], result) + # Notify, but contest the lock before cancelling + await cond.acquire() + self.assertTrue(cond.locked()) + cond.notify() + asyncio.get_running_loop().call_soon(wait_task.cancel) + asyncio.get_running_loop().call_soon(cond.release) - await cond.acquire() - cond.notify() - cond.release() - await asyncio.sleep(0) - self.assertEqual([], result) + try: + await wait_task + except asyncio.CancelledError: + # Should not happen, since no cancellation points + pass - presult = True - await cond.acquire() - cond.notify() - cond.release() - await asyncio.sleep(0) - self.assertEqual([1], result) + self.assertTrue(cond.locked()) - self.assertTrue(t.done()) - self.assertTrue(t.result()) + async def test_wait_cancel_after_notify(self): + # See bpo-32841 + waited = False - async def test_wait_for_unacquired(self): - cond = _ACondition(threading.Condition(threading.Lock())) + cond = _async_create_condition(_async_create_lock()) - # predicate can return true immediately - res = await cond.wait_for(lambda: [1, 2, 3]) - self.assertEqual([1, 2, 3], res) + async def wait_on_cond(): + nonlocal waited + async with cond: + waited = True # Make sure this area was reached + await cond.wait() - with self.assertRaises(RuntimeError): - await cond.wait_for(lambda: False) + waiter = asyncio.create_task(wait_on_cond()) + await asyncio.sleep(0) # Start waiting - async def test_notify(self): - cond = _ACondition(threading.Condition(threading.Lock())) - result = [] + await cond.acquire() + cond.notify() + await asyncio.sleep(0) # Get to acquire() + waiter.cancel() + await asyncio.sleep(0) # Activate cancellation + cond.release() + await asyncio.sleep(0) # Cancellation should occur + + self.assertTrue(waiter.cancelled()) + self.assertTrue(waited) + + async def test_wait_unacquired(self): + cond = _async_create_condition(_async_create_lock()) + with self.assertRaises(RuntimeError): + await cond.wait() - async def c1(result): - async with cond: - if await cond.wait(): - result.append(1) - return True + async def test_wait_for(self): + cond = _async_create_condition(_async_create_lock()) + presult = False - async def c2(result): - async with cond: - if await cond.wait(): - result.append(2) - return True + def predicate(): + return presult - async def c3(result): - async with cond: - if await cond.wait(): - result.append(3) + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait_for(predicate): + result.append(1) + cond.release() return True - t1 = asyncio.create_task(c1(result)) - t2 = asyncio.create_task(c2(result)) - t3 = asyncio.create_task(c3(result)) + t = asyncio.create_task(c1(result)) - await asyncio.sleep(0) - self.assertEqual([], result) + await asyncio.sleep(0) + self.assertEqual([], result) - async with cond: - cond.notify(1) - await asyncio.sleep(1) - self.assertEqual([1], result) + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([], result) - async with cond: - cond.notify(1) - cond.notify(2048) - await asyncio.sleep(1) - self.assertEqual([1, 2, 3], result) + presult = True + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) - self.assertTrue(t1.done()) - self.assertTrue(t1.result()) - self.assertTrue(t2.done()) - self.assertTrue(t2.result()) - self.assertTrue(t3.done()) - self.assertTrue(t3.result()) + self.assertTrue(t.done()) + self.assertTrue(t.result()) - async def test_notify_all(self): - cond = _ACondition(threading.Condition(threading.Lock())) + async def test_wait_for_unacquired(self): + cond = _async_create_condition(_async_create_lock()) - result = [] + # predicate can return true immediately + res = await cond.wait_for(lambda: [1, 2, 3]) + self.assertEqual([1, 2, 3], res) - async def c1(result): - async with cond: - if await cond.wait(): - result.append(1) - return True + with self.assertRaises(RuntimeError): + await cond.wait_for(lambda: False) - async def c2(result): - async with cond: - if await cond.wait(): - result.append(2) - return True + async def test_notify(self): + cond = _async_create_condition(_async_create_lock()) + result = [] - t1 = asyncio.create_task(c1(result)) - t2 = asyncio.create_task(c2(result)) + async def c1(result): + async with cond: + if await cond.wait(): + result.append(1) + return True - await asyncio.sleep(0) - self.assertEqual([], result) + async def c2(result): + async with cond: + if await cond.wait(): + result.append(2) + return True - async with cond: - cond.notify_all() - await asyncio.sleep(1) - self.assertEqual([1, 2], result) + async def c3(result): + async with cond: + if await cond.wait(): + result.append(3) + return True - self.assertTrue(t1.done()) - self.assertTrue(t1.result()) - self.assertTrue(t2.done()) - self.assertTrue(t2.result()) + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) - async def test_context_manager(self): - cond = _ACondition(threading.Condition(threading.Lock())) - self.assertFalse(cond.locked()) - async with cond: - self.assertTrue(cond.locked()) - self.assertFalse(cond.locked()) - - async def test_timeout_in_block(self): - condition = _ACondition(threading.Condition(threading.Lock())) - async with condition: - with self.assertRaises(asyncio.TimeoutError): - await asyncio.wait_for(condition.wait(), timeout=0.5) - - @unittest.skipIf( - sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" - ) - async def test_cancelled_error_wakeup(self): - # Test that a cancelled error, received when awaiting wakeup, - # will be re-raised un-modified. - wake = False - raised = None - cond = _ACondition(threading.Condition(threading.Lock())) - - async def func(): - nonlocal raised - async with cond: - with self.assertRaises(asyncio.CancelledError) as err: - await cond.wait_for(lambda: wake) - raised = err.exception - raise raised - - task = asyncio.create_task(func()) - await asyncio.sleep(0) - # Task is waiting on the condition, cancel it there. - task.cancel(msg="foo") # type: ignore[call-arg] - with self.assertRaises(asyncio.CancelledError) as err: - await task - self.assertEqual(err.exception.args, ("foo",)) - # We should have got the _same_ exception instance as the one - # originally raised. - self.assertIs(err.exception, raised) - - @unittest.skipIf( - sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" - ) - async def test_cancelled_error_re_aquire(self): - # Test that a cancelled error, received when re-aquiring lock, - # will be re-raised un-modified. - wake = False - raised = None - cond = _ACondition(threading.Condition(threading.Lock())) - - async def func(): - nonlocal raised - async with cond: - with self.assertRaises(asyncio.CancelledError) as err: - await cond.wait_for(lambda: wake) - raised = err.exception - raise raised - - task = asyncio.create_task(func()) - await asyncio.sleep(0) - # Task is waiting on the condition - await cond.acquire() - wake = True - cond.notify() - await asyncio.sleep(0) - # Task is now trying to re-acquire the lock, cancel it there. - task.cancel(msg="foo") # type: ignore[call-arg] - cond.release() - with self.assertRaises(asyncio.CancelledError) as err: - await task - self.assertEqual(err.exception.args, ("foo",)) - # We should have got the _same_ exception instance as the one - # originally raised. - self.assertIs(err.exception, raised) - - @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") - async def test_cancelled_wakeup(self): - # Test that a task cancelled at the "same" time as it is woken - # up as part of a Condition.notify() does not result in a lost wakeup. - # This test simulates a cancel while the target task is awaiting initial - # wakeup on the wakeup queue. - condition = _ACondition(threading.Condition(threading.Lock())) - state = 0 - - async def consumer(): - nonlocal state - async with condition: - while True: - await condition.wait_for(lambda: state != 0) - if state < 0: - return - state -= 1 - - # create two consumers - c = [asyncio.create_task(consumer()) for _ in range(2)] - # wait for them to settle - await asyncio.sleep(0.1) - async with condition: - # produce one item and wake up one - state += 1 - condition.notify(1) - - # Cancel it while it is awaiting to be run. - # This cancellation could come from the outside - c[0].cancel() - - # now wait for the item to be consumed - # if it doesn't means that our "notify" didn"t take hold. - # because it raced with a cancel() - try: - async with asyncio.timeout(1): - await condition.wait_for(lambda: state == 0) - except TimeoutError: - pass - self.assertEqual(state, 0) - - # clean up - state = -1 - condition.notify_all() - await c[1] - - @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") - async def test_cancelled_wakeup_relock(self): - # Test that a task cancelled at the "same" time as it is woken - # up as part of a Condition.notify() does not result in a lost wakeup. - # This test simulates a cancel while the target task is acquiring the lock - # again. - condition = _ACondition(threading.Condition(threading.Lock())) - state = 0 - - async def consumer(): - nonlocal state - async with condition: - while True: - await condition.wait_for(lambda: state != 0) - if state < 0: - return - state -= 1 - - # create two consumers - c = [asyncio.create_task(consumer()) for _ in range(2)] - # wait for them to settle - await asyncio.sleep(0.1) - async with condition: - # produce one item and wake up one - state += 1 - condition.notify(1) - - # now we sleep for a bit. This allows the target task to wake up and - # settle on re-aquiring the lock await asyncio.sleep(0) + self.assertEqual([], result) - # Cancel it while awaiting the lock - # This cancel could come the outside. - c[0].cancel() + async with cond: + cond.notify(1) + await asyncio.sleep(1) + self.assertEqual([1], result) - # now wait for the item to be consumed - # if it doesn't means that our "notify" didn"t take hold. - # because it raced with a cancel() - try: - async with asyncio.timeout(1): - await condition.wait_for(lambda: state == 0) - except TimeoutError: - pass - self.assertEqual(state, 0) + async with cond: + cond.notify(1) + cond.notify(2048) + await asyncio.sleep(1) + self.assertEqual([1, 2, 3], result) - # clean up - state = -1 - condition.notify_all() - await c[1] + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + async def test_notify_all(self): + cond = _async_create_condition(_async_create_lock()) -class TestCondition(unittest.IsolatedAsyncioTestCase): - async def test_multiple_loops_notify(self): - cond = _ACondition(threading.Condition(threading.Lock())) + result = [] - def tmain(cond): - async def atmain(cond): - await asyncio.sleep(1) + async def c1(result): async with cond: - cond.notify(1) - - asyncio.run(atmain(cond)) - - t = threading.Thread(target=tmain, args=(cond,)) - t.start() + if await cond.wait(): + result.append(1) + return True - async with cond: - self.assertTrue(await cond.wait(30)) - t.join() - - async def test_multiple_loops_notify_all(self): - cond = _ACondition(threading.Condition(threading.Lock())) - results = [] - - def tmain(cond, results): - async def atmain(cond, results): - await asyncio.sleep(1) + async def c2(result): async with cond: - res = await cond.wait(30) - results.append(res) - - asyncio.run(atmain(cond, results)) + if await cond.wait(): + result.append(2) + return True - nthreads = 5 - threads = [] - for _ in range(nthreads): - threads.append(threading.Thread(target=tmain, args=(cond, results))) - for t in threads: - t.start() + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) - await asyncio.sleep(2) - async with cond: - cond.notify_all() + await asyncio.sleep(0) + self.assertEqual([], result) - for t in threads: - t.join() + async with cond: + cond.notify_all() + await asyncio.sleep(1) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + async def test_context_manager(self): + cond = _async_create_condition(_async_create_lock()) + self.assertFalse(cond.locked()) + async with cond: + self.assertTrue(cond.locked()) + self.assertFalse(cond.locked()) - self.assertEqual(results, [True] * nthreads) + async def test_timeout_in_block(self): + condition = _async_create_condition(_async_create_lock()) + async with condition: + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for(condition.wait(), timeout=0.5) + + @unittest.skipIf( + sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" + ) + async def test_cancelled_error_wakeup(self): + # Test that a cancelled error, received when awaiting wakeup, + # will be re-raised un-modified. + wake = False + raised = None + cond = _async_create_condition(_async_create_lock()) + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition, cancel it there. + task.cancel(msg="foo") # type: ignore[call-arg] + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + @unittest.skipIf( + sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" + ) + async def test_cancelled_error_re_aquire(self): + # Test that a cancelled error, received when re-aquiring lock, + # will be re-raised un-modified. + wake = False + raised = None + cond = _async_create_condition(_async_create_lock()) + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised -if __name__ == "__main__": - unittest.main() + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition + await cond.acquire() + wake = True + cond.notify() + await asyncio.sleep(0) + # Task is now trying to re-acquire the lock, cancel it there. + task.cancel(msg="foo") # type: ignore[call-arg] + cond.release() + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") + async def test_cancelled_wakeup(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is awaiting initial + # wakeup on the wakeup queue. + condition = _async_create_condition(_async_create_lock()) + state = 0 + + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0.1) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # Cancel it while it is awaiting to be run. + # This cancellation could come from the outside + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(1): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + + @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") + async def test_cancelled_wakeup_relock(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is acquiring the lock + # again. + condition = _async_create_condition(_async_create_lock()) + state = 0 + + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0.1) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # now we sleep for a bit. This allows the target task to wake up and + # settle on re-aquiring the lock + await asyncio.sleep(0) + + # Cancel it while awaiting the lock + # This cancel could come the outside. + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(1): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + + if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_monitoring.py b/test/asynchronous/test_monitoring.py index b0c86ab54e..eaad60beac 100644 --- a/test/asynchronous/test_monitoring.py +++ b/test/asynchronous/test_monitoring.py @@ -52,22 +52,16 @@ class AsyncTestCommandMonitoring(AsyncIntegrationTest): listener: EventListener @classmethod - @async_client_context.require_connection - async def _setup_class(cls): - await super()._setup_class() + def setUpClass(cls) -> None: cls.listener = OvertCommandListener() - cls.client = await cls.unmanaged_async_rs_or_single_client( - event_listeners=[cls.listener], retryWrites=False - ) - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - await super()._tearDown_class() - - async def asyncTearDown(self): + @async_client_context.require_connection + async def asyncSetUp(self) -> None: + await super().asyncSetUp() self.listener.reset() - await super().asyncTearDown() + self.client = await self.async_rs_or_single_client( + event_listeners=[self.listener], retryWrites=False + ) async def test_started_simple(self): await self.client.pymongo_test.command("ping") @@ -1140,26 +1134,23 @@ class AsyncTestGlobalListener(AsyncIntegrationTest): saved_listeners: Any @classmethod - @async_client_context.require_connection - async def _setup_class(cls): - await super()._setup_class() + def setUpClass(cls) -> None: cls.listener = OvertCommandListener() # We plan to call register(), which internally modifies _LISTENERS. cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS) monitoring.register(cls.listener) - cls.client = await cls.unmanaged_async_single_client() - # Get one (authenticated) socket in the pool. - await cls.client.pymongo_test.command("ping") - - @classmethod - async def _tearDown_class(cls): - monitoring._LISTENERS = cls.saved_listeners - await cls.client.close() - await super()._tearDown_class() + @async_client_context.require_connection async def asyncSetUp(self): await super().asyncSetUp() self.listener.reset() + self.client = await self.async_single_client() + # Get one (authenticated) socket in the pool. + await self.client.pymongo_test.command("ping") + + @classmethod + def tearDownClass(cls): + monitoring._LISTENERS = cls.saved_listeners async def test_simple(self): await self.client.pymongo_test.command("ping") diff --git a/test/asynchronous/test_retryable_writes.py b/test/asynchronous/test_retryable_writes.py index ca2f0a5422..738ce04192 100644 --- a/test/asynchronous/test_retryable_writes.py +++ b/test/asynchronous/test_retryable_writes.py @@ -132,34 +132,27 @@ class IgnoreDeprecationsTest(AsyncIntegrationTest): RUN_ON_SERVERLESS = True deprecation_filter: DeprecationFilter - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.deprecation_filter = DeprecationFilter() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.deprecation_filter = DeprecationFilter() - @classmethod - async def _tearDown_class(cls): - cls.deprecation_filter.stop() - await super()._tearDown_class() + async def asyncTearDown(self) -> None: + self.deprecation_filter.stop() class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest): knobs: client_knobs - @classmethod - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - cls.client = await cls.unmanaged_async_rs_or_single_client(retryWrites=True) - cls.db = cls.client.pymongo_test + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() + self.client = await self.async_rs_or_single_client(retryWrites=True) + self.db = self.client.pymongo_test - @classmethod - async def _tearDown_class(cls): - cls.knobs.disable() - await cls.client.close() - await super()._tearDown_class() + async def asyncTearDown(self) -> None: + self.knobs.disable() @async_client_context.require_no_standalone async def test_actionable_error_message(self): @@ -180,26 +173,18 @@ class TestRetryableWrites(IgnoreDeprecationsTest): listener: OvertCommandListener knobs: client_knobs - @classmethod @async_client_context.require_no_mmap - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self) -> None: + await super().asyncSetUp() # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - cls.listener = OvertCommandListener() - cls.client = await cls.unmanaged_async_rs_or_single_client( - retryWrites=True, event_listeners=[cls.listener] + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() + self.listener = OvertCommandListener() + self.client = await self.async_rs_or_single_client( + retryWrites=True, event_listeners=[self.listener] ) - cls.db = cls.client.pymongo_test + self.db = self.client.pymongo_test - @classmethod - async def _tearDown_class(cls): - cls.knobs.disable() - await cls.client.close() - await super()._tearDown_class() - - async def asyncSetUp(self): if async_client_context.is_rs and async_client_context.test_commands_enabled: await self.client.admin.command( SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")]) @@ -210,6 +195,7 @@ async def asyncTearDown(self): await self.client.admin.command( SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")]) ) + self.knobs.disable() async def test_supported_single_statement_no_retry(self): listener = OvertCommandListener() @@ -438,13 +424,12 @@ class TestWriteConcernError(AsyncIntegrationTest): RUN_ON_SERVERLESS = True fail_insert: dict - @classmethod @async_client_context.require_replica_set @async_client_context.require_no_mmap @async_client_context.require_failCommand_fail_point - async def _setup_class(cls): - await super()._setup_class() - cls.fail_insert = { + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.fail_insert = { "configureFailPoint": "failCommand", "mode": {"times": 2}, "data": { diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index b432621798..42bc253b56 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -38,7 +38,6 @@ ExceptionCatchingThread, OvertCommandListener, async_wait_until, - wait_until, ) from bson import DBRef @@ -83,36 +82,27 @@ class TestSession(AsyncIntegrationTest): client2: AsyncMongoClient sensitive_commands: Set[str] - @classmethod @async_client_context.require_sessions - async def _setup_class(cls): - await super()._setup_class() + async def asyncSetUp(self): + await super().asyncSetUp() # Create a second client so we can make sure clients cannot share # sessions. - cls.client2 = await cls.unmanaged_async_rs_or_single_client() + self.client2 = await self.async_rs_or_single_client() # Redact no commands, so we can test user-admin commands have "lsid". - cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() + self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() monitoring._SENSITIVE_COMMANDS.clear() - @classmethod - async def _tearDown_class(cls): - monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands) - await cls.client2.close() - await super()._tearDown_class() - - async def asyncSetUp(self): self.listener = SessionTestListener() self.session_checker_listener = SessionTestListener() self.client = await self.async_rs_or_single_client( event_listeners=[self.listener, self.session_checker_listener] ) - self.addAsyncCleanup(self.client.close) self.db = self.client.pymongo_test self.initial_lsids = {s["id"] for s in session_ids(self.client)} async def asyncTearDown(self): - """All sessions used in the test must be returned to the pool.""" + monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands) await self.client.drop_database("pymongo_test") used_lsids = self.initial_lsids.copy() for event in self.session_checker_listener.started_events: @@ -122,6 +112,8 @@ async def asyncTearDown(self): current_lsids = {s["id"] for s in session_ids(self.client)} self.assertLessEqual(used_lsids, current_lsids) + await super().asyncTearDown() + async def _test_ops(self, client, *ops): listener = client.options.event_listeners[0] @@ -833,18 +825,11 @@ class TestCausalConsistency(AsyncUnitTest): listener: SessionTestListener client: AsyncMongoClient - @classmethod - async def _setup_class(cls): - cls.listener = SessionTestListener() - cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) - - @classmethod - async def _tearDown_class(cls): - await cls.client.close() - @async_client_context.require_sessions async def asyncSetUp(self): await super().asyncSetUp() + self.listener = SessionTestListener() + self.client = await self.async_rs_or_single_client(event_listeners=[self.listener]) @async_client_context.require_no_standalone async def test_core(self): diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index b5d0686417..d11d0a9776 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -26,7 +26,7 @@ from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest from test.utils import ( OvertCommandListener, - wait_until, + async_wait_until, ) from typing import List @@ -162,7 +162,7 @@ async def test_unpin_for_next_transaction(self): client = await self.async_rs_client( async_client_context.mongos_seeds(), localThresholdMS=1000 ) - wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") + await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. await coll.insert_one({}) @@ -191,7 +191,7 @@ async def test_unpin_for_non_transaction_operation(self): client = await self.async_rs_client( async_client_context.mongos_seeds(), localThresholdMS=1000 ) - wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") + await async_wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. await coll.insert_one({}) @@ -403,21 +403,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): class TestTransactionsConvenientAPI(AsyncTransactionsBase): - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.mongos_clients = [] + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.mongos_clients = [] if async_client_context.supports_transactions(): for address in async_client_context.mongoses: - cls.mongos_clients.append( - await cls.unmanaged_async_single_client("{}:{}".format(*address)) - ) - - @classmethod - async def _tearDown_class(cls): - for client in cls.mongos_clients: - await client.close() - await super()._tearDown_class() + self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address))) async def _set_fail_point(self, client, command_args): cmd = {"configureFailPoint": "failCommand"} diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index db5ed81e24..b18b09383e 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -50,6 +50,7 @@ ) from test.utils import ( async_get_pool, + async_wait_until, camel_to_snake, camel_to_snake_args, parse_spec_options, @@ -304,7 +305,6 @@ async def _create_entity(self, entity_spec, uri=None): kwargs["h"] = uri client = await self.test.async_rs_or_single_client(**kwargs) self[spec["id"]] = client - self.test.addAsyncCleanup(client.close) return elif entity_type == "database": client = self[spec["client"]] @@ -479,54 +479,47 @@ async def insert_initial_data(self, initial_data): await db.create_collection(coll_name, write_concern=wc, **opts) @classmethod - async def _setup_class(cls): + def setUpClass(cls) -> None: + # Speed up the tests by decreasing the heartbeat frequency. + cls.knobs = client_knobs( + heartbeat_frequency=0.1, + min_heartbeat_interval=0.1, + kill_cursor_frequency=0.1, + events_queue_frequency=0.1, + ) + cls.knobs.enable() + + @classmethod + def tearDownClass(cls) -> None: + cls.knobs.disable() + + async def asyncSetUp(self): # super call creates internal client cls.client - await super()._setup_class() + await super().asyncSetUp() # process file-level runOnRequirements - run_on_spec = cls.TEST_SPEC.get("runOnRequirements", []) - if not await cls.should_run_on(run_on_spec): - raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied") + run_on_spec = self.TEST_SPEC.get("runOnRequirements", []) + if not await self.should_run_on(run_on_spec): + raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied") # add any special-casing for skipping tests here if async_client_context.storage_engine == "mmapv1": - if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str( - cls.TEST_PATH + if "retryable-writes" in self.TEST_SPEC["description"] or "retryable_writes" in str( + self.TEST_PATH ): raise unittest.SkipTest("MMAPv1 does not support retryWrites=True") # Handle mongos_clients for transactions tests. - cls.mongos_clients = [] + self.mongos_clients = [] if ( async_client_context.supports_transactions() and not async_client_context.load_balancer and not async_client_context.serverless ): for address in async_client_context.mongoses: - cls.mongos_clients.append( - await cls.unmanaged_async_single_client("{}:{}".format(*address)) - ) + self.mongos_clients.append(await self.async_single_client("{}:{}".format(*address))) - # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs( - heartbeat_frequency=0.1, - min_heartbeat_interval=0.1, - kill_cursor_frequency=0.1, - events_queue_frequency=0.1, - ) - cls.knobs.enable() - - @classmethod - async def _tearDown_class(cls): - cls.knobs.disable() - for client in cls.mongos_clients: - await client.close() - await super()._tearDown_class() - - async def asyncSetUp(self): - await super().asyncSetUp() # process schemaVersion # note: we check major schema version during class generation - # note: we do this here because we cannot run assertions in setUpClass version = Version.from_string(self.TEST_SPEC["schemaVersion"]) self.assertLessEqual( version, @@ -1036,7 +1029,6 @@ async def _testOperation_targetedFailPoint(self, spec): ) client = await self.async_single_client("{}:{}".format(*session._pinned_address)) - self.addAsyncCleanup(client.close) await self.__set_fail_point(client=client, command_args=spec["failPoint"]) async def _testOperation_createEntities(self, spec): @@ -1137,13 +1129,13 @@ def _testOperation_assertEventCount(self, spec): client, event, count = spec["client"], spec["event"], spec["count"] self.assertEqual(self._event_count(client, event), count, f"expected {count} not {event!r}") - def _testOperation_waitForEvent(self, spec): + async def _testOperation_waitForEvent(self, spec): """Run the waitForEvent test operation. Wait for a number of events to be published, or fail. """ client, event, count = spec["client"], spec["event"], spec["count"] - wait_until( + await async_wait_until( lambda: self._event_count(client, event) >= count, f"find {count} {event} event(s)", ) diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index f27f52ec2c..b79e5258b5 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -249,30 +249,22 @@ class AsyncSpecRunner(AsyncIntegrationTest): knobs: client_knobs listener: EventListener - @classmethod - async def _setup_class(cls): - await super()._setup_class() - cls.mongos_clients = [] + async def asyncSetUp(self) -> None: + await super().asyncSetUp() + self.mongos_clients = [] # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - - @classmethod - async def _tearDown_class(cls): - cls.knobs.disable() - for client in cls.mongos_clients: - await client.close() - await super()._tearDown_class() - - def setUp(self): - super().setUp() + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() self.targets = {} self.listener = None # type: ignore self.pool_listener = None self.server_listener = None self.maxDiff = None + async def asyncTearDown(self) -> None: + self.knobs.disable() + async def _set_fail_point(self, client, command_args): cmd = SON([("configureFailPoint", "failCommand")]) cmd.update(command_args) @@ -700,8 +692,6 @@ async def run_scenario(self, scenario_def, test): self.listener = listener self.pool_listener = pool_listener self.server_listener = server_listener - # Close the client explicitly to avoid having too many threads open. - self.addAsyncCleanup(client.close) # Create session0 and session1. sessions = {} diff --git a/test/conftest.py b/test/conftest.py index a3d954c7c3..91fad28d0a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -20,7 +20,7 @@ def event_loop_policy(): return asyncio.get_event_loop_policy() -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture(scope="package", autouse=True) def test_setup_and_teardown(): setup() yield diff --git a/test/test_bulk.py b/test/test_bulk.py index ea2b803804..6d29ff510a 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -42,15 +42,11 @@ class BulkTestBase(IntegrationTest): coll: Collection coll_w0: Collection - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.coll = cls.db.test - cls.coll_w0 = cls.coll.with_options(write_concern=WriteConcern(w=0)) - def setUp(self): super().setUp() + self.coll = self.db.test self.coll.drop() + self.coll_w0 = self.coll.with_options(write_concern=WriteConcern(w=0)) def assertEqualResponse(self, expected, actual): """Compare response from bulk.execute() to expected response.""" @@ -785,12 +781,8 @@ def test_large_inserts_unordered(self): class BulkAuthorizationTestBase(BulkTestBase): - @classmethod @client_context.require_auth @client_context.require_no_api_version - def _setup_class(cls): - super()._setup_class() - def setUp(self): super().setUp() client_context.create_user(self.db.name, "readonly", "pw", ["read"]) @@ -935,21 +927,19 @@ class TestBulkWriteConcern(BulkTestBase): w: Optional[int] secondary: MongoClient - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.w = client_context.w - cls.secondary = None - if cls.w is not None and cls.w > 1: + def setUp(self): + super().setUp() + self.w = client_context.w + self.secondary = None + if self.w is not None and self.w > 1: for member in (client_context.hello)["hosts"]: if member != (client_context.hello)["primary"]: - cls.secondary = cls.unmanaged_single_client(*partition_node(member)) + self.secondary = self.single_client(*partition_node(member)) break - @classmethod - def async_tearDownClass(cls): - if cls.secondary: - cls.secondary.close() + def tearDown(self): + if self.secondary: + self.secondary.close() def cause_wtimeout(self, requests, ordered): if not client_context.test_commands_enabled: diff --git a/test/test_change_stream.py b/test/test_change_stream.py index 3a107122b7..4ed21f55cf 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -820,18 +820,16 @@ def test_split_large_change(self): class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin): dbs: list - @classmethod @client_context.require_version_min(4, 0, 0, -1) @client_context.require_change_streams - def _setup_class(cls): - super()._setup_class() - cls.dbs = [cls.db, cls.client.pymongo_test_2] + def setUp(self) -> None: + super().setUp() + self.dbs = [self.db, self.client.pymongo_test_2] - @classmethod - def _tearDown_class(cls): - for db in cls.dbs: - cls.client.drop_database(db) - super()._tearDown_class() + def tearDown(self): + for db in self.dbs: + self.client.drop_database(db) + super().tearDown() def change_stream_with_client(self, client, *args, **kwargs): return client.watch(*args, **kwargs) @@ -882,11 +880,10 @@ def test_full_pipeline(self): class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin): - @classmethod @client_context.require_version_min(4, 0, 0, -1) @client_context.require_change_streams - def _setup_class(cls): - super()._setup_class() + def setUp(self) -> None: + super().setUp() def change_stream_with_client(self, client, *args, **kwargs): return client[self.db.name].watch(*args, **kwargs) @@ -968,12 +965,9 @@ def test_isolation(self): class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, ProseSpecTestsMixin): - @classmethod @client_context.require_change_streams - def _setup_class(cls): - super()._setup_class() - def setUp(self): + super().setUp() # Use a new collection for each test. self.watched_collection().drop() self.watched_collection().insert_one({}) @@ -1111,20 +1105,11 @@ class TestAllLegacyScenarios(IntegrationTest): RUN_ON_LOAD_BALANCER = True listener: AllowListEventListener - @classmethod @client_context.require_connection - def _setup_class(cls): - super()._setup_class() - cls.listener = AllowListEventListener("aggregate", "getMore") - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - - @classmethod - def _tearDown_class(cls): - cls.client.close() - super()._tearDown_class() - def setUp(self): super().setUp() + self.listener = AllowListEventListener("aggregate", "getMore") + self.client = self.rs_or_single_client(event_listeners=[self.listener]) self.listener.reset() def setUpCluster(self, scenario_dict): diff --git a/test/test_client.py b/test/test_client.py index 5bbb5bd751..5ec425f312 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -129,13 +129,8 @@ class ClientUnitTest(UnitTest): client: MongoClient - @classmethod - def _setup_class(cls): - cls.client = cls.unmanaged_rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) - - @classmethod - def _tearDown_class(cls): - cls.client.close() + def setUp(self) -> None: + self.client = self.rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): @@ -1039,14 +1034,21 @@ def test_uri_connect_option(self): self.assertFalse(client._topology._opened) # Ensure kill cursors thread has not been started. - kc_thread = client._kill_cursors_executor._thread - self.assertFalse(kc_thread and kc_thread.is_alive()) - + if _IS_SYNC: + kc_thread = client._kill_cursors_executor._thread + self.assertFalse(kc_thread and kc_thread.is_alive()) + else: + kc_task = client._kill_cursors_executor._task + self.assertFalse(kc_task and not kc_task.done()) # Using the client should open topology and start the thread. client.admin.command("ping") self.assertTrue(client._topology._opened) - kc_thread = client._kill_cursors_executor._thread - self.assertTrue(kc_thread and kc_thread.is_alive()) + if _IS_SYNC: + kc_thread = client._kill_cursors_executor._thread + self.assertTrue(kc_thread and kc_thread.is_alive()) + else: + kc_task = client._kill_cursors_executor._task + self.assertTrue(kc_task and not kc_task.done()) def test_close_does_not_open_servers(self): client = self.rs_client(connect=False) @@ -1241,6 +1243,7 @@ def get_x(db): def test_server_selection_timeout(self): client = MongoClient(serverSelectionTimeoutMS=100, connect=False) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) + client.close() client = MongoClient(serverSelectionTimeoutMS=0, connect=False) @@ -1251,16 +1254,20 @@ def test_server_selection_timeout(self): self.assertRaises( ConfigurationError, MongoClient, serverSelectionTimeoutMS=None, connect=False ) + client.close() client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) + client.close() client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False) self.assertAlmostEqual(0, client.options.server_selection_timeout) + client.close() # Test invalid timeout in URI ignored and set to default. client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False) self.assertAlmostEqual(30, client.options.server_selection_timeout) + client.close() client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False) self.assertAlmostEqual(30, client.options.server_selection_timeout) diff --git a/test/test_collation.py b/test/test_collation.py index b878df2fb4..06436f0638 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -97,26 +97,19 @@ class TestCollation(IntegrationTest): warn_context: Any collation: Collation - @classmethod @client_context.require_connection - def _setup_class(cls): - super()._setup_class() - cls.listener = OvertCommandListener() - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - cls.db = cls.client.pymongo_test - cls.collation = Collation("en_US") - cls.warn_context = warnings.catch_warnings() - cls.warn_context.__enter__() - warnings.simplefilter("ignore", DeprecationWarning) - - @classmethod - def _tearDown_class(cls): - cls.warn_context.__exit__() - cls.warn_context = None - cls.client.close() - super()._tearDown_class() - - def tearDown(self): + def setUp(self) -> None: + super().setUp() + self.listener = OvertCommandListener() + self.client = self.rs_or_single_client(event_listeners=[self.listener]) + self.db = self.client.pymongo_test + self.collation = Collation("en_US") + self.warn_context = warnings.catch_warnings() + self.warn_context.__enter__() + + def tearDown(self) -> None: + self.warn_context.__exit__() + self.warn_context = None self.listener.reset() super().tearDown() diff --git a/test/test_collection.py b/test/test_collection.py index 84a900d45b..af524bba47 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -87,14 +87,10 @@ class TestCollectionNoConnect(UnitTest): db: Database client: MongoClient - @classmethod - def _setup_class(cls): - cls.client = MongoClient(connect=False) - cls.db = cls.client.pymongo_test - - @classmethod - def _tearDown_class(cls): - cls.client.close() + def setUp(self) -> None: + super().setUp() + self.client = self.simple_client(connect=False) + self.db = self.client.pymongo_test def test_collection(self): self.assertRaises(TypeError, Collection, self.db, 5) @@ -164,27 +160,14 @@ def test_iteration(self): class TestCollection(IntegrationTest): w: int - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.w = client_context.w # type: ignore - - @classmethod - def tearDownClass(cls): - if _IS_SYNC: - cls.db.drop_collection("test_large_limit") # type: ignore[unused-coroutine] - else: - asyncio.run(cls.async_tearDownClass()) - - @classmethod - def async_tearDownClass(cls): - cls.db.drop_collection("test_large_limit") - def setUp(self): - self.db.test.drop() + super().setUp() + self.w = client_context.w # type: ignore def tearDown(self): self.db.test.drop() + self.db.drop_collection("test_large_limit") + super().tearDown() @contextlib.contextmanager def write_concern_collection(self): @@ -1010,7 +993,10 @@ def test_replace_bypass_document_validation(self): db.test.insert_one({"y": 1}, bypass_document_validation=True) db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True) - wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document") + def predicate(): + return db_w0.test.find_one({"x": 1}) + + wait_until(predicate, "find w:0 replaced document") def test_update_bypass_document_validation(self): db = self.db diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index 54cc4e0482..84ef6decd5 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -19,7 +19,12 @@ sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest +from test import ( + IntegrationTest, + client_context, + reset_client_context, + unittest, +) from test.helpers import repl_set_step_down from test.utils import ( CMAPListener, @@ -39,29 +44,19 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): listener: CMAPListener coll: Collection - @classmethod @client_context.require_replica_set - def _setup_class(cls): - super()._setup_class() - cls.listener = CMAPListener() - cls.client = cls.unmanaged_rs_or_single_client( - event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500 + def setUp(self): + self.listener = CMAPListener() + self.client = self.rs_or_single_client( + event_listeners=[self.listener], retryWrites=False, heartbeatFrequencyMS=500 ) # Ensure connections to all servers in replica set. This is to test # that the is_writable flag is properly updated for connections that # survive a replica set election. - ensure_all_connected(cls.client) - cls.listener.reset() - - cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority")) - cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority")) - - @classmethod - def _tearDown_class(cls): - cls.client.close() - - def setUp(self): + ensure_all_connected(self.client) + self.db = self.client.get_database("step-down", write_concern=WriteConcern("majority")) + self.coll = self.db.get_collection("step-down", write_concern=WriteConcern("majority")) # Note that all ops use same write-concern as self.db (majority). self.db.drop_collection("step-down") self.db.create_collection("step-down") diff --git a/test/test_create_entities.py b/test/test_create_entities.py index ad75fe5702..9d77a08eee 100644 --- a/test/test_create_entities.py +++ b/test/test_create_entities.py @@ -56,6 +56,9 @@ def test_store_events_as_entities(self): self.assertGreater(len(final_entity_map["events1"]), 0) for event in final_entity_map["events1"]: self.assertIn("PoolCreatedEvent", event["name"]) + if self.scenario_runner.mongos_clients: + for client in self.scenario_runner.mongos_clients: + client.close() def test_store_all_others_as_entities(self): self.scenario_runner = UnifiedSpecTestMixinV1() @@ -122,6 +125,9 @@ def test_store_all_others_as_entities(self): self.assertEqual(entity_map["failures"], []) self.assertEqual(entity_map["successes"], 2) self.assertEqual(entity_map["iterations"], 5) + if self.scenario_runner.mongos_clients: + for client in self.scenario_runner.mongos_clients: + client.close() if __name__ == "__main__": diff --git a/test/test_cursor.py b/test/test_cursor.py index 9eac0f1c49..bcc7ed75f1 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1636,10 +1636,6 @@ def test_monitoring(self): class TestRawBatchCommandCursor(IntegrationTest): - @classmethod - def _setup_class(cls): - super()._setup_class() - def test_aggregate_raw(self): c = self.db.test c.drop() diff --git a/test/test_custom_types.py b/test/test_custom_types.py index abaa820cb7..6771ea25f9 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -633,6 +633,7 @@ class MyType(pytype): # type: ignore class TestCollectionWCustomType(IntegrationTest): def setUp(self): + super().setUp() self.db.test.drop() def tearDown(self): @@ -754,6 +755,7 @@ def test_find_one_and__w_custom_type_decoder(self): class TestGridFileCustomType(IntegrationTest): def setUp(self): + super().setUp() self.db.drop_collection("fs.files") self.db.drop_collection("fs.chunks") @@ -917,11 +919,10 @@ def run_test(doc_cls): class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): - @classmethod @client_context.require_change_streams - def setUpClass(cls): - super().setUpClass() - cls.db.test.delete_many({}) + def setUp(self): + super().setUp() + self.db.test.delete_many({}) def tearDown(self): self.input_target.drop() @@ -935,12 +936,11 @@ def create_targets(self, *args, **kwargs): class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): - @classmethod @client_context.require_version_min(4, 0, 0) @client_context.require_change_streams - def setUpClass(cls): - super().setUpClass() - cls.db.test.delete_many({}) + def setUp(self): + super().setUp() + self.db.test.delete_many({}) def tearDown(self): self.input_target.drop() @@ -954,12 +954,11 @@ def create_targets(self, *args, **kwargs): class TestClusterChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): - @classmethod @client_context.require_version_min(4, 0, 0) @client_context.require_change_streams - def setUpClass(cls): - super().setUpClass() - cls.db.test.delete_many({}) + def setUp(self): + super().setUp() + self.db.test.delete_many({}) def tearDown(self): self.input_target.drop() diff --git a/test/test_database.py b/test/test_database.py index 4973ed0134..5e854c941d 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -709,6 +709,7 @@ def test_with_options(self): class TestDatabaseAggregation(IntegrationTest): def setUp(self): + super().setUp() self.pipeline: List[Mapping[str, Any]] = [ {"$listLocalSessions": {}}, {"$limit": 1}, diff --git a/test/test_encryption.py b/test/test_encryption.py index 0806f91a06..cb8bcb74d6 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -211,11 +211,10 @@ def test_kwargs(self): class EncryptionIntegrationTest(IntegrationTest): """Base class for encryption integration tests.""" - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @client_context.require_version_min(4, 2, -1) - def _setup_class(cls): - super()._setup_class() + def setUp(self) -> None: + super().setUp() def assertEncrypted(self, val): self.assertIsInstance(val, Binary) @@ -430,10 +429,9 @@ def test_upsert_uuid_standard_encrypt(self): class TestClientMaxWireVersion(IntegrationTest): - @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") - def _setup_class(cls): - super()._setup_class() + def setUp(self): + super().setUp() @client_context.require_version_max(4, 0, 99) def test_raise_max_wire_version_error(self): @@ -816,17 +814,16 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest): "local": None, } - @classmethod @unittest.skipUnless( any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), "No environment credentials are set", ) - def _setup_class(cls): - super()._setup_class() - cls.listener = OvertCommandListener() - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - cls.client.db.coll.drop() - cls.vault = create_key_vault(cls.client.keyvault.datakeys) + def setUp(self): + super().setUp() + self.listener = OvertCommandListener() + self.client = self.rs_or_single_client(event_listeners=[self.listener]) + self.client.db.coll.drop() + self.vault = create_key_vault(self.client.keyvault.datakeys) # Configure the encrypted field via the local schema_map option. schemas = { @@ -844,25 +841,22 @@ def _setup_class(cls): } } opts = AutoEncryptionOpts( - cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS + self.KMS_PROVIDERS, + "keyvault.datakeys", + schema_map=schemas, + kms_tls_options=KMS_TLS_OPTS, ) - cls.client_encrypted = cls.unmanaged_rs_or_single_client( + self.client_encrypted = self.rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - cls.client_encryption = cls.unmanaged_create_client_encryption( - cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS + self.client_encryption = self.create_client_encryption( + self.KMS_PROVIDERS, "keyvault.datakeys", self.client, OPTS, kms_tls_options=KMS_TLS_OPTS ) - - @classmethod - def _tearDown_class(cls): - cls.vault.drop() - cls.client.close() - cls.client_encrypted.close() - cls.client_encryption.close() - - def setUp(self): self.listener.reset() + def tearDown(self) -> None: + self.vault.drop() + def run_test(self, provider_name): # Create data key. master_key: Any = self.MASTER_KEYS[provider_name] @@ -1007,10 +1001,9 @@ def test_views_are_prohibited(self): class TestCorpus(EncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") - def _setup_class(cls): - super()._setup_class() + def setUp(self): + super().setUp() @staticmethod def kms_providers(): @@ -1184,12 +1177,11 @@ class TestBsonSizeBatches(EncryptionIntegrationTest): client_encrypted: MongoClient listener: OvertCommandListener - @classmethod - def _setup_class(cls): - super()._setup_class() + def setUp(self): + super().setUp() db = client_context.client.db - cls.coll = db.coll - cls.coll.drop() + self.coll = db.coll + self.coll.drop() # Configure the encrypted 'db.coll' collection via jsonSchema. json_schema = json_data("limits", "limits-schema.json") db.create_collection( @@ -1207,17 +1199,14 @@ def _setup_class(cls): coll.insert_one(json_data("limits", "limits-key.json")) opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") - cls.listener = OvertCommandListener() - cls.client_encrypted = cls.unmanaged_rs_or_single_client( - auto_encryption_opts=opts, event_listeners=[cls.listener] + self.listener = OvertCommandListener() + self.client_encrypted = self.rs_or_single_client( + auto_encryption_opts=opts, event_listeners=[self.listener] ) - cls.coll_encrypted = cls.client_encrypted.db.coll + self.coll_encrypted = self.client_encrypted.db.coll - @classmethod - def _tearDown_class(cls): - cls.coll_encrypted.drop() - cls.client_encrypted.close() - super()._tearDown_class() + def tearDown(self) -> None: + self.coll_encrypted.drop() def test_01_insert_succeeds_under_2MiB(self): doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB} @@ -1241,7 +1230,9 @@ def test_03_bulk_batch_split(self): doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB} self.listener.reset() self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)]) - self.assertEqual(self.listener.started_command_names(), ["insert", "insert"]) + self.assertEqual( + len([c for c in self.listener.started_command_names() if c == "insert"]), 2 + ) def test_04_bulk_batch_split(self): limits_doc = json_data("limits", "limits-doc.json") @@ -1251,7 +1242,9 @@ def test_04_bulk_batch_split(self): doc2.update(limits_doc) self.listener.reset() self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)]) - self.assertEqual(self.listener.started_command_names(), ["insert", "insert"]) + self.assertEqual( + len([c for c in self.listener.started_command_names() if c == "insert"]), 2 + ) def test_05_insert_succeeds_just_under_16MiB(self): doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)} @@ -1281,15 +1274,12 @@ def test_06_insert_fails_over_16MiB(self): class TestCustomEndpoint(EncryptionIntegrationTest): """Prose tests for creating data keys with a custom endpoint.""" - @classmethod @unittest.skipUnless( any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), "No environment credentials are set", ) - def _setup_class(cls): - super()._setup_class() - def setUp(self): + super().setUp() kms_providers = { "aws": AWS_CREDS, "azure": AZURE_CREDS, @@ -1318,10 +1308,6 @@ def setUp(self): self._kmip_host_error = None self._invalid_host_error = None - def tearDown(self): - self.client_encryption.close() - self.client_encryption_invalid.close() - def run_test_expected_success(self, provider_name, master_key): data_key_id = self.client_encryption.create_data_key(provider_name, master_key=master_key) encrypted = self.client_encryption.encrypt( @@ -1494,18 +1480,18 @@ class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest): KEYVAULT_COLL = "datakeys" client: MongoClient - def setUp(self): + def _setup(self): keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL) create_key_vault(keyvault, self.DEK) def _test_explicit(self, expectation): + self._setup() client_encryption = self.create_client_encryption( self.KMS_PROVIDER_MAP, # type: ignore[arg-type] ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]), client_context.client, OPTS, ) - self.addCleanup(client_encryption.close) ciphertext = client_encryption.encrypt( "string0", @@ -1517,6 +1503,7 @@ def _test_explicit(self, expectation): self.assertEqual(client_encryption.decrypt(ciphertext), "string0") def _test_automatic(self, expectation_extjson, payload): + self._setup() encrypted_db = "db" encrypted_coll = "coll" keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]) @@ -1531,7 +1518,6 @@ def _test_automatic(self, expectation_extjson, payload): client = self.rs_or_single_client( auto_encryption_opts=encryption_opts, event_listeners=[insert_listener] ) - self.addCleanup(client.close) coll = client.get_database(encrypted_db).get_collection( encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority") @@ -1553,13 +1539,12 @@ def _test_automatic(self, expectation_extjson, payload): class TestAzureEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set") - def _setup_class(cls): - cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} - cls.DEK = json_data(BASE, "custom", "azure-dek.json") - cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") - super()._setup_class() + def setUp(self): + self.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} + self.DEK = json_data(BASE, "custom", "azure-dek.json") + self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") + super().setUp() def test_explicit(self): return self._test_explicit( @@ -1579,13 +1564,12 @@ def test_automatic(self): class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest): - @classmethod @unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set") - def _setup_class(cls): - cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} - cls.DEK = json_data(BASE, "custom", "gcp-dek.json") - cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") - super()._setup_class() + def setUp(self): + self.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} + self.DEK = json_data(BASE, "custom", "gcp-dek.json") + self.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") + super().setUp() def test_explicit(self): return self._test_explicit( @@ -1607,6 +1591,7 @@ def test_automatic(self): # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#deadlock-tests class TestDeadlockProse(EncryptionIntegrationTest): def setUp(self): + super().setUp() self.client_test = self.rs_or_single_client( maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard" ) @@ -1637,7 +1622,6 @@ def setUp(self): self.ciphertext = client_encryption.encrypt( "string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local" ) - client_encryption.close() self.client_listener = OvertCommandListener() self.topology_listener = TopologyEventListener() @@ -1832,6 +1816,7 @@ def test_case_8(self): # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#14-decryption-events class TestDecryptProse(EncryptionIntegrationTest): def setUp(self): + super().setUp() self.client = client_context.client self.client.db.drop_collection("decryption_events") create_key_vault(self.client.keyvault.datakeys) @@ -2267,6 +2252,7 @@ def test_06_named_kms_providers_apply_tls_options_kmip(self): # https://github.com/mongodb/specifications/blob/50e26fe/source/client-side-encryption/tests/README.md#unique-index-on-keyaltnames class TestUniqueIndexOnKeyAltNamesProse(EncryptionIntegrationTest): def setUp(self): + super().setUp() self.client = client_context.client create_key_vault(self.client.keyvault.datakeys) kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}} @@ -2608,8 +2594,6 @@ def MongoClient(**kwargs): assert isinstance(res["encrypted_indexed"], Binary) assert isinstance(res["encrypted_unindexed"], Binary) - client_encryption.close() - # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#22-range-explicit-encryption class TestRangeQueryProse(EncryptionIntegrationTest): @@ -3071,17 +3055,11 @@ class TestNoSessionsSupport(EncryptionIntegrationTest): mongocryptd_client: MongoClient MONGOCRYPTD_PORT = 27020 - @classmethod @unittest.skipIf(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is installed") - def _setup_class(cls): - super()._setup_class() - start_mongocryptd(cls.MONGOCRYPTD_PORT) - - @classmethod - def _tearDown_class(cls): - super()._tearDown_class() - def setUp(self) -> None: + super().setUp() + start_mongocryptd(self.MONGOCRYPTD_PORT) + self.listener = OvertCommandListener() self.mongocryptd_client = self.simple_client( f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener] diff --git a/test/test_examples.py b/test/test_examples.py index ebf1d784a3..7f98226e7a 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -33,19 +33,14 @@ class TestSampleShellCommands(IntegrationTest): - @classmethod - def setUpClass(cls): - super().setUpClass() - # Run once before any tests run. - cls.db.inventory.drop() - - @classmethod - def tearDownClass(cls): - cls.client.drop_database("pymongo_test") + def setUp(self): + super().setUp() + self.db.inventory.drop() def tearDown(self): # Run after every test. self.db.inventory.drop() + self.client.drop_database("pymongo_test") def test_first_three_examples(self): db = self.db diff --git a/test/test_grid_file.py b/test/test_grid_file.py index c35efccef5..6534bc11bf 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -97,6 +97,7 @@ def test_grid_in_custom_opts(self): class TestGridFile(IntegrationTest): def setUp(self): + super().setUp() self.cleanup_colls(self.db.fs.files, self.db.fs.chunks) def test_basic(self): diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 549dc0b204..a36109f399 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -75,9 +75,9 @@ def run(self): class TestGridfsNoConnect(unittest.TestCase): db: Database - @classmethod - def setUpClass(cls): - cls.db = MongoClient(connect=False).pymongo_test + def setUp(self): + super().setUp() + self.db = MongoClient(connect=False).pymongo_test def test_gridfs(self): self.assertRaises(TypeError, gridfs.GridFS, "foo") @@ -88,13 +88,10 @@ class TestGridfs(IntegrationTest): fs: gridfs.GridFS alt: gridfs.GridFS - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.fs = gridfs.GridFS(cls.db) - cls.alt = gridfs.GridFS(cls.db, "alt") - def setUp(self): + super().setUp() + self.fs = gridfs.GridFS(self.db) + self.alt = gridfs.GridFS(self.db, "alt") self.cleanup_colls( self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks ) @@ -509,10 +506,9 @@ def test_md5(self): class TestGridfsReplicaSet(IntegrationTest): - @classmethod @client_context.require_secondaries_count(1) - def setUpClass(cls): - super().setUpClass() + def setUp(self): + super().setUp() @classmethod def tearDownClass(cls): diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index 28adb7051a..04c7427350 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -79,13 +79,10 @@ class TestGridfs(IntegrationTest): fs: gridfs.GridFSBucket alt: gridfs.GridFSBucket - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.fs = gridfs.GridFSBucket(cls.db) - cls.alt = gridfs.GridFSBucket(cls.db, bucket_name="alt") - def setUp(self): + super().setUp() + self.fs = gridfs.GridFSBucket(self.db) + self.alt = gridfs.GridFSBucket(self.db, bucket_name="alt") self.cleanup_colls( self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks ) @@ -479,10 +476,9 @@ def test_md5(self): class TestGridfsBucketReplicaSet(IntegrationTest): - @classmethod @client_context.require_secondaries_count(1) - def setUpClass(cls): - super().setUpClass() + def setUp(self): + super().setUp() @classmethod def tearDownClass(cls): diff --git a/test/test_monitor.py b/test/test_monitor.py index f8e9443fae..a704f3d8cb 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -29,7 +29,7 @@ wait_until, ) -from pymongo.synchronous.periodic_executor import _EXECUTORS +from pymongo.periodic_executor import _EXECUTORS def unregistered(ref): diff --git a/test/test_monitoring.py b/test/test_monitoring.py index 75fe5c987a..670558c0a0 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -52,22 +52,14 @@ class TestCommandMonitoring(IntegrationTest): listener: EventListener @classmethod - @client_context.require_connection - def _setup_class(cls): - super()._setup_class() + def setUpClass(cls) -> None: cls.listener = OvertCommandListener() - cls.client = cls.unmanaged_rs_or_single_client( - event_listeners=[cls.listener], retryWrites=False - ) - @classmethod - def _tearDown_class(cls): - cls.client.close() - super()._tearDown_class() - - def tearDown(self): + @client_context.require_connection + def setUp(self) -> None: + super().setUp() self.listener.reset() - super().tearDown() + self.client = self.rs_or_single_client(event_listeners=[self.listener], retryWrites=False) def test_started_simple(self): self.client.pymongo_test.command("ping") @@ -1140,26 +1132,23 @@ class TestGlobalListener(IntegrationTest): saved_listeners: Any @classmethod - @client_context.require_connection - def _setup_class(cls): - super()._setup_class() + def setUpClass(cls) -> None: cls.listener = OvertCommandListener() # We plan to call register(), which internally modifies _LISTENERS. cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS) monitoring.register(cls.listener) - cls.client = cls.unmanaged_single_client() - # Get one (authenticated) socket in the pool. - cls.client.pymongo_test.command("ping") - - @classmethod - def _tearDown_class(cls): - monitoring._LISTENERS = cls.saved_listeners - cls.client.close() - super()._tearDown_class() + @client_context.require_connection def setUp(self): super().setUp() self.listener.reset() + self.client = self.single_client() + # Get one (authenticated) socket in the pool. + self.client.pymongo_test.command("ping") + + @classmethod + def tearDownClass(cls): + monitoring._LISTENERS = cls.saved_listeners def test_simple(self): self.client.pymongo_test.command("ping") diff --git a/test/test_read_concern.py b/test/test_read_concern.py index ea9ce49a30..f7c0901422 100644 --- a/test/test_read_concern.py +++ b/test/test_read_concern.py @@ -31,24 +31,16 @@ class TestReadConcern(IntegrationTest): listener: OvertCommandListener - @classmethod @client_context.require_connection - def setUpClass(cls): - super().setUpClass() - cls.listener = OvertCommandListener() - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - cls.db = cls.client.pymongo_test + def setUp(self): + super().setUp() + self.listener = OvertCommandListener() + self.client = self.rs_or_single_client(event_listeners=[self.listener]) + self.db = self.client.pymongo_test client_context.client.pymongo_test.create_collection("coll") - @classmethod - def tearDownClass(cls): - cls.client.close() - client_context.client.pymongo_test.drop_collection("coll") - super().tearDownClass() - def tearDown(self): - self.listener.reset() - super().tearDown() + client_context.client.pymongo_test.drop_collection("coll") def test_read_concern(self): rc = ReadConcern() diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 74f3c23e51..07bd1db0ba 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -132,34 +132,27 @@ class IgnoreDeprecationsTest(IntegrationTest): RUN_ON_SERVERLESS = True deprecation_filter: DeprecationFilter - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.deprecation_filter = DeprecationFilter() + def setUp(self) -> None: + super().setUp() + self.deprecation_filter = DeprecationFilter() - @classmethod - def _tearDown_class(cls): - cls.deprecation_filter.stop() - super()._tearDown_class() + def tearDown(self) -> None: + self.deprecation_filter.stop() class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest): knobs: client_knobs - @classmethod - def _setup_class(cls): - super()._setup_class() + def setUp(self) -> None: + super().setUp() # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - cls.client = cls.unmanaged_rs_or_single_client(retryWrites=True) - cls.db = cls.client.pymongo_test + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() + self.client = self.rs_or_single_client(retryWrites=True) + self.db = self.client.pymongo_test - @classmethod - def _tearDown_class(cls): - cls.knobs.disable() - cls.client.close() - super()._tearDown_class() + def tearDown(self) -> None: + self.knobs.disable() @client_context.require_no_standalone def test_actionable_error_message(self): @@ -180,26 +173,16 @@ class TestRetryableWrites(IgnoreDeprecationsTest): listener: OvertCommandListener knobs: client_knobs - @classmethod @client_context.require_no_mmap - def _setup_class(cls): - super()._setup_class() + def setUp(self) -> None: + super().setUp() # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - cls.listener = OvertCommandListener() - cls.client = cls.unmanaged_rs_or_single_client( - retryWrites=True, event_listeners=[cls.listener] - ) - cls.db = cls.client.pymongo_test - - @classmethod - def _tearDown_class(cls): - cls.knobs.disable() - cls.client.close() - super()._tearDown_class() + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() + self.listener = OvertCommandListener() + self.client = self.rs_or_single_client(retryWrites=True, event_listeners=[self.listener]) + self.db = self.client.pymongo_test - def setUp(self): if client_context.is_rs and client_context.test_commands_enabled: self.client.admin.command( SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")]) @@ -210,6 +193,7 @@ def tearDown(self): self.client.admin.command( SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")]) ) + self.knobs.disable() def test_supported_single_statement_no_retry(self): listener = OvertCommandListener() @@ -438,13 +422,12 @@ class TestWriteConcernError(IntegrationTest): RUN_ON_SERVERLESS = True fail_insert: dict - @classmethod @client_context.require_replica_set @client_context.require_no_mmap @client_context.require_failCommand_fail_point - def _setup_class(cls): - super()._setup_class() - cls.fail_insert = { + def setUp(self) -> None: + super().setUp() + self.fail_insert = { "configureFailPoint": "failCommand", "mode": {"times": 2}, "data": { diff --git a/test/test_sdam_monitoring_spec.py b/test/test_sdam_monitoring_spec.py index 81b208d511..6b808b159d 100644 --- a/test/test_sdam_monitoring_spec.py +++ b/test/test_sdam_monitoring_spec.py @@ -270,7 +270,7 @@ class TestSdamMonitoring(IntegrationTest): @classmethod @client_context.require_failCommand_fail_point def setUpClass(cls): - super().setUpClass() + super().setUp(cls) # Speed up the tests by decreasing the event publish frequency. cls.knobs = client_knobs( events_queue_frequency=0.1, heartbeat_frequency=0.1, min_heartbeat_interval=0.1 diff --git a/test/test_session.py b/test/test_session.py index d0bbb075a8..634efa11c0 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -82,36 +82,27 @@ class TestSession(IntegrationTest): client2: MongoClient sensitive_commands: Set[str] - @classmethod @client_context.require_sessions - def _setup_class(cls): - super()._setup_class() + def setUp(self): + super().setUp() # Create a second client so we can make sure clients cannot share # sessions. - cls.client2 = cls.unmanaged_rs_or_single_client() + self.client2 = self.rs_or_single_client() # Redact no commands, so we can test user-admin commands have "lsid". - cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() + self.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() monitoring._SENSITIVE_COMMANDS.clear() - @classmethod - def _tearDown_class(cls): - monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands) - cls.client2.close() - super()._tearDown_class() - - def setUp(self): self.listener = SessionTestListener() self.session_checker_listener = SessionTestListener() self.client = self.rs_or_single_client( event_listeners=[self.listener, self.session_checker_listener] ) - self.addCleanup(self.client.close) self.db = self.client.pymongo_test self.initial_lsids = {s["id"] for s in session_ids(self.client)} def tearDown(self): - """All sessions used in the test must be returned to the pool.""" + monitoring._SENSITIVE_COMMANDS.update(self.sensitive_commands) self.client.drop_database("pymongo_test") used_lsids = self.initial_lsids.copy() for event in self.session_checker_listener.started_events: @@ -121,6 +112,8 @@ def tearDown(self): current_lsids = {s["id"] for s in session_ids(self.client)} self.assertLessEqual(used_lsids, current_lsids) + super().tearDown() + def _test_ops(self, client, *ops): listener = client.options.event_listeners[0] @@ -832,18 +825,11 @@ class TestCausalConsistency(UnitTest): listener: SessionTestListener client: MongoClient - @classmethod - def _setup_class(cls): - cls.listener = SessionTestListener() - cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) - - @classmethod - def _tearDown_class(cls): - cls.client.close() - @client_context.require_sessions def setUp(self): super().setUp() + self.listener = SessionTestListener() + self.client = self.rs_or_single_client(event_listeners=[self.listener]) @client_context.require_no_standalone def test_core(self): diff --git a/test/test_threads.py b/test/test_threads.py index b3dadbb1a3..3e469e28fe 100644 --- a/test/test_threads.py +++ b/test/test_threads.py @@ -105,6 +105,7 @@ def run(self): class TestThreads(IntegrationTest): def setUp(self): + super().setUp() self.db = self.client.pymongo_test def test_threading(self): diff --git a/test/test_transactions.py b/test/test_transactions.py index 3cecbe9d38..949b88e60b 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -395,19 +395,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): class TestTransactionsConvenientAPI(TransactionsBase): - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.mongos_clients = [] + def setUp(self) -> None: + super().setUp() + self.mongos_clients = [] if client_context.supports_transactions(): for address in client_context.mongoses: - cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address))) - - @classmethod - def _tearDown_class(cls): - for client in cls.mongos_clients: - client.close() - super()._tearDown_class() + self.mongos_clients.append(self.single_client("{}:{}".format(*address))) def _set_fail_point(self, client, command_args): cmd = {"configureFailPoint": "failCommand"} diff --git a/test/test_typing.py b/test/test_typing.py index 441707616e..bfe4d032c1 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -114,10 +114,9 @@ def test_mypy_failures(self) -> None: class TestPymongo(IntegrationTest): coll: Collection - @classmethod - def setUpClass(cls): - super().setUpClass() - cls.coll = cls.client.test.test + def setUp(self): + super().setUp() + self.coll = self.client.test.test def test_insert_find(self) -> None: doc = {"my": "doc"} diff --git a/test/unified_format.py b/test/unified_format.py index 3489a8ac84..5cb268a29d 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -304,7 +304,6 @@ def _create_entity(self, entity_spec, uri=None): kwargs["h"] = uri client = self.test.rs_or_single_client(**kwargs) self[spec["id"]] = client - self.test.addCleanup(client.close) return elif entity_type == "database": client = self[spec["client"]] @@ -479,52 +478,47 @@ def insert_initial_data(self, initial_data): db.create_collection(coll_name, write_concern=wc, **opts) @classmethod - def _setup_class(cls): + def setUpClass(cls) -> None: + # Speed up the tests by decreasing the heartbeat frequency. + cls.knobs = client_knobs( + heartbeat_frequency=0.1, + min_heartbeat_interval=0.1, + kill_cursor_frequency=0.1, + events_queue_frequency=0.1, + ) + cls.knobs.enable() + + @classmethod + def tearDownClass(cls) -> None: + cls.knobs.disable() + + def setUp(self): # super call creates internal client cls.client - super()._setup_class() + super().setUp() # process file-level runOnRequirements - run_on_spec = cls.TEST_SPEC.get("runOnRequirements", []) - if not cls.should_run_on(run_on_spec): - raise unittest.SkipTest(f"{cls.__name__} runOnRequirements not satisfied") + run_on_spec = self.TEST_SPEC.get("runOnRequirements", []) + if not self.should_run_on(run_on_spec): + raise unittest.SkipTest(f"{self.__class__.__name__} runOnRequirements not satisfied") # add any special-casing for skipping tests here if client_context.storage_engine == "mmapv1": - if "retryable-writes" in cls.TEST_SPEC["description"] or "retryable_writes" in str( - cls.TEST_PATH + if "retryable-writes" in self.TEST_SPEC["description"] or "retryable_writes" in str( + self.TEST_PATH ): raise unittest.SkipTest("MMAPv1 does not support retryWrites=True") # Handle mongos_clients for transactions tests. - cls.mongos_clients = [] + self.mongos_clients = [] if ( client_context.supports_transactions() and not client_context.load_balancer and not client_context.serverless ): for address in client_context.mongoses: - cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address))) + self.mongos_clients.append(self.single_client("{}:{}".format(*address))) - # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs( - heartbeat_frequency=0.1, - min_heartbeat_interval=0.1, - kill_cursor_frequency=0.1, - events_queue_frequency=0.1, - ) - cls.knobs.enable() - - @classmethod - def _tearDown_class(cls): - cls.knobs.disable() - for client in cls.mongos_clients: - client.close() - super()._tearDown_class() - - def setUp(self): - super().setUp() # process schemaVersion # note: we check major schema version during class generation - # note: we do this here because we cannot run assertions in setUpClass version = Version.from_string(self.TEST_SPEC["schemaVersion"]) self.assertLessEqual( version, @@ -1026,7 +1020,6 @@ def _testOperation_targetedFailPoint(self, spec): ) client = self.single_client("{}:{}".format(*session._pinned_address)) - self.addCleanup(client.close) self.__set_fail_point(client=client, command_args=spec["failPoint"]) def _testOperation_createEntities(self, spec): diff --git a/test/utils.py b/test/utils.py index 9b326e5d73..69154bc63b 100644 --- a/test/utils.py +++ b/test/utils.py @@ -99,6 +99,12 @@ def wait_for_event(self, event, count): """Wait for a number of events to be published, or fail.""" wait_until(lambda: self.event_count(event) >= count, f"find {count} {event} event(s)") + async def async_wait_for_event(self, event, count): + """Wait for a number of events to be published, or fail.""" + await async_wait_until( + lambda: self.event_count(event) >= count, f"find {count} {event} event(s)" + ) + class CMAPListener(BaseListener, monitoring.ConnectionPoolListener): def connection_created(self, event): @@ -644,7 +650,10 @@ async def async_wait_until(predicate, success_description, timeout=10): start = time.time() interval = min(float(timeout) / 100, 0.1) while True: - retval = await predicate() + if iscoroutinefunction(predicate): + retval = await predicate() + else: + retval = predicate() if retval: return retval diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 8b2679d776..4508502cd0 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -249,30 +249,22 @@ class SpecRunner(IntegrationTest): knobs: client_knobs listener: EventListener - @classmethod - def _setup_class(cls): - super()._setup_class() - cls.mongos_clients = [] + def setUp(self) -> None: + super().setUp() + self.mongos_clients = [] # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) - cls.knobs.enable() - - @classmethod - def _tearDown_class(cls): - cls.knobs.disable() - for client in cls.mongos_clients: - client.close() - super()._tearDown_class() - - def setUp(self): - super().setUp() + self.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) + self.knobs.enable() self.targets = {} self.listener = None # type: ignore self.pool_listener = None self.server_listener = None self.maxDiff = None + def tearDown(self) -> None: + self.knobs.disable() + def _set_fail_point(self, client, command_args): cmd = SON([("configureFailPoint", "failCommand")]) cmd.update(command_args) @@ -697,8 +689,6 @@ def run_scenario(self, scenario_def, test): self.listener = listener self.pool_listener = pool_listener self.server_listener = server_listener - # Close the client explicitly to avoid having too many threads open. - self.addCleanup(client.close) # Create session0 and session1. sessions = {} diff --git a/tools/synchro.py b/tools/synchro.py index 0a7109c6d4..47617365f4 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -110,6 +110,13 @@ "async_set_fail_point": "set_fail_point", "async_ensure_all_connected": "ensure_all_connected", "async_repl_set_step_down": "repl_set_step_down", + "AsyncPeriodicExecutor": "PeriodicExecutor", + "async_wait_for_event": "wait_for_event", + "pymongo_server_monitor_task": "pymongo_server_monitor_thread", + "pymongo_server_rtt_task": "pymongo_server_rtt_thread", + "_async_create_lock": "_create_lock", + "_async_create_condition": "_create_condition", + "_async_cond_wait": "_cond_wait", } docstring_replacements: dict[tuple[str, str], str] = { @@ -130,8 +137,6 @@ ".. warning:: This API is currently in beta, meaning the classes, methods, and behaviors described within may change before the full release." } -type_replacements = {"_Condition": "threading.Condition"} - import_replacements = {"test.synchronous": "test"} _pymongo_base = "./pymongo/asynchronous/" @@ -234,8 +239,6 @@ def process_files(files: list[str]) -> None: lines = translate_async_sleeps(lines) if file in docstring_translate_files: lines = translate_docstrings(lines) - translate_locks(lines) - translate_types(lines) if file in sync_test_files: translate_imports(lines) f.seek(0) @@ -269,34 +272,6 @@ def translate_coroutine_types(lines: list[str]) -> list[str]: return lines -def translate_locks(lines: list[str]) -> list[str]: - lock_lines = [line for line in lines if "_Lock(" in line] - cond_lines = [line for line in lines if "_Condition(" in line] - for line in lock_lines: - res = re.search(r"_Lock\(([^()]*\([^()]*\))\)", line) - if res: - old = res[0] - index = lines.index(line) - lines[index] = line.replace(old, res[1]) - for line in cond_lines: - res = re.search(r"_Condition\(([^()]*\([^()]*\))\)", line) - if res: - old = res[0] - index = lines.index(line) - lines[index] = line.replace(old, res[1]) - - return lines - - -def translate_types(lines: list[str]) -> list[str]: - for k, v in type_replacements.items(): - matches = [line for line in lines if k in line and "import" not in line] - for line in matches: - index = lines.index(line) - lines[index] = line.replace(k, v) - return lines - - def translate_imports(lines: list[str]) -> list[str]: for k, v in import_replacements.items(): matches = [line for line in lines if k in line and "import" in line]