Skip to content

Commit 51ffa4a

Browse files
committed
Move vendored lock and condition classes into separate file
1 parent 67d1a53 commit 51ffa4a

File tree

6 files changed

+792
-304
lines changed

6 files changed

+792
-304
lines changed

THIRD-PARTY-NOTICES

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
7373
ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
7474

7575

76-
3) License Notice for lock.py
76+
3) License Notice for async_lock.py
7777
-----------------------------------------
7878

7979
1. This LICENSE AGREEMENT is between the Python Software Foundation

pymongo/_asyncio_lock.py

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
# Copyright (c) 2001-2024 Python Software Foundation; All Rights Reserved
2+
3+
"""Lock and Condition classes vendored from https://github.com/python/cpython/blob/main/Lib/asyncio/locks.py
4+
to port 3.13 fixes to older versions of Python.
5+
Can be removed once we drop Python 3.12 support."""
6+
7+
from __future__ import annotations
8+
9+
import collections
10+
import threading
11+
from asyncio import events, exceptions
12+
from typing import Any, Coroutine, Optional
13+
14+
_global_lock = threading.Lock()
15+
16+
17+
class _LoopBoundMixin:
18+
_loop = None
19+
20+
def _get_loop(self) -> Any:
21+
loop = events._get_running_loop()
22+
23+
if self._loop is None:
24+
with _global_lock:
25+
if self._loop is None:
26+
self._loop = loop
27+
if loop is not self._loop:
28+
raise RuntimeError(f"{self!r} is bound to a different event loop")
29+
return loop
30+
31+
32+
class _ContextManagerMixin:
33+
async def __aenter__(self) -> None:
34+
await self.acquire() # type: ignore[attr-defined]
35+
# We have no use for the "as ..." clause in the with
36+
# statement for locks.
37+
return
38+
39+
async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> None:
40+
self.release() # type: ignore[attr-defined]
41+
42+
43+
class Lock(_ContextManagerMixin, _LoopBoundMixin):
44+
"""Primitive lock objects.
45+
46+
A primitive lock is a synchronization primitive that is not owned
47+
by a particular task when locked. A primitive lock is in one
48+
of two states, 'locked' or 'unlocked'.
49+
50+
It is created in the unlocked state. It has two basic methods,
51+
acquire() and release(). When the state is unlocked, acquire()
52+
changes the state to locked and returns immediately. When the
53+
state is locked, acquire() blocks until a call to release() in
54+
another task changes it to unlocked, then the acquire() call
55+
resets it to locked and returns. The release() method should only
56+
be called in the locked state; it changes the state to unlocked
57+
and returns immediately. If an attempt is made to release an
58+
unlocked lock, a RuntimeError will be raised.
59+
60+
When more than one task is blocked in acquire() waiting for
61+
the state to turn to unlocked, only one task proceeds when a
62+
release() call resets the state to unlocked; successive release()
63+
calls will unblock tasks in FIFO order.
64+
65+
Locks also support the asynchronous context management protocol.
66+
'async with lock' statement should be used.
67+
68+
Usage:
69+
70+
lock = Lock()
71+
...
72+
await lock.acquire()
73+
try:
74+
...
75+
finally:
76+
lock.release()
77+
78+
Context manager usage:
79+
80+
lock = Lock()
81+
...
82+
async with lock:
83+
...
84+
85+
Lock objects can be tested for locking state:
86+
87+
if not lock.locked():
88+
await lock.acquire()
89+
else:
90+
# lock is acquired
91+
...
92+
93+
"""
94+
95+
def __init__(self) -> None:
96+
self._waiters: Optional[collections.deque] = None
97+
self._locked = False
98+
99+
def __repr__(self) -> str:
100+
res = super().__repr__()
101+
extra = "locked" if self._locked else "unlocked"
102+
if self._waiters:
103+
extra = f"{extra}, waiters:{len(self._waiters)}"
104+
return f"<{res[1:-1]} [{extra}]>"
105+
106+
def locked(self) -> bool:
107+
"""Return True if lock is acquired."""
108+
return self._locked
109+
110+
async def acquire(self) -> bool:
111+
"""Acquire a lock.
112+
113+
This method blocks until the lock is unlocked, then sets it to
114+
locked and returns True.
115+
"""
116+
# Implement fair scheduling, where thread always waits
117+
# its turn. Jumping the queue if all are cancelled is an optimization.
118+
if not self._locked and (
119+
self._waiters is None or all(w.cancelled() for w in self._waiters)
120+
):
121+
self._locked = True
122+
return True
123+
124+
if self._waiters is None:
125+
self._waiters = collections.deque()
126+
fut = self._get_loop().create_future()
127+
self._waiters.append(fut)
128+
129+
try:
130+
try:
131+
await fut
132+
finally:
133+
self._waiters.remove(fut)
134+
except exceptions.CancelledError:
135+
# Currently the only exception designed be able to occur here.
136+
137+
# Ensure the lock invariant: If lock is not claimed (or about
138+
# to be claimed by us) and there is a Task in waiters,
139+
# ensure that the Task at the head will run.
140+
if not self._locked:
141+
self._wake_up_first()
142+
raise
143+
144+
# assert self._locked is False
145+
self._locked = True
146+
return True
147+
148+
def release(self) -> None:
149+
"""Release a lock.
150+
151+
When the lock is locked, reset it to unlocked, and return.
152+
If any other tasks are blocked waiting for the lock to become
153+
unlocked, allow exactly one of them to proceed.
154+
155+
When invoked on an unlocked lock, a RuntimeError is raised.
156+
157+
There is no return value.
158+
"""
159+
if self._locked:
160+
self._locked = False
161+
self._wake_up_first()
162+
else:
163+
raise RuntimeError("Lock is not acquired.")
164+
165+
def _wake_up_first(self) -> None:
166+
"""Ensure that the first waiter will wake up."""
167+
if not self._waiters:
168+
return
169+
try:
170+
fut = next(iter(self._waiters))
171+
except StopIteration:
172+
return
173+
174+
# .done() means that the waiter is already set to wake up.
175+
if not fut.done():
176+
fut.set_result(True)
177+
178+
179+
class Condition(_ContextManagerMixin, _LoopBoundMixin):
180+
"""Asynchronous equivalent to threading.Condition.
181+
182+
This class implements condition variable objects. A condition variable
183+
allows one or more tasks to wait until they are notified by another
184+
task.
185+
186+
A new Lock object is created and used as the underlying lock.
187+
"""
188+
189+
def __init__(self, lock: Optional[Lock] = None) -> None:
190+
if lock is None:
191+
lock = Lock()
192+
193+
self._lock = lock
194+
# Export the lock's locked(), acquire() and release() methods.
195+
self.locked = lock.locked
196+
self.acquire = lock.acquire
197+
self.release = lock.release
198+
199+
self._waiters: collections.deque = collections.deque()
200+
201+
def __repr__(self) -> str:
202+
res = super().__repr__()
203+
extra = "locked" if self.locked() else "unlocked"
204+
if self._waiters:
205+
extra = f"{extra}, waiters:{len(self._waiters)}"
206+
return f"<{res[1:-1]} [{extra}]>"
207+
208+
async def wait(self) -> bool:
209+
"""Wait until notified.
210+
211+
If the calling task has not acquired the lock when this
212+
method is called, a RuntimeError is raised.
213+
214+
This method releases the underlying lock, and then blocks
215+
until it is awakened by a notify() or notify_all() call for
216+
the same condition variable in another task. Once
217+
awakened, it re-acquires the lock and returns True.
218+
219+
This method may return spuriously,
220+
which is why the caller should always
221+
re-check the state and be prepared to wait() again.
222+
"""
223+
if not self.locked():
224+
raise RuntimeError("cannot wait on un-acquired lock")
225+
226+
fut = self._get_loop().create_future()
227+
self.release()
228+
try:
229+
try:
230+
self._waiters.append(fut)
231+
try:
232+
await fut
233+
return True
234+
finally:
235+
self._waiters.remove(fut)
236+
237+
finally:
238+
# Must re-acquire lock even if wait is cancelled.
239+
# We only catch CancelledError here, since we don't want any
240+
# other (fatal) errors with the future to cause us to spin.
241+
err = None
242+
while True:
243+
try:
244+
await self.acquire()
245+
break
246+
except exceptions.CancelledError as e:
247+
err = e
248+
249+
if err is not None:
250+
try:
251+
raise err # Re-raise most recent exception instance.
252+
finally:
253+
err = None # Break reference cycles.
254+
except BaseException:
255+
# Any error raised out of here _may_ have occurred after this Task
256+
# believed to have been successfully notified.
257+
# Make sure to notify another Task instead. This may result
258+
# in a "spurious wakeup", which is allowed as part of the
259+
# Condition Variable protocol.
260+
self._notify(1)
261+
raise
262+
263+
async def wait_for(self, predicate: Any) -> Coroutine:
264+
"""Wait until a predicate becomes true.
265+
266+
The predicate should be a callable whose result will be
267+
interpreted as a boolean value. The method will repeatedly
268+
wait() until it evaluates to true. The final predicate value is
269+
the return value.
270+
"""
271+
result = predicate()
272+
while not result:
273+
await self.wait()
274+
result = predicate()
275+
return result
276+
277+
def notify(self, n: int = 1) -> None:
278+
"""By default, wake up one task waiting on this condition, if any.
279+
If the calling task has not acquired the lock when this method
280+
is called, a RuntimeError is raised.
281+
282+
This method wakes up n of the tasks waiting for the condition
283+
variable; if fewer than n are waiting, they are all awoken.
284+
285+
Note: an awakened task does not actually return from its
286+
wait() call until it can reacquire the lock. Since notify() does
287+
not release the lock, its caller should.
288+
"""
289+
if not self.locked():
290+
raise RuntimeError("cannot notify on un-acquired lock")
291+
self._notify(n)
292+
293+
def _notify(self, n: int) -> None:
294+
idx = 0
295+
for fut in self._waiters:
296+
if idx >= n:
297+
break
298+
299+
if not fut.done():
300+
idx += 1
301+
fut.set_result(False)
302+
303+
def notify_all(self) -> None:
304+
"""Wake up all tasks waiting on this condition. This method acts
305+
like notify(), but wakes up all waiting tasks instead of one. If the
306+
calling task has not acquired the lock when this method is called,
307+
a RuntimeError is raised.
308+
"""
309+
self.notify(len(self._waiters))

pymongo/asynchronous/pool.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,6 @@ def __init__(
991991
self.conns: collections.deque = collections.deque()
992992
self.active_contexts: set[_CancellationContext] = set()
993993
self.lock = _async_create_lock()
994-
self.size_cond = _async_create_condition(self.lock)
995994
self._max_connecting_cond = _async_create_condition(self.lock)
996995
self.active_sockets = 0
997996
# Monotonically increasing connection ID required for CMAP Events.
@@ -1018,13 +1017,15 @@ def __init__(
10181017
# The first portion of the wait queue.
10191018
# Enforces: maxPoolSize
10201019
# Also used for: clearing the wait queue
1020+
self.size_cond = _async_create_condition(self.lock)
10211021
self.requests = 0
10221022
self.max_pool_size = self.opts.max_pool_size
10231023
if not self.max_pool_size:
10241024
self.max_pool_size = float("inf")
10251025
# The second portion of the wait queue.
10261026
# Enforces: maxConnecting
10271027
# Also used for: clearing the wait queue
1028+
self._max_connecting_cond = _async_create_condition(self.lock)
10281029
self._max_connecting = self.opts.max_connecting
10291030
self._pending = 0
10301031
self._client_id = client_id

0 commit comments

Comments
 (0)