1616import concurrent .futures as cf
1717import threading
1818from collections import deque
19+ from collections .abc import Generator
1920
2021from . import _typings as t
2122
2223# TODO: pick what public namespace to re-export this from.
24+ # TODO: write a test for the odd behavior noticed with prior version in discord
25+
26+
27+ class _Waiter :
28+ __slots__ = ("future" ,)
29+
30+ def __init__ (self , future : cf .Future [None ], / ) -> None :
31+ self .future : cf .Future [None ] = future
32+
33+ def cancelled (self ) -> bool :
34+ return self .future .cancelled ()
35+
36+ def done (self ) -> bool :
37+ return self .future .done ()
38+
39+ def set_result (self , val : None ) -> None :
40+ self .future .set_result (val )
41+
42+ def __await__ (self ) -> Generator [t .Any , t .Any , None ]:
43+ f = asyncio .wrap_future (self .future )
44+ return (yield from f .__await__ ())
45+
46+ __final__ = True
47+
48+ def __init_subclass__ (cls ) -> t .Never :
49+ msg = "Don't subclass this"
50+ raise RuntimeError (msg )
2351
2452
2553class AsyncLock :
@@ -32,44 +60,65 @@ def __init_subclass__(cls) -> t.Never:
3260 __final__ = True
3361
3462 def __init__ (self ) -> None :
35- self ._waiters : deque [cf .Future [None ]] = deque ()
63+ self ._waiters : deque [_Waiter ] | None = None
64+ self ._lockv : bool = False
3665 self ._internal_lock : threading .RLock = threading .RLock ()
37- self ._locked : bool = False
3866
39- async def __aenter__ (self , / ) -> None :
67+ def __locked (self ) -> bool :
4068 with self ._internal_lock :
41- if not self ._locked and (all (w .cancelled () for w in self ._waiters )):
42- self ._locked = True
43- return
69+ return self ._lockv or (any (not w .cancelled () for w in (self ._waiters or ())))
70+
71+ async def __aenter__ (self ) -> None :
72+ await self .__acquire ()
73+
74+ async def __aexit__ (self , * dont_care : object ) -> t .Literal [False ]:
75+ self .__release ()
76+ return False
77+
78+ async def __acquire (self ) -> bool :
79+ with self ._internal_lock :
80+ if not self .__locked ():
81+ self ._lockv = True
82+ return True
83+
84+ with self ._internal_lock :
85+ if self ._waiters is None :
86+ self ._waiters = deque ()
4487
4588 fut : cf .Future [None ] = cf .Future ()
4689
90+ waiter = _Waiter (fut )
91+
4792 with self ._internal_lock :
48- self ._waiters .append (fut )
93+ self ._waiters .append (waiter )
4994
5095 try :
51- await asyncio .wrap_future (fut )
52- except (asyncio .CancelledError , cf .CancelledError ):
53- with self ._internal_lock :
54- if self ._locked :
55- self ._maybe_wake ()
96+ await waiter
97+ except asyncio .CancelledError :
98+ if fut .done () and not fut .cancelled ():
99+ self ._lockv = False
100+ raise
101+
56102 finally :
57- self ._waiters .remove (fut )
103+ self ._maybe_wake ()
104+ return True
58105
59- async def __aexit__ (self , * _dont_care : object ) -> t . Literal [ False ] :
106+ def _maybe_wake (self ) -> None :
60107 with self ._internal_lock :
61- if self ._locked :
62- self ._locked = False
63- self ._maybe_wake ()
108+ while (not self ._lockv ) and self ._waiters :
109+ next_waiter = self ._waiters .popleft ()
64110
65- return False
111+ if not (next_waiter .done () or next_waiter .cancelled ()):
112+ self ._lockv = True
113+ next_waiter .set_result (None )
66114
67- def _maybe_wake (self ) -> None :
115+ while self ._waiters :
116+ next_waiter = self ._waiters .popleft ()
117+ if not (next_waiter .done () or next_waiter .cancelled ()):
118+ self ._waiters .appendleft (next_waiter )
119+ break
120+
121+ def __release (self ) -> None :
68122 with self ._internal_lock :
69- if self ._waiters :
70- try :
71- fut = next (iter (self ._waiters ))
72- except StopIteration :
73- return
74- if not fut .done ():
75- fut .set_result (None )
123+ self ._lockv = False
124+ self ._maybe_wake ()
0 commit comments