Skip to content

Commit 2711f5a

Browse files
committed
...
1 parent ec68cb2 commit 2711f5a

File tree

2 files changed

+97
-26
lines changed

2 files changed

+97
-26
lines changed

_misc/_l.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import asyncio
2+
import random
3+
4+
from async_utils._simple_lock import AsyncLock # noqa: PLC2701
5+
from async_utils.bg_loop import threaded_loop
6+
7+
8+
async def check(lock: AsyncLock):
9+
async with lock:
10+
v = random.random()
11+
print(await asyncio.sleep(v, v), flush=True) # noqa: T201
12+
13+
14+
async def amain():
15+
lock = AsyncLock()
16+
with threaded_loop() as tl1, threaded_loop() as tl2:
17+
tsks = {loop.run(check(lock)) for loop in (tl1, tl2) for _ in range(10)}
18+
await asyncio.gather(*tsks)
19+
20+
21+
if __name__ == "__main__":
22+
asyncio.run(amain())

src/async_utils/_simple_lock.py

Lines changed: 75 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,38 @@
1616
import concurrent.futures as cf
1717
import threading
1818
from collections import deque
19+
from collections.abc import Generator
1920

2021
from . import _typings as t
2122

2223
# TODO: pick what public namespace to re-export this from.
24+
# TODO: write a test for the odd behavior noticed with prior version in discord
25+
26+
27+
class _Waiter:
28+
__slots__ = ("future",)
29+
30+
def __init__(self, future: cf.Future[None], /) -> None:
31+
self.future: cf.Future[None] = future
32+
33+
def cancelled(self) -> bool:
34+
return self.future.cancelled()
35+
36+
def done(self) -> bool:
37+
return self.future.done()
38+
39+
def set_result(self, val: None) -> None:
40+
self.future.set_result(val)
41+
42+
def __await__(self) -> Generator[t.Any, t.Any, None]:
43+
f = asyncio.wrap_future(self.future)
44+
return (yield from f.__await__())
45+
46+
__final__ = True
47+
48+
def __init_subclass__(cls) -> t.Never:
49+
msg = "Don't subclass this"
50+
raise RuntimeError(msg)
2351

2452

2553
class AsyncLock:
@@ -32,44 +60,65 @@ def __init_subclass__(cls) -> t.Never:
3260
__final__ = True
3361

3462
def __init__(self) -> None:
35-
self._waiters: deque[cf.Future[None]] = deque()
63+
self._waiters: deque[_Waiter] | None = None
64+
self._lockv: bool = False
3665
self._internal_lock: threading.RLock = threading.RLock()
37-
self._locked: bool = False
3866

39-
async def __aenter__(self, /) -> None:
67+
def __locked(self) -> bool:
4068
with self._internal_lock:
41-
if not self._locked and (all(w.cancelled() for w in self._waiters)):
42-
self._locked = True
43-
return
69+
return self._lockv or (any(not w.cancelled() for w in (self._waiters or ())))
70+
71+
async def __aenter__(self) -> None:
72+
await self.__acquire()
73+
74+
async def __aexit__(self, *dont_care: object) -> t.Literal[False]:
75+
self.__release()
76+
return False
77+
78+
async def __acquire(self) -> bool:
79+
with self._internal_lock:
80+
if not self.__locked():
81+
self._lockv = True
82+
return True
83+
84+
with self._internal_lock:
85+
if self._waiters is None:
86+
self._waiters = deque()
4487

4588
fut: cf.Future[None] = cf.Future()
4689

90+
waiter = _Waiter(fut)
91+
4792
with self._internal_lock:
48-
self._waiters.append(fut)
93+
self._waiters.append(waiter)
4994

5095
try:
51-
await asyncio.wrap_future(fut)
52-
except (asyncio.CancelledError, cf.CancelledError):
53-
with self._internal_lock:
54-
if self._locked:
55-
self._maybe_wake()
96+
await waiter
97+
except asyncio.CancelledError:
98+
if fut.done() and not fut.cancelled():
99+
self._lockv = False
100+
raise
101+
56102
finally:
57-
self._waiters.remove(fut)
103+
self._maybe_wake()
104+
return True
58105

59-
async def __aexit__(self, *_dont_care: object) -> t.Literal[False]:
106+
def _maybe_wake(self) -> None:
60107
with self._internal_lock:
61-
if self._locked:
62-
self._locked = False
63-
self._maybe_wake()
108+
while (not self._lockv) and self._waiters:
109+
next_waiter = self._waiters.popleft()
64110

65-
return False
111+
if not (next_waiter.done() or next_waiter.cancelled()):
112+
self._lockv = True
113+
next_waiter.set_result(None)
66114

67-
def _maybe_wake(self) -> None:
115+
while self._waiters:
116+
next_waiter = self._waiters.popleft()
117+
if not (next_waiter.done() or next_waiter.cancelled()):
118+
self._waiters.appendleft(next_waiter)
119+
break
120+
121+
def __release(self) -> None:
68122
with self._internal_lock:
69-
if self._waiters:
70-
try:
71-
fut = next(iter(self._waiters))
72-
except StopIteration:
73-
return
74-
if not fut.done():
75-
fut.set_result(None)
123+
self._lockv = False
124+
self._maybe_wake()

0 commit comments

Comments
 (0)