Skip to content

Commit da577d0

Browse files
committed
wip lifecycle stuff
1 parent 2f9a530 commit da577d0

File tree

1 file changed

+136
-3
lines changed

1 file changed

+136
-3
lines changed

async_utils/sig_service.py

Lines changed: 136 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,23 @@
1414

1515
from __future__ import annotations
1616

17+
import asyncio
18+
import logging
1719
import select
1820
import signal
1921
import socket
2022
import sys
21-
from collections.abc import Callable
23+
import threading
24+
from collections.abc import Callable, Coroutine
2225
from types import FrameType
23-
from typing import Any
26+
from typing import Any, Literal
2427

2528
type SignalCallback = Callable[[signal.Signals], Any]
2629
type StartStopCall = Callable[[], Any]
2730
type _HANDLER = Callable[[int, FrameType | None], Any] | int | signal.Handlers | None
2831

32+
log = logging.getLogger(__name__)
33+
2934
__all__ = ["SignalService"]
3035

3136
possible = "SIGINT", "SIGTERM", "SIGBREAK", "SIGHUP"
@@ -41,7 +46,13 @@ def __init__(self, *, startup: list[StartStopCall], signal_cbs: list[SignalCallb
4146
self._cbs: list[SignalCallback] = signal_cbs
4247
self._joins: list[StartStopCall] = joins
4348

44-
def run(self):
49+
def add_async_lifecycle(self, lifecycle: AsyncLifecycle[Any], /) -> None:
50+
st, cb, j = lifecycle.get_service_args()
51+
self._startup.append(st)
52+
self._cbs.append(cb)
53+
self._joins.append(j)
54+
55+
def run(self) -> None:
4556
ss, cs = socket.socketpair()
4657
ss.setblocking(False)
4758
cs.setblocking(False)
@@ -69,3 +80,125 @@ def run(self):
6980

7081
for sig, original in zip(actual, original_handlers):
7182
signal.signal(sig, original)
83+
84+
85+
type CtxSync[Context] = Callable[[Context], Any]
86+
type CtxAsync[Context] = Callable[[Context], Coroutine[Any, None, None]]
87+
88+
89+
class AsyncLifecycle[Context]:
90+
"""Intended to be used with the above."""
91+
92+
def __init__(
93+
self,
94+
context: Context,
95+
loop: asyncio.AbstractEventLoop,
96+
signal_queue: asyncio.Queue[signal.Signals],
97+
sync_setup: CtxSync[Context],
98+
async_main: CtxAsync[Context],
99+
async_cleanup: CtxAsync[Context],
100+
sync_cleanup: CtxSync[Context],
101+
timeout: float = 0.1,
102+
) -> None:
103+
self.context = context
104+
self.loop: asyncio.AbstractEventLoop = loop
105+
self.signal_queue: asyncio.Queue[signal.Signals] = signal_queue
106+
self.sync_setup: CtxSync[Context] = sync_setup
107+
self.async_main: CtxAsync[Context] = async_main
108+
self.async_cleanup: CtxAsync[Context] = async_cleanup
109+
self.sync_cleanup: CtxSync[Context] = sync_cleanup
110+
self.timeout: float = timeout
111+
self.thread: threading.Thread | None | Literal[False] = None
112+
113+
def get_service_args(self) -> tuple[StartStopCall, SignalCallback, StartStopCall]:
114+
def runner() -> None:
115+
loop = self.loop
116+
loop.set_task_factory(asyncio.eager_task_factory)
117+
asyncio.set_event_loop(loop)
118+
119+
self.sync_setup(self.context)
120+
121+
async def sig_h() -> None:
122+
await self.signal_queue.get()
123+
log.info("Recieved shutdown signal, shutting down worker.")
124+
loop.call_soon(self.loop.stop)
125+
126+
async def wrapped_main() -> None:
127+
t1 = asyncio.create_task(self.async_main(self.context))
128+
t2 = asyncio.create_task(sig_h())
129+
await asyncio.gather(t1, t2)
130+
131+
def stop_when_done(fut: asyncio.Future[None]) -> None:
132+
self.loop.stop()
133+
134+
fut = asyncio.ensure_future(wrapped_main(), loop=self.loop)
135+
try:
136+
fut.add_done_callback(stop_when_done)
137+
self.loop.run_forever()
138+
finally:
139+
fut.remove_done_callback(stop_when_done)
140+
141+
self.loop.run_until_complete(self.async_cleanup(self.context))
142+
143+
tasks: set[asyncio.Task[Any]] = {t for t in asyncio.all_tasks(loop) if not t.done()}
144+
145+
async def limited_finalization() -> None:
146+
_done, pending = await asyncio.wait(tasks, timeout=self.timeout)
147+
if not pending:
148+
log.debug("All tasks finished")
149+
return
150+
151+
for task in tasks:
152+
task.cancel()
153+
154+
_done, pending = await asyncio.wait(tasks, timeout=self.timeout)
155+
156+
for task in pending:
157+
name = task.get_name()
158+
coro = task.get_coro()
159+
log.warning("Task %s wrapping coro %r did not exit properly", name, coro)
160+
161+
if tasks:
162+
loop.run_until_complete(limited_finalization())
163+
loop.run_until_complete(loop.shutdown_asyncgens())
164+
loop.run_until_complete(loop.shutdown_default_executor())
165+
166+
for task in tasks:
167+
try:
168+
if (exc := task.exception()) is not None:
169+
loop.call_exception_handler(
170+
{
171+
"message": "Unhandled exception in task during shutdown.",
172+
"exception": exc,
173+
"task": task,
174+
}
175+
)
176+
except (asyncio.InvalidStateError, asyncio.CancelledError):
177+
pass
178+
179+
asyncio.set_event_loop(None)
180+
loop.close()
181+
182+
if not fut.cancelled():
183+
fut.result()
184+
185+
self.sync_cleanup(self.context)
186+
187+
def wrapped_run() -> None:
188+
if self.thread is not None:
189+
msg = "This isn't re-entrant"
190+
raise RuntimeError(msg)
191+
self.thread = threading.Thread(target=runner)
192+
self.thread.start()
193+
194+
def join() -> None:
195+
if not self.thread:
196+
self.thread = False
197+
return
198+
self.thread.join()
199+
200+
def sig(signal: signal.Signals) -> None:
201+
self.loop.call_soon(self.signal_queue.put_nowait, signal)
202+
return
203+
204+
return wrapped_run, sig, join

0 commit comments

Comments
 (0)