Skip to content

Commit f1da4f4

Browse files
committed
Start adding N:M concurrency building blocks
1 parent a5eeb67 commit f1da4f4

File tree

1 file changed

+157
-0
lines changed

1 file changed

+157
-0
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Copyright 2020-present Michael Hall
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# This used to include CPython code, some minor performance losses have been
16+
# taken to not tightly include upstream code
17+
"""Building blocks used to create thread-safe multi-event loop async things."""
18+
19+
from __future__ import annotations
20+
21+
import asyncio
22+
import heapq
23+
import threading
24+
from collections.abc import Callable, Generator
25+
26+
from . import _typings as t
27+
28+
29+
def _wake_nm_waiter(fut: asyncio.Future[None]) -> None:
30+
if not fut.done():
31+
fut.set_result(None)
32+
33+
34+
class _NMWaiter:
35+
__slots__ = ("_future", "_loop", "_thread_id", "_val")
36+
37+
def __init__(self, val: int = 1, /) -> None:
38+
self._loop = asyncio.get_running_loop()
39+
self._future: asyncio.Future[None] = self._loop.create_future()
40+
self._thread_id: int = threading.get_ident()
41+
self._val = val
42+
43+
def __lt__(self, other: object) -> bool:
44+
if isinstance(other, _NMWaiter):
45+
return self._val < other._val
46+
return NotImplemented
47+
48+
def __init_subclass__(cls) -> t.Never:
49+
msg = "Don't subclass this"
50+
raise RuntimeError(msg)
51+
52+
__final__ = True
53+
54+
def __await__(self) -> Generator[t.Any, None, None]:
55+
if threading.get_ident() == self._thread_id:
56+
return self._future.__await__()
57+
58+
msg = "Attempted to await future from the wrong thread"
59+
raise RuntimeError(msg)
60+
61+
def wake(self) -> None:
62+
self._loop.call_soon_threadsafe(_wake_nm_waiter, self._future)
63+
64+
def cancelled(self):
65+
return self._future.cancelled()
66+
67+
def done(self):
68+
return self._future.done()
69+
70+
def add_done_callback(self, cb: Callable[[asyncio.Future[None]], None]) -> None:
71+
self._loop.call_soon_threadsafe(self._future.add_done_callback, cb)
72+
73+
74+
class _WrappedRLock:
75+
__slots__ = ("_lock",)
76+
77+
def __init__(self, lock: threading.RLock) -> None:
78+
self._lock = lock
79+
80+
async def __aenter__(self) -> None:
81+
acquired = False
82+
while not acquired:
83+
acquired = self._lock.acquire(blocking=False)
84+
await asyncio.sleep(0)
85+
86+
async def __aexit__(self, *dont_care: object) -> None:
87+
self._lock.release()
88+
89+
90+
class NMSemaphoreBase:
91+
def __init__(self, value: int = 1, /) -> None:
92+
self._waiters: list[_NMWaiter] = []
93+
self._lock = _WrappedRLock(threading.RLock())
94+
self._value: int = value
95+
96+
def _get_priority(self) -> int:
97+
# This and value are the only differences between a semaphore,
98+
# lock, and priority semaphore, and priority lock
99+
# TODO:
100+
# - split this before it is in a public importable namespace.
101+
# - simplify aquire for locks
102+
return 0 - self._value
103+
104+
def __repr__(self) -> str:
105+
res = super().__repr__()
106+
extra = "locked" if self.__locked() else f"unlocked, value:{self._value}"
107+
if self._waiters:
108+
extra = f"{extra}, waiters:{len(self._waiters)}"
109+
return f"<{res[1:-1]} [{extra}]>"
110+
111+
def __locked(self) -> bool:
112+
return self._value == 0 or (any(not w.cancelled() for w in (self._waiters or ())))
113+
114+
async def __aexit__(self, *dont_care: object) -> None:
115+
async with self._lock:
116+
self._value += 1
117+
await self._maybe_wake()
118+
119+
async def __aenter__(self) -> None:
120+
async with self._lock:
121+
if not self.__locked():
122+
self._value -= 1
123+
return
124+
125+
waiter = _NMWaiter(self._get_priority())
126+
heapq.heappush(self._waiters, waiter)
127+
128+
try:
129+
await waiter
130+
# we don't remove ourselves
131+
# we need to maintain the heap invariants
132+
except asyncio.CancelledError:
133+
if waiter.done() and not waiter.cancelled():
134+
self._value += 1
135+
raise
136+
137+
finally:
138+
await self._maybe_wake()
139+
140+
async def _maybe_wake(self) -> None:
141+
async with self._lock:
142+
while self._value > 0 and self._waiters:
143+
next_waiter = heapq.heappop(self._waiters)
144+
145+
if not (next_waiter.done() or next_waiter.cancelled()):
146+
self._value -= 1
147+
next_waiter.wake()
148+
149+
while self._waiters:
150+
# cleanup maintaining heap invariant
151+
# This will only fully empty the heap when
152+
# all things remaining in the heap after waking tasks in
153+
# above section are all done.
154+
next_waiter = heapq.heappop(self._waiters)
155+
if not (next_waiter.done() or next_waiter.cancelled()):
156+
heapq.heappush(self._waiters, next_waiter)
157+
break

0 commit comments

Comments
 (0)