diff --git a/THIRD-PARTY-NOTICES b/THIRD-PARTY-NOTICES index 0b9fc738ed..7e20a6f2bd 100644 --- a/THIRD-PARTY-NOTICES +++ b/THIRD-PARTY-NOTICES @@ -71,3 +71,62 @@ OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +3) License Notice for async_lock.py +----------------------------------------- + +1. This LICENSE AGREEMENT is between the Python Software Foundation +("PSF"), and the Individual or Organization ("Licensee") accessing and +otherwise using this software ("Python") in source or binary form and +its associated documentation. + +2. Subject to the terms and conditions of this License Agreement, PSF hereby +grants Licensee a nonexclusive, royalty-free, world-wide license to reproduce, +analyze, test, perform and/or display publicly, prepare derivative works, +distribute, and otherwise use Python alone or in any derivative version, +provided, however, that PSF's License Agreement and PSF's notice of copyright, +i.e., "Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved" +are retained in Python alone or in any derivative version prepared by Licensee. + +3. In the event Licensee prepares a derivative work that is based on +or incorporates Python or any part thereof, and wants to make +the derivative work available to others as provided herein, then +Licensee hereby agrees to include in any such work a brief summary of +the changes made to Python. + +4. PSF is making Python available to Licensee on an "AS IS" +basis. PSF MAKES NO REPRESENTATIONS OR WARRANTIES, EXPRESS OR +IMPLIED. BY WAY OF EXAMPLE, BUT NOT LIMITATION, PSF MAKES NO AND +DISCLAIMS ANY REPRESENTATION OR WARRANTY OF MERCHANTABILITY OR FITNESS +FOR ANY PARTICULAR PURPOSE OR THAT THE USE OF PYTHON WILL NOT +INFRINGE ANY THIRD PARTY RIGHTS. + +5. PSF SHALL NOT BE LIABLE TO LICENSEE OR ANY OTHER USERS OF PYTHON +FOR ANY INCIDENTAL, SPECIAL, OR CONSEQUENTIAL DAMAGES OR LOSS AS +A RESULT OF MODIFYING, DISTRIBUTING, OR OTHERWISE USING PYTHON, +OR ANY DERIVATIVE THEREOF, EVEN IF ADVISED OF THE POSSIBILITY THEREOF. + +6. This License Agreement will automatically terminate upon a material +breach of its terms and conditions. + +7. Nothing in this License Agreement shall be deemed to create any +relationship of agency, partnership, or joint venture between PSF and +Licensee. This License Agreement does not grant permission to use PSF +trademarks or trade name in a trademark sense to endorse or promote +products or services of Licensee, or any third party. + +8. By copying, installing or otherwise using Python, Licensee +agrees to be bound by the terms and conditions of this License +Agreement. + +Permission to use, copy, modify, and/or distribute this software for any +purpose with or without fee is hereby granted. + +THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH +REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, +INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR +OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +PERFORMANCE OF THIS SOFTWARE. diff --git a/pymongo/_asyncio_lock.py b/pymongo/_asyncio_lock.py new file mode 100644 index 0000000000..669b0f63a7 --- /dev/null +++ b/pymongo/_asyncio_lock.py @@ -0,0 +1,309 @@ +# Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved + +"""Lock and Condition classes vendored from https://github.com/python/cpython/blob/main/Lib/asyncio/locks.py +to port 3.13 fixes to older versions of Python. +Can be removed once we drop Python 3.12 support.""" + +from __future__ import annotations + +import collections +import threading +from asyncio import events, exceptions +from typing import Any, Coroutine, Optional + +_global_lock = threading.Lock() + + +class _LoopBoundMixin: + _loop = None + + def _get_loop(self) -> Any: + loop = events._get_running_loop() + + if self._loop is None: + with _global_lock: + if self._loop is None: + self._loop = loop + if loop is not self._loop: + raise RuntimeError(f"{self!r} is bound to a different event loop") + return loop + + +class _ContextManagerMixin: + async def __aenter__(self) -> None: + await self.acquire() # type: ignore[attr-defined] + # We have no use for the "as ..." clause in the with + # statement for locks. + return + + async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: + self.release() # type: ignore[attr-defined] + + +class Lock(_ContextManagerMixin, _LoopBoundMixin): + """Primitive lock objects. + + A primitive lock is a synchronization primitive that is not owned + by a particular task when locked. A primitive lock is in one + of two states, 'locked' or 'unlocked'. + + It is created in the unlocked state. It has two basic methods, + acquire() and release(). When the state is unlocked, acquire() + changes the state to locked and returns immediately. When the + state is locked, acquire() blocks until a call to release() in + another task changes it to unlocked, then the acquire() call + resets it to locked and returns. The release() method should only + be called in the locked state; it changes the state to unlocked + and returns immediately. If an attempt is made to release an + unlocked lock, a RuntimeError will be raised. + + When more than one task is blocked in acquire() waiting for + the state to turn to unlocked, only one task proceeds when a + release() call resets the state to unlocked; successive release() + calls will unblock tasks in FIFO order. + + Locks also support the asynchronous context management protocol. + 'async with lock' statement should be used. + + Usage: + + lock = Lock() + ... + await lock.acquire() + try: + ... + finally: + lock.release() + + Context manager usage: + + lock = Lock() + ... + async with lock: + ... + + Lock objects can be tested for locking state: + + if not lock.locked(): + await lock.acquire() + else: + # lock is acquired + ... + + """ + + def __init__(self) -> None: + self._waiters: Optional[collections.deque] = None + self._locked = False + + def __repr__(self) -> str: + res = super().__repr__() + extra = "locked" if self._locked else "unlocked" + if self._waiters: + extra = f"{extra}, waiters:{len(self._waiters)}" + return f"<{res[1:-1]} [{extra}]>" + + def locked(self) -> bool: + """Return True if lock is acquired.""" + return self._locked + + async def acquire(self) -> bool: + """Acquire a lock. + + This method blocks until the lock is unlocked, then sets it to + locked and returns True. + """ + # Implement fair scheduling, where thread always waits + # its turn. Jumping the queue if all are cancelled is an optimization. + if not self._locked and ( + self._waiters is None or all(w.cancelled() for w in self._waiters) + ): + self._locked = True + return True + + if self._waiters is None: + self._waiters = collections.deque() + fut = self._get_loop().create_future() + self._waiters.append(fut) + + try: + try: + await fut + finally: + self._waiters.remove(fut) + except exceptions.CancelledError: + # Currently the only exception designed be able to occur here. + + # Ensure the lock invariant: If lock is not claimed (or about + # to be claimed by us) and there is a Task in waiters, + # ensure that the Task at the head will run. + if not self._locked: + self._wake_up_first() + raise + + # assert self._locked is False + self._locked = True + return True + + def release(self) -> None: + """Release a lock. + + When the lock is locked, reset it to unlocked, and return. + If any other tasks are blocked waiting for the lock to become + unlocked, allow exactly one of them to proceed. + + When invoked on an unlocked lock, a RuntimeError is raised. + + There is no return value. + """ + if self._locked: + self._locked = False + self._wake_up_first() + else: + raise RuntimeError("Lock is not acquired.") + + def _wake_up_first(self) -> None: + """Ensure that the first waiter will wake up.""" + if not self._waiters: + return + try: + fut = next(iter(self._waiters)) + except StopIteration: + return + + # .done() means that the waiter is already set to wake up. + if not fut.done(): + fut.set_result(True) + + +class Condition(_ContextManagerMixin, _LoopBoundMixin): + """Asynchronous equivalent to threading.Condition. + + This class implements condition variable objects. A condition variable + allows one or more tasks to wait until they are notified by another + task. + + A new Lock object is created and used as the underlying lock. + """ + + def __init__(self, lock: Optional[Lock] = None) -> None: + if lock is None: + lock = Lock() + + self._lock = lock + # Export the lock's locked(), acquire() and release() methods. + self.locked = lock.locked + self.acquire = lock.acquire + self.release = lock.release + + self._waiters: collections.deque = collections.deque() + + def __repr__(self) -> str: + res = super().__repr__() + extra = "locked" if self.locked() else "unlocked" + if self._waiters: + extra = f"{extra}, waiters:{len(self._waiters)}" + return f"<{res[1:-1]} [{extra}]>" + + async def wait(self) -> bool: + """Wait until notified. + + If the calling task has not acquired the lock when this + method is called, a RuntimeError is raised. + + This method releases the underlying lock, and then blocks + until it is awakened by a notify() or notify_all() call for + the same condition variable in another task. Once + awakened, it re-acquires the lock and returns True. + + This method may return spuriously, + which is why the caller should always + re-check the state and be prepared to wait() again. + """ + if not self.locked(): + raise RuntimeError("cannot wait on un-acquired lock") + + fut = self._get_loop().create_future() + self.release() + try: + try: + self._waiters.append(fut) + try: + await fut + return True + finally: + self._waiters.remove(fut) + + finally: + # Must re-acquire lock even if wait is cancelled. + # We only catch CancelledError here, since we don't want any + # other (fatal) errors with the future to cause us to spin. + err = None + while True: + try: + await self.acquire() + break + except exceptions.CancelledError as e: + err = e + + if err is not None: + try: + raise err # Re-raise most recent exception instance. + finally: + err = None # Break reference cycles. + except BaseException: + # Any error raised out of here _may_ have occurred after this Task + # believed to have been successfully notified. + # Make sure to notify another Task instead. This may result + # in a "spurious wakeup", which is allowed as part of the + # Condition Variable protocol. + self._notify(1) + raise + + async def wait_for(self, predicate: Any) -> Coroutine: + """Wait until a predicate becomes true. + + The predicate should be a callable whose result will be + interpreted as a boolean value. The method will repeatedly + wait() until it evaluates to true. The final predicate value is + the return value. + """ + result = predicate() + while not result: + await self.wait() + result = predicate() + return result + + def notify(self, n: int = 1) -> None: + """By default, wake up one task waiting on this condition, if any. + If the calling task has not acquired the lock when this method + is called, a RuntimeError is raised. + + This method wakes up n of the tasks waiting for the condition + variable; if fewer than n are waiting, they are all awoken. + + Note: an awakened task does not actually return from its + wait() call until it can reacquire the lock. Since notify() does + not release the lock, its caller should. + """ + if not self.locked(): + raise RuntimeError("cannot notify on un-acquired lock") + self._notify(n) + + def _notify(self, n: int) -> None: + idx = 0 + for fut in self._waiters: + if idx >= n: + break + + if not fut.done(): + idx += 1 + fut.set_result(False) + + def notify_all(self) -> None: + """Wake up all tasks waiting on this condition. This method acts + like notify(), but wakes up all waiting tasks instead of one. If the + calling task has not acquired the lock when this method is called, + a RuntimeError is raised. + """ + self.notify(len(self._waiters)) diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 4b4bb52a8e..7d7ae4a5db 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -45,7 +45,7 @@ ) from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure -from pymongo.lock import _ALock, _create_lock +from pymongo.lock import _async_create_lock from pymongo.message import ( _CursorAddress, _GetMore, @@ -77,7 +77,7 @@ class _ConnectionManager: def __init__(self, conn: AsyncConnection, more_to_come: bool): self.conn: Optional[AsyncConnection] = conn self.more_to_come = more_to_come - self._alock = _ALock(_create_lock()) + self._lock = _async_create_lock() def update_exhaust(self, more_to_come: bool) -> None: self.more_to_come = more_to_come diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index eae2b0df4c..a33246a24b 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -82,7 +82,11 @@ WaitQueueTimeoutError, WriteConcernError, ) -from pymongo.lock import _HAS_REGISTER_AT_FORK, _ALock, _create_lock, _release_locks +from pymongo.lock import ( + _HAS_REGISTER_AT_FORK, + _async_create_lock, + _release_locks, +) from pymongo.logger import _CLIENT_LOGGER, _log_or_warn from pymongo.message import _CursorAddress, _GetMore, _Query from pymongo.monitoring import ConnectionClosedReason @@ -842,7 +846,7 @@ def __init__( self._options = options = ClientOptions(username, password, dbase, opts, _IS_SYNC) self._default_database_name = dbase - self._lock = _ALock(_create_lock()) + self._lock = _async_create_lock() self._kill_cursors_queue: list = [] self._event_listeners = options.pool_options._event_listeners @@ -1721,7 +1725,7 @@ async def _run_operation( address=address, ) - async with operation.conn_mgr._alock: + async with operation.conn_mgr._lock: async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] err_handler.contribute_socket(operation.conn_mgr.conn) return await server.run_operation( @@ -1969,7 +1973,7 @@ async def _close_cursor_now( try: if conn_mgr: - async with conn_mgr._alock: + async with conn_mgr._lock: # Cursor is pinned to LB outside of a transaction. assert address is not None assert conn_mgr.conn is not None diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index a9f02d650a..2fe9579aef 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -23,7 +23,6 @@ import socket import ssl import sys -import threading import time import weakref from typing import ( @@ -65,7 +64,11 @@ _CertificateError, ) from pymongo.hello import Hello, HelloCompat -from pymongo.lock import _ACondition, _ALock, _create_lock +from pymongo.lock import ( + _async_cond_wait, + _async_create_condition, + _async_create_lock, +) from pymongo.logger import ( _CONNECTION_LOGGER, _ConnectionStatusMessage, @@ -208,11 +211,6 @@ def _raise_connection_failure( raise AutoReconnect(msg) from error -async def _cond_wait(condition: _ACondition, deadline: Optional[float]) -> bool: - timeout = deadline - time.monotonic() if deadline else None - return await condition.wait(timeout) - - def _get_timeout_details(options: PoolOptions) -> dict[str, float]: details = {} timeout = _csot.get_timeout() @@ -992,8 +990,8 @@ def __init__( # from the right side. self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() - _lock = _create_lock() - self.lock = _ALock(_lock) + self.lock = _async_create_lock() + self._max_connecting_cond = _async_create_condition(self.lock) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 @@ -1019,7 +1017,7 @@ def __init__( # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue - self.size_cond = _ACondition(threading.Condition(_lock)) + self.size_cond = _async_create_condition(self.lock) self.requests = 0 self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: @@ -1027,7 +1025,7 @@ def __init__( # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue - self._max_connecting_cond = _ACondition(threading.Condition(_lock)) + self._max_connecting_cond = _async_create_condition(self.lock) self._max_connecting = self.opts.max_connecting self._pending = 0 self._client_id = client_id @@ -1456,7 +1454,8 @@ async def _get_conn( async with self.size_cond: self._raise_if_not_ready(checkout_started_time, emit_event=True) while not (self.requests < self.max_pool_size): - if not await _cond_wait(self.size_cond, deadline): + timeout = deadline - time.monotonic() if deadline else None + if not await _async_cond_wait(self.size_cond, timeout): # Timed out, notify the next thread to ensure a # timeout doesn't consume the condition. if self.requests < self.max_pool_size: @@ -1479,7 +1478,8 @@ async def _get_conn( async with self._max_connecting_cond: self._raise_if_not_ready(checkout_started_time, emit_event=False) while not (self.conns or self._pending < self._max_connecting): - if not await _cond_wait(self._max_connecting_cond, deadline): + timeout = deadline - time.monotonic() if deadline else None + if not await _async_cond_wait(self._max_connecting_cond, timeout): # Timed out, notify the next thread to ensure a # timeout doesn't consume the condition. if self.conns or self._pending < self._max_connecting: diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index f0cb56cbf1..6d67710a7e 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -43,7 +43,11 @@ WriteError, ) from pymongo.hello import Hello -from pymongo.lock import _ACondition, _ALock, _create_lock +from pymongo.lock import ( + _async_cond_wait, + _async_create_condition, + _async_create_lock, +) from pymongo.logger import ( _SDAM_LOGGER, _SERVER_SELECTION_LOGGER, @@ -169,9 +173,10 @@ def __init__(self, topology_settings: TopologySettings): self._seed_addresses = list(topology_description.server_descriptions()) self._opened = False self._closed = False - _lock = _create_lock() - self._lock = _ALock(_lock) - self._condition = _ACondition(self._settings.condition_class(_lock)) + self._lock = _async_create_lock() + self._condition = _async_create_condition( + self._lock, self._settings.condition_class if _IS_SYNC else None + ) self._servers: dict[_Address, Server] = {} self._pid: Optional[int] = None self._max_cluster_time: Optional[ClusterTime] = None @@ -353,7 +358,7 @@ async def _select_servers_loop( # change, or for a timeout. We won't miss any changes that # came after our most recent apply_selector call, since we've # held the lock until now. - await self._condition.wait(common.MIN_HEARTBEAT_INTERVAL) + await _async_cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL) self._description.check_compatible() now = time.monotonic() server_descriptions = self._description.apply_selector( @@ -653,7 +658,7 @@ async def request_check_all(self, wait_time: int = 5) -> None: """Wake all monitors, wait for at least one to check its server.""" async with self._lock: self._request_check_all() - await self._condition.wait(wait_time) + await _async_cond_wait(self._condition, wait_time) def data_bearing_servers(self) -> list[ServerDescription]: """Return a list of all data-bearing servers. diff --git a/pymongo/lock.py b/pymongo/lock.py index 0cbfb4a57e..6bf7138017 100644 --- a/pymongo/lock.py +++ b/pymongo/lock.py @@ -11,15 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +"""Internal helpers for lock and condition coordination primitives.""" + from __future__ import annotations import asyncio -import collections import os +import sys import threading -import time import weakref -from typing import Any, Callable, Optional, TypeVar +from asyncio import wait_for +from typing import Any, Optional, TypeVar + +import pymongo._asyncio_lock _HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork") @@ -28,6 +33,15 @@ _T = TypeVar("_T") +# Needed to support 3.13 asyncio fixes (https://github.com/python/cpython/issues/112202) +# in older versions of Python +if sys.version_info >= (3, 13): + Lock = asyncio.Lock + Condition = asyncio.Condition +else: + Lock = pymongo._asyncio_lock.Lock + Condition = pymongo._asyncio_lock.Condition + def _create_lock() -> threading.Lock: """Represents a lock that is tracked upon instantiation using a WeakSet and @@ -39,6 +53,27 @@ def _create_lock() -> threading.Lock: return lock +def _async_create_lock() -> Lock: + """Represents an asyncio.Lock.""" + return Lock() + + +def _create_condition( + lock: threading.Lock, condition_class: Optional[Any] = None +) -> threading.Condition: + """Represents a threading.Condition.""" + if condition_class: + return condition_class(lock) + return threading.Condition(lock) + + +def _async_create_condition(lock: Lock, condition_class: Optional[Any] = None) -> Condition: + """Represents an asyncio.Condition.""" + if condition_class: + return condition_class(lock) + return Condition(lock) + + def _release_locks() -> None: # Completed the fork, reset all the locks in the child. for lock in _forkable_locks: @@ -46,202 +81,12 @@ def _release_locks() -> None: lock.release() -# Needed only for synchro.py compat. -def _Lock(lock: threading.Lock) -> threading.Lock: - return lock +async def _async_cond_wait(condition: Condition, timeout: Optional[float]) -> bool: + try: + return await wait_for(condition.wait(), timeout) + except asyncio.TimeoutError: + return False -class _ALock: - __slots__ = ("_lock",) - - def __init__(self, lock: threading.Lock) -> None: - self._lock = lock - - def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: - return self._lock.acquire(blocking=blocking, timeout=timeout) - - async def a_acquire(self, blocking: bool = True, timeout: float = -1) -> bool: - if timeout > 0: - tstart = time.monotonic() - while True: - acquired = self._lock.acquire(blocking=False) - if acquired: - return True - if timeout > 0 and (time.monotonic() - tstart) > timeout: - return False - if not blocking: - return False - await asyncio.sleep(0) - - def release(self) -> None: - self._lock.release() - - async def __aenter__(self) -> _ALock: - await self.a_acquire() - return self - - def __enter__(self) -> _ALock: - self._lock.acquire() - return self - - def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - self.release() - - async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - self.release() - - -def _safe_set_result(fut: asyncio.Future) -> None: - # Ensure the future hasn't been cancelled before calling set_result. - if not fut.done(): - fut.set_result(False) - - -class _ACondition: - __slots__ = ("_condition", "_waiters") - - def __init__(self, condition: threading.Condition) -> None: - self._condition = condition - self._waiters: collections.deque = collections.deque() - - async def acquire(self, blocking: bool = True, timeout: float = -1) -> bool: - if timeout > 0: - tstart = time.monotonic() - while True: - acquired = self._condition.acquire(blocking=False) - if acquired: - return True - if timeout > 0 and (time.monotonic() - tstart) > timeout: - return False - if not blocking: - return False - await asyncio.sleep(0) - - async def wait(self, timeout: Optional[float] = None) -> bool: - """Wait until notified. - - If the calling task has not acquired the lock when this - method is called, a RuntimeError is raised. - - This method releases the underlying lock, and then blocks - until it is awakened by a notify() or notify_all() call for - the same condition variable in another task. Once - awakened, it re-acquires the lock and returns True. - - This method may return spuriously, - which is why the caller should always - re-check the state and be prepared to wait() again. - """ - loop = asyncio.get_running_loop() - fut = loop.create_future() - self._waiters.append((loop, fut)) - self.release() - try: - try: - try: - await asyncio.wait_for(fut, timeout) - return True - except asyncio.TimeoutError: - return False # Return false on timeout for sync pool compat. - finally: - # Must re-acquire lock even if wait is cancelled. - # We only catch CancelledError here, since we don't want any - # other (fatal) errors with the future to cause us to spin. - err = None - while True: - try: - await self.acquire() - break - except asyncio.exceptions.CancelledError as e: - err = e - - self._waiters.remove((loop, fut)) - if err is not None: - try: - raise err # Re-raise most recent exception instance. - finally: - err = None # Break reference cycles. - except BaseException: - # Any error raised out of here _may_ have occurred after this Task - # believed to have been successfully notified. - # Make sure to notify another Task instead. This may result - # in a "spurious wakeup", which is allowed as part of the - # Condition Variable protocol. - self.notify(1) - raise - - async def wait_for(self, predicate: Callable[[], _T]) -> _T: - """Wait until a predicate becomes true. - - The predicate should be a callable whose result will be - interpreted as a boolean value. The method will repeatedly - wait() until it evaluates to true. The final predicate value is - the return value. - """ - result = predicate() - while not result: - await self.wait() - result = predicate() - return result - - def notify(self, n: int = 1) -> None: - """By default, wake up one coroutine waiting on this condition, if any. - If the calling coroutine has not acquired the lock when this method - is called, a RuntimeError is raised. - - This method wakes up at most n of the coroutines waiting for the - condition variable; it is a no-op if no coroutines are waiting. - - Note: an awakened coroutine does not actually return from its - wait() call until it can reacquire the lock. Since notify() does - not release the lock, its caller should. - """ - idx = 0 - to_remove = [] - for loop, fut in self._waiters: - if idx >= n: - break - - if fut.done(): - continue - - try: - loop.call_soon_threadsafe(_safe_set_result, fut) - except RuntimeError: - # Loop was closed, ignore. - to_remove.append((loop, fut)) - continue - - idx += 1 - - for waiter in to_remove: - self._waiters.remove(waiter) - - def notify_all(self) -> None: - """Wake up all threads waiting on this condition. This method acts - like notify(), but wakes up all waiting threads instead of one. If the - calling thread has not acquired the lock when this method is called, - a RuntimeError is raised. - """ - self.notify(len(self._waiters)) - - def locked(self) -> bool: - """Only needed for tests in test_locks.""" - return self._condition._lock.locked() # type: ignore[attr-defined] - - def release(self) -> None: - self._condition.release() - - async def __aenter__(self) -> _ACondition: - await self.acquire() - return self - - def __enter__(self) -> _ACondition: - self._condition.acquire() - return self - - async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - self.release() - - def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: - self.release() +def _cond_wait(condition: threading.Condition, timeout: Optional[float]) -> bool: + return condition.wait(timeout) diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index 27a76cf91d..9a7637704f 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -77,7 +77,7 @@ class _ConnectionManager: def __init__(self, conn: Connection, more_to_come: bool): self.conn: Optional[Connection] = conn self.more_to_come = more_to_come - self._alock = _create_lock() + self._lock = _create_lock() def update_exhaust(self, more_to_come: bool) -> None: self.more_to_come = more_to_come diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 7eab5e74f1..eb363f82f5 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -74,7 +74,11 @@ WaitQueueTimeoutError, WriteConcernError, ) -from pymongo.lock import _HAS_REGISTER_AT_FORK, _create_lock, _release_locks +from pymongo.lock import ( + _HAS_REGISTER_AT_FORK, + _create_lock, + _release_locks, +) from pymongo.logger import _CLIENT_LOGGER, _log_or_warn from pymongo.message import _CursorAddress, _GetMore, _Query from pymongo.monitoring import ConnectionClosedReason @@ -1715,7 +1719,7 @@ def _run_operation( address=address, ) - with operation.conn_mgr._alock: + with operation.conn_mgr._lock: with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] err_handler.contribute_socket(operation.conn_mgr.conn) return server.run_operation( @@ -1963,7 +1967,7 @@ def _close_cursor_now( try: if conn_mgr: - with conn_mgr._alock: + with conn_mgr._lock: # Cursor is pinned to LB outside of a transaction. assert address is not None assert conn_mgr.conn is not None diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index eb007a3471..6ac7b4eca9 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -23,7 +23,6 @@ import socket import ssl import sys -import threading import time import weakref from typing import ( @@ -62,7 +61,11 @@ _CertificateError, ) from pymongo.hello import Hello, HelloCompat -from pymongo.lock import _create_lock, _Lock +from pymongo.lock import ( + _cond_wait, + _create_condition, + _create_lock, +) from pymongo.logger import ( _CONNECTION_LOGGER, _ConnectionStatusMessage, @@ -208,11 +211,6 @@ def _raise_connection_failure( raise AutoReconnect(msg) from error -def _cond_wait(condition: threading.Condition, deadline: Optional[float]) -> bool: - timeout = deadline - time.monotonic() if deadline else None - return condition.wait(timeout) - - def _get_timeout_details(options: PoolOptions) -> dict[str, float]: details = {} timeout = _csot.get_timeout() @@ -988,8 +986,8 @@ def __init__( # from the right side. self.conns: collections.deque = collections.deque() self.active_contexts: set[_CancellationContext] = set() - _lock = _create_lock() - self.lock = _Lock(_lock) + self.lock = _create_lock() + self._max_connecting_cond = _create_condition(self.lock) self.active_sockets = 0 # Monotonically increasing connection ID required for CMAP Events. self.next_connection_id = 1 @@ -1015,7 +1013,7 @@ def __init__( # The first portion of the wait queue. # Enforces: maxPoolSize # Also used for: clearing the wait queue - self.size_cond = threading.Condition(_lock) + self.size_cond = _create_condition(self.lock) self.requests = 0 self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: @@ -1023,7 +1021,7 @@ def __init__( # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue - self._max_connecting_cond = threading.Condition(_lock) + self._max_connecting_cond = _create_condition(self.lock) self._max_connecting = self.opts.max_connecting self._pending = 0 self._client_id = client_id @@ -1450,7 +1448,8 @@ def _get_conn( with self.size_cond: self._raise_if_not_ready(checkout_started_time, emit_event=True) while not (self.requests < self.max_pool_size): - if not _cond_wait(self.size_cond, deadline): + timeout = deadline - time.monotonic() if deadline else None + if not _cond_wait(self.size_cond, timeout): # Timed out, notify the next thread to ensure a # timeout doesn't consume the condition. if self.requests < self.max_pool_size: @@ -1473,7 +1472,8 @@ def _get_conn( with self._max_connecting_cond: self._raise_if_not_ready(checkout_started_time, emit_event=False) while not (self.conns or self._pending < self._max_connecting): - if not _cond_wait(self._max_connecting_cond, deadline): + timeout = deadline - time.monotonic() if deadline else None + if not _cond_wait(self._max_connecting_cond, timeout): # Timed out, notify the next thread to ensure a # timeout doesn't consume the condition. if self.conns or self._pending < self._max_connecting: diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index e34de6bc50..b03269ae43 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -39,7 +39,11 @@ WriteError, ) from pymongo.hello import Hello -from pymongo.lock import _create_lock, _Lock +from pymongo.lock import ( + _cond_wait, + _create_condition, + _create_lock, +) from pymongo.logger import ( _SDAM_LOGGER, _SERVER_SELECTION_LOGGER, @@ -169,9 +173,10 @@ def __init__(self, topology_settings: TopologySettings): self._seed_addresses = list(topology_description.server_descriptions()) self._opened = False self._closed = False - _lock = _create_lock() - self._lock = _Lock(_lock) - self._condition = self._settings.condition_class(_lock) + self._lock = _create_lock() + self._condition = _create_condition( + self._lock, self._settings.condition_class if _IS_SYNC else None + ) self._servers: dict[_Address, Server] = {} self._pid: Optional[int] = None self._max_cluster_time: Optional[ClusterTime] = None @@ -353,7 +358,7 @@ def _select_servers_loop( # change, or for a timeout. We won't miss any changes that # came after our most recent apply_selector call, since we've # held the lock until now. - self._condition.wait(common.MIN_HEARTBEAT_INTERVAL) + _cond_wait(self._condition, common.MIN_HEARTBEAT_INTERVAL) self._description.check_compatible() now = time.monotonic() server_descriptions = self._description.apply_selector( @@ -651,7 +656,7 @@ def request_check_all(self, wait_time: int = 5) -> None: """Wake all monitors, wait for at least one to check its server.""" with self._lock: self._request_check_all() - self._condition.wait(wait_time) + _cond_wait(self._condition, wait_time) def data_bearing_servers(self) -> list[ServerDescription]: """Return a list of all data-bearing servers. diff --git a/test/__init__.py b/test/__init__.py index 940518c2c5..c1944f5870 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1131,7 +1131,7 @@ class IntegrationTest(PyMongoTestCase): @client_context.require_connection def setUp(self) -> None: - if not _IS_SYNC: + if not _IS_SYNC and client_context.client is not None: reset_client_context() if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 8d1e3e1911..9ca5a32ffc 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -1149,7 +1149,7 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase): @async_client_context.require_connection async def asyncSetUp(self) -> None: - if not _IS_SYNC: + if not _IS_SYNC and async_client_context.client is not None: await reset_client_context() if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") diff --git a/test/asynchronous/test_locks.py b/test/asynchronous/test_locks.py index e0e7f2fc8d..e5a0adfee6 100644 --- a/test/asynchronous/test_locks.py +++ b/test/asynchronous/test_locks.py @@ -16,498 +16,447 @@ import asyncio import sys -import threading import unittest +from pymongo.lock import _async_create_condition, _async_create_lock + sys.path[0:0] = [""] -from pymongo.lock import _ACondition +if sys.version_info < (3, 13): + # Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py + # Includes tests for: + # - https://github.com/python/cpython/issues/111693 + # - https://github.com/python/cpython/issues/112202 + class TestConditionStdlib(unittest.IsolatedAsyncioTestCase): + async def test_wait(self): + cond = _async_create_condition(_async_create_lock()) + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait(): + result.append(1) + return True + async def c2(result): + await cond.acquire() + if await cond.wait(): + result.append(2) + return True -# Tests adapted from: https://github.com/python/cpython/blob/v3.13.0rc2/Lib/test/test_asyncio/test_locks.py -# Includes tests for: -# - https://github.com/python/cpython/issues/111693 -# - https://github.com/python/cpython/issues/112202 -class TestConditionStdlib(unittest.IsolatedAsyncioTestCase): - async def test_wait(self): - cond = _ACondition(threading.Condition(threading.Lock())) - result = [] + async def c3(result): + await cond.acquire() + if await cond.wait(): + result.append(3) + return True - async def c1(result): - await cond.acquire() - if await cond.wait(): - result.append(1) - return True + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) - async def c2(result): - await cond.acquire() - if await cond.wait(): - result.append(2) - return True + await asyncio.sleep(0) + self.assertEqual([], result) + self.assertFalse(cond.locked()) - async def c3(result): - await cond.acquire() - if await cond.wait(): - result.append(3) - return True - - t1 = asyncio.create_task(c1(result)) - t2 = asyncio.create_task(c2(result)) - t3 = asyncio.create_task(c3(result)) - - await asyncio.sleep(0) - self.assertEqual([], result) - self.assertFalse(cond.locked()) - - self.assertTrue(await cond.acquire()) - cond.notify() - await asyncio.sleep(0) - self.assertEqual([], result) - self.assertTrue(cond.locked()) - - cond.release() - await asyncio.sleep(0) - self.assertEqual([1], result) - self.assertTrue(cond.locked()) - - cond.notify(2) - await asyncio.sleep(0) - self.assertEqual([1], result) - self.assertTrue(cond.locked()) - - cond.release() - await asyncio.sleep(0) - self.assertEqual([1, 2], result) - self.assertTrue(cond.locked()) - - cond.release() - await asyncio.sleep(0) - self.assertEqual([1, 2, 3], result) - self.assertTrue(cond.locked()) - - self.assertTrue(t1.done()) - self.assertTrue(t1.result()) - self.assertTrue(t2.done()) - self.assertTrue(t2.result()) - self.assertTrue(t3.done()) - self.assertTrue(t3.result()) - - async def test_wait_cancel(self): - cond = _ACondition(threading.Condition(threading.Lock())) - await cond.acquire() - - wait = asyncio.create_task(cond.wait()) - asyncio.get_running_loop().call_soon(wait.cancel) - with self.assertRaises(asyncio.CancelledError): - await wait - self.assertFalse(cond._waiters) - self.assertTrue(cond.locked()) - - async def test_wait_cancel_contested(self): - cond = _ACondition(threading.Condition(threading.Lock())) - - await cond.acquire() - self.assertTrue(cond.locked()) - - wait_task = asyncio.create_task(cond.wait()) - await asyncio.sleep(0) - self.assertFalse(cond.locked()) - - # Notify, but contest the lock before cancelling - await cond.acquire() - self.assertTrue(cond.locked()) - cond.notify() - asyncio.get_running_loop().call_soon(wait_task.cancel) - asyncio.get_running_loop().call_soon(cond.release) - - try: - await wait_task - except asyncio.CancelledError: - # Should not happen, since no cancellation points - pass - - self.assertTrue(cond.locked()) - - async def test_wait_cancel_after_notify(self): - # See bpo-32841 - waited = False - - cond = _ACondition(threading.Condition(threading.Lock())) - - async def wait_on_cond(): - nonlocal waited - async with cond: - waited = True # Make sure this area was reached - await cond.wait() + self.assertTrue(await cond.acquire()) + cond.notify() + await asyncio.sleep(0) + self.assertEqual([], result) + self.assertTrue(cond.locked()) - waiter = asyncio.create_task(wait_on_cond()) - await asyncio.sleep(0) # Start waiting + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) + + cond.notify(2) + await asyncio.sleep(0) + self.assertEqual([1], result) + self.assertTrue(cond.locked()) - await cond.acquire() - cond.notify() - await asyncio.sleep(0) # Get to acquire() - waiter.cancel() - await asyncio.sleep(0) # Activate cancellation - cond.release() - await asyncio.sleep(0) # Cancellation should occur + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2], result) + self.assertTrue(cond.locked()) - self.assertTrue(waiter.cancelled()) - self.assertTrue(waited) + cond.release() + await asyncio.sleep(0) + self.assertEqual([1, 2, 3], result) + self.assertTrue(cond.locked()) - async def test_wait_unacquired(self): - cond = _ACondition(threading.Condition(threading.Lock())) - with self.assertRaises(RuntimeError): - await cond.wait() + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) - async def test_wait_for(self): - cond = _ACondition(threading.Condition(threading.Lock())) - presult = False + async def test_wait_cancel(self): + cond = _async_create_condition(_async_create_lock()) + await cond.acquire() - def predicate(): - return presult + wait = asyncio.create_task(cond.wait()) + asyncio.get_running_loop().call_soon(wait.cancel) + with self.assertRaises(asyncio.CancelledError): + await wait + self.assertFalse(cond._waiters) + self.assertTrue(cond.locked()) - result = [] + async def test_wait_cancel_contested(self): + cond = _async_create_condition(_async_create_lock()) - async def c1(result): await cond.acquire() - if await cond.wait_for(predicate): - result.append(1) - cond.release() - return True + self.assertTrue(cond.locked()) - t = asyncio.create_task(c1(result)) + wait_task = asyncio.create_task(cond.wait()) + await asyncio.sleep(0) + self.assertFalse(cond.locked()) - await asyncio.sleep(0) - self.assertEqual([], result) + # Notify, but contest the lock before cancelling + await cond.acquire() + self.assertTrue(cond.locked()) + cond.notify() + asyncio.get_running_loop().call_soon(wait_task.cancel) + asyncio.get_running_loop().call_soon(cond.release) - await cond.acquire() - cond.notify() - cond.release() - await asyncio.sleep(0) - self.assertEqual([], result) + try: + await wait_task + except asyncio.CancelledError: + # Should not happen, since no cancellation points + pass - presult = True - await cond.acquire() - cond.notify() - cond.release() - await asyncio.sleep(0) - self.assertEqual([1], result) + self.assertTrue(cond.locked()) - self.assertTrue(t.done()) - self.assertTrue(t.result()) + async def test_wait_cancel_after_notify(self): + # See bpo-32841 + waited = False - async def test_wait_for_unacquired(self): - cond = _ACondition(threading.Condition(threading.Lock())) + cond = _async_create_condition(_async_create_lock()) - # predicate can return true immediately - res = await cond.wait_for(lambda: [1, 2, 3]) - self.assertEqual([1, 2, 3], res) + async def wait_on_cond(): + nonlocal waited + async with cond: + waited = True # Make sure this area was reached + await cond.wait() - with self.assertRaises(RuntimeError): - await cond.wait_for(lambda: False) + waiter = asyncio.create_task(wait_on_cond()) + await asyncio.sleep(0) # Start waiting - async def test_notify(self): - cond = _ACondition(threading.Condition(threading.Lock())) - result = [] + await cond.acquire() + cond.notify() + await asyncio.sleep(0) # Get to acquire() + waiter.cancel() + await asyncio.sleep(0) # Activate cancellation + cond.release() + await asyncio.sleep(0) # Cancellation should occur + + self.assertTrue(waiter.cancelled()) + self.assertTrue(waited) + + async def test_wait_unacquired(self): + cond = _async_create_condition(_async_create_lock()) + with self.assertRaises(RuntimeError): + await cond.wait() - async def c1(result): - async with cond: - if await cond.wait(): - result.append(1) - return True + async def test_wait_for(self): + cond = _async_create_condition(_async_create_lock()) + presult = False - async def c2(result): - async with cond: - if await cond.wait(): - result.append(2) - return True + def predicate(): + return presult - async def c3(result): - async with cond: - if await cond.wait(): - result.append(3) + result = [] + + async def c1(result): + await cond.acquire() + if await cond.wait_for(predicate): + result.append(1) + cond.release() return True - t1 = asyncio.create_task(c1(result)) - t2 = asyncio.create_task(c2(result)) - t3 = asyncio.create_task(c3(result)) + t = asyncio.create_task(c1(result)) - await asyncio.sleep(0) - self.assertEqual([], result) + await asyncio.sleep(0) + self.assertEqual([], result) - async with cond: - cond.notify(1) - await asyncio.sleep(1) - self.assertEqual([1], result) + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([], result) - async with cond: - cond.notify(1) - cond.notify(2048) - await asyncio.sleep(1) - self.assertEqual([1, 2, 3], result) + presult = True + await cond.acquire() + cond.notify() + cond.release() + await asyncio.sleep(0) + self.assertEqual([1], result) - self.assertTrue(t1.done()) - self.assertTrue(t1.result()) - self.assertTrue(t2.done()) - self.assertTrue(t2.result()) - self.assertTrue(t3.done()) - self.assertTrue(t3.result()) + self.assertTrue(t.done()) + self.assertTrue(t.result()) - async def test_notify_all(self): - cond = _ACondition(threading.Condition(threading.Lock())) + async def test_wait_for_unacquired(self): + cond = _async_create_condition(_async_create_lock()) - result = [] + # predicate can return true immediately + res = await cond.wait_for(lambda: [1, 2, 3]) + self.assertEqual([1, 2, 3], res) - async def c1(result): - async with cond: - if await cond.wait(): - result.append(1) - return True + with self.assertRaises(RuntimeError): + await cond.wait_for(lambda: False) - async def c2(result): - async with cond: - if await cond.wait(): - result.append(2) - return True + async def test_notify(self): + cond = _async_create_condition(_async_create_lock()) + result = [] - t1 = asyncio.create_task(c1(result)) - t2 = asyncio.create_task(c2(result)) + async def c1(result): + async with cond: + if await cond.wait(): + result.append(1) + return True - await asyncio.sleep(0) - self.assertEqual([], result) + async def c2(result): + async with cond: + if await cond.wait(): + result.append(2) + return True - async with cond: - cond.notify_all() - await asyncio.sleep(1) - self.assertEqual([1, 2], result) + async def c3(result): + async with cond: + if await cond.wait(): + result.append(3) + return True - self.assertTrue(t1.done()) - self.assertTrue(t1.result()) - self.assertTrue(t2.done()) - self.assertTrue(t2.result()) + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) + t3 = asyncio.create_task(c3(result)) - async def test_context_manager(self): - cond = _ACondition(threading.Condition(threading.Lock())) - self.assertFalse(cond.locked()) - async with cond: - self.assertTrue(cond.locked()) - self.assertFalse(cond.locked()) - - async def test_timeout_in_block(self): - condition = _ACondition(threading.Condition(threading.Lock())) - async with condition: - with self.assertRaises(asyncio.TimeoutError): - await asyncio.wait_for(condition.wait(), timeout=0.5) - - @unittest.skipIf( - sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" - ) - async def test_cancelled_error_wakeup(self): - # Test that a cancelled error, received when awaiting wakeup, - # will be re-raised un-modified. - wake = False - raised = None - cond = _ACondition(threading.Condition(threading.Lock())) - - async def func(): - nonlocal raised - async with cond: - with self.assertRaises(asyncio.CancelledError) as err: - await cond.wait_for(lambda: wake) - raised = err.exception - raise raised - - task = asyncio.create_task(func()) - await asyncio.sleep(0) - # Task is waiting on the condition, cancel it there. - task.cancel(msg="foo") # type: ignore[call-arg] - with self.assertRaises(asyncio.CancelledError) as err: - await task - self.assertEqual(err.exception.args, ("foo",)) - # We should have got the _same_ exception instance as the one - # originally raised. - self.assertIs(err.exception, raised) - - @unittest.skipIf( - sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" - ) - async def test_cancelled_error_re_aquire(self): - # Test that a cancelled error, received when re-aquiring lock, - # will be re-raised un-modified. - wake = False - raised = None - cond = _ACondition(threading.Condition(threading.Lock())) - - async def func(): - nonlocal raised - async with cond: - with self.assertRaises(asyncio.CancelledError) as err: - await cond.wait_for(lambda: wake) - raised = err.exception - raise raised - - task = asyncio.create_task(func()) - await asyncio.sleep(0) - # Task is waiting on the condition - await cond.acquire() - wake = True - cond.notify() - await asyncio.sleep(0) - # Task is now trying to re-acquire the lock, cancel it there. - task.cancel(msg="foo") # type: ignore[call-arg] - cond.release() - with self.assertRaises(asyncio.CancelledError) as err: - await task - self.assertEqual(err.exception.args, ("foo",)) - # We should have got the _same_ exception instance as the one - # originally raised. - self.assertIs(err.exception, raised) - - @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") - async def test_cancelled_wakeup(self): - # Test that a task cancelled at the "same" time as it is woken - # up as part of a Condition.notify() does not result in a lost wakeup. - # This test simulates a cancel while the target task is awaiting initial - # wakeup on the wakeup queue. - condition = _ACondition(threading.Condition(threading.Lock())) - state = 0 - - async def consumer(): - nonlocal state - async with condition: - while True: - await condition.wait_for(lambda: state != 0) - if state < 0: - return - state -= 1 - - # create two consumers - c = [asyncio.create_task(consumer()) for _ in range(2)] - # wait for them to settle - await asyncio.sleep(0.1) - async with condition: - # produce one item and wake up one - state += 1 - condition.notify(1) - - # Cancel it while it is awaiting to be run. - # This cancellation could come from the outside - c[0].cancel() - - # now wait for the item to be consumed - # if it doesn't means that our "notify" didn"t take hold. - # because it raced with a cancel() - try: - async with asyncio.timeout(1): - await condition.wait_for(lambda: state == 0) - except TimeoutError: - pass - self.assertEqual(state, 0) - - # clean up - state = -1 - condition.notify_all() - await c[1] - - @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") - async def test_cancelled_wakeup_relock(self): - # Test that a task cancelled at the "same" time as it is woken - # up as part of a Condition.notify() does not result in a lost wakeup. - # This test simulates a cancel while the target task is acquiring the lock - # again. - condition = _ACondition(threading.Condition(threading.Lock())) - state = 0 - - async def consumer(): - nonlocal state - async with condition: - while True: - await condition.wait_for(lambda: state != 0) - if state < 0: - return - state -= 1 - - # create two consumers - c = [asyncio.create_task(consumer()) for _ in range(2)] - # wait for them to settle - await asyncio.sleep(0.1) - async with condition: - # produce one item and wake up one - state += 1 - condition.notify(1) - - # now we sleep for a bit. This allows the target task to wake up and - # settle on re-aquiring the lock await asyncio.sleep(0) + self.assertEqual([], result) - # Cancel it while awaiting the lock - # This cancel could come the outside. - c[0].cancel() + async with cond: + cond.notify(1) + await asyncio.sleep(1) + self.assertEqual([1], result) - # now wait for the item to be consumed - # if it doesn't means that our "notify" didn"t take hold. - # because it raced with a cancel() - try: - async with asyncio.timeout(1): - await condition.wait_for(lambda: state == 0) - except TimeoutError: - pass - self.assertEqual(state, 0) + async with cond: + cond.notify(1) + cond.notify(2048) + await asyncio.sleep(1) + self.assertEqual([1, 2, 3], result) - # clean up - state = -1 - condition.notify_all() - await c[1] + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + self.assertTrue(t3.done()) + self.assertTrue(t3.result()) + async def test_notify_all(self): + cond = _async_create_condition(_async_create_lock()) -class TestCondition(unittest.IsolatedAsyncioTestCase): - async def test_multiple_loops_notify(self): - cond = _ACondition(threading.Condition(threading.Lock())) + result = [] - def tmain(cond): - async def atmain(cond): - await asyncio.sleep(1) + async def c1(result): async with cond: - cond.notify(1) - - asyncio.run(atmain(cond)) - - t = threading.Thread(target=tmain, args=(cond,)) - t.start() + if await cond.wait(): + result.append(1) + return True - async with cond: - self.assertTrue(await cond.wait(30)) - t.join() - - async def test_multiple_loops_notify_all(self): - cond = _ACondition(threading.Condition(threading.Lock())) - results = [] - - def tmain(cond, results): - async def atmain(cond, results): - await asyncio.sleep(1) + async def c2(result): async with cond: - res = await cond.wait(30) - results.append(res) - - asyncio.run(atmain(cond, results)) + if await cond.wait(): + result.append(2) + return True - nthreads = 5 - threads = [] - for _ in range(nthreads): - threads.append(threading.Thread(target=tmain, args=(cond, results))) - for t in threads: - t.start() + t1 = asyncio.create_task(c1(result)) + t2 = asyncio.create_task(c2(result)) - await asyncio.sleep(2) - async with cond: - cond.notify_all() + await asyncio.sleep(0) + self.assertEqual([], result) - for t in threads: - t.join() + async with cond: + cond.notify_all() + await asyncio.sleep(1) + self.assertEqual([1, 2], result) + + self.assertTrue(t1.done()) + self.assertTrue(t1.result()) + self.assertTrue(t2.done()) + self.assertTrue(t2.result()) + + async def test_context_manager(self): + cond = _async_create_condition(_async_create_lock()) + self.assertFalse(cond.locked()) + async with cond: + self.assertTrue(cond.locked()) + self.assertFalse(cond.locked()) - self.assertEqual(results, [True] * nthreads) + async def test_timeout_in_block(self): + condition = _async_create_condition(_async_create_lock()) + async with condition: + with self.assertRaises(asyncio.TimeoutError): + await asyncio.wait_for(condition.wait(), timeout=0.5) + + @unittest.skipIf( + sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" + ) + async def test_cancelled_error_wakeup(self): + # Test that a cancelled error, received when awaiting wakeup, + # will be re-raised un-modified. + wake = False + raised = None + cond = _async_create_condition(_async_create_lock()) + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition, cancel it there. + task.cancel(msg="foo") # type: ignore[call-arg] + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + @unittest.skipIf( + sys.version_info < (3, 11), "raising the same cancelled error requires Python>=3.11" + ) + async def test_cancelled_error_re_aquire(self): + # Test that a cancelled error, received when re-aquiring lock, + # will be re-raised un-modified. + wake = False + raised = None + cond = _async_create_condition(_async_create_lock()) + + async def func(): + nonlocal raised + async with cond: + with self.assertRaises(asyncio.CancelledError) as err: + await cond.wait_for(lambda: wake) + raised = err.exception + raise raised -if __name__ == "__main__": - unittest.main() + task = asyncio.create_task(func()) + await asyncio.sleep(0) + # Task is waiting on the condition + await cond.acquire() + wake = True + cond.notify() + await asyncio.sleep(0) + # Task is now trying to re-acquire the lock, cancel it there. + task.cancel(msg="foo") # type: ignore[call-arg] + cond.release() + with self.assertRaises(asyncio.CancelledError) as err: + await task + self.assertEqual(err.exception.args, ("foo",)) + # We should have got the _same_ exception instance as the one + # originally raised. + self.assertIs(err.exception, raised) + + @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") + async def test_cancelled_wakeup(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is awaiting initial + # wakeup on the wakeup queue. + condition = _async_create_condition(_async_create_lock()) + state = 0 + + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0.1) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # Cancel it while it is awaiting to be run. + # This cancellation could come from the outside + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(1): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + + @unittest.skipIf(sys.version_info < (3, 11), "asyncio.timeout requires Python>=3.11") + async def test_cancelled_wakeup_relock(self): + # Test that a task cancelled at the "same" time as it is woken + # up as part of a Condition.notify() does not result in a lost wakeup. + # This test simulates a cancel while the target task is acquiring the lock + # again. + condition = _async_create_condition(_async_create_lock()) + state = 0 + + async def consumer(): + nonlocal state + async with condition: + while True: + await condition.wait_for(lambda: state != 0) + if state < 0: + return + state -= 1 + + # create two consumers + c = [asyncio.create_task(consumer()) for _ in range(2)] + # wait for them to settle + await asyncio.sleep(0.1) + async with condition: + # produce one item and wake up one + state += 1 + condition.notify(1) + + # now we sleep for a bit. This allows the target task to wake up and + # settle on re-aquiring the lock + await asyncio.sleep(0) + + # Cancel it while awaiting the lock + # This cancel could come the outside. + c[0].cancel() + + # now wait for the item to be consumed + # if it doesn't means that our "notify" didn"t take hold. + # because it raced with a cancel() + try: + async with asyncio.timeout(1): + await condition.wait_for(lambda: state == 0) + except TimeoutError: + pass + self.assertEqual(state, 0) + + # clean up + state = -1 + condition.notify_all() + await c[1] + + if __name__ == "__main__": + unittest.main() diff --git a/tools/synchro.py b/tools/synchro.py index f460b348c4..17841d3025 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -114,6 +114,9 @@ "async_wait_for_event": "wait_for_event", "pymongo_server_monitor_task": "pymongo_server_monitor_thread", "pymongo_server_rtt_task": "pymongo_server_rtt_thread", + "_async_create_lock": "_create_lock", + "_async_create_condition": "_create_condition", + "_async_cond_wait": "_cond_wait", } docstring_replacements: dict[tuple[str, str], str] = { @@ -134,8 +137,6 @@ ".. warning:: This API is currently in beta, meaning the classes, methods, and behaviors described within may change before the full release." } -type_replacements = {"_Condition": "threading.Condition"} - import_replacements = {"test.synchronous": "test"} _pymongo_base = "./pymongo/asynchronous/" @@ -236,8 +237,6 @@ def process_files(files: list[str]) -> None: lines = translate_async_sleeps(lines) if file in docstring_translate_files: lines = translate_docstrings(lines) - translate_locks(lines) - translate_types(lines) if file in sync_test_files: translate_imports(lines) f.seek(0) @@ -271,34 +270,6 @@ def translate_coroutine_types(lines: list[str]) -> list[str]: return lines -def translate_locks(lines: list[str]) -> list[str]: - lock_lines = [line for line in lines if "_Lock(" in line] - cond_lines = [line for line in lines if "_Condition(" in line] - for line in lock_lines: - res = re.search(r"_Lock\(([^()]*\([^()]*\))\)", line) - if res: - old = res[0] - index = lines.index(line) - lines[index] = line.replace(old, res[1]) - for line in cond_lines: - res = re.search(r"_Condition\(([^()]*\([^()]*\))\)", line) - if res: - old = res[0] - index = lines.index(line) - lines[index] = line.replace(old, res[1]) - - return lines - - -def translate_types(lines: list[str]) -> list[str]: - for k, v in type_replacements.items(): - matches = [line for line in lines if k in line and "import" not in line] - for line in matches: - index = lines.index(line) - lines[index] = line.replace(k, v) - return lines - - def translate_imports(lines: list[str]) -> list[str]: for k, v in import_replacements.items(): matches = [line for line in lines if k in line and "import" in line]