1414
1515from __future__ import annotations
1616
17+ import asyncio
18+ import logging
1719import select
1820import signal
1921import socket
2022import sys
21- from collections .abc import Callable
23+ import threading
24+ from collections .abc import Callable , Coroutine
2225from types import FrameType
23- from typing import Any
26+ from typing import Any , Literal
2427
2528type SignalCallback = Callable [[signal .Signals ], Any ]
2629type StartStopCall = Callable [[], Any ]
2730type _HANDLER = Callable [[int , FrameType | None ], Any ] | int | signal .Handlers | None
2831
32+ log = logging .getLogger (__name__ )
33+
2934__all__ = ["SignalService" ]
3035
3136possible = "SIGINT" , "SIGTERM" , "SIGBREAK" , "SIGHUP"
@@ -41,7 +46,13 @@ def __init__(self, *, startup: list[StartStopCall], signal_cbs: list[SignalCallb
4146 self ._cbs : list [SignalCallback ] = signal_cbs
4247 self ._joins : list [StartStopCall ] = joins
4348
44- def run (self ):
49+ def add_async_lifecycle (self , lifecycle : AsyncLifecycle [Any ], / ) -> None :
50+ st , cb , j = lifecycle .get_service_args ()
51+ self ._startup .append (st )
52+ self ._cbs .append (cb )
53+ self ._joins .append (j )
54+
55+ def run (self ) -> None :
4556 ss , cs = socket .socketpair ()
4657 ss .setblocking (False )
4758 cs .setblocking (False )
@@ -69,3 +80,125 @@ def run(self):
6980
7081 for sig , original in zip (actual , original_handlers ):
7182 signal .signal (sig , original )
83+
84+
85+ type CtxSync [Context ] = Callable [[Context ], Any ]
86+ type CtxAsync [Context ] = Callable [[Context ], Coroutine [Any , None , None ]]
87+
88+
89+ class AsyncLifecycle [Context ]:
90+ """Intended to be used with the above."""
91+
92+ def __init__ (
93+ self ,
94+ context : Context ,
95+ loop : asyncio .AbstractEventLoop ,
96+ signal_queue : asyncio .Queue [signal .Signals ],
97+ sync_setup : CtxSync [Context ],
98+ async_main : CtxAsync [Context ],
99+ async_cleanup : CtxAsync [Context ],
100+ sync_cleanup : CtxSync [Context ],
101+ timeout : float = 0.1 ,
102+ ) -> None :
103+ self .context = context
104+ self .loop : asyncio .AbstractEventLoop = loop
105+ self .signal_queue : asyncio .Queue [signal .Signals ] = signal_queue
106+ self .sync_setup : CtxSync [Context ] = sync_setup
107+ self .async_main : CtxAsync [Context ] = async_main
108+ self .async_cleanup : CtxAsync [Context ] = async_cleanup
109+ self .sync_cleanup : CtxSync [Context ] = sync_cleanup
110+ self .timeout : float = timeout
111+ self .thread : threading .Thread | None | Literal [False ] = None
112+
113+ def get_service_args (self ) -> tuple [StartStopCall , SignalCallback , StartStopCall ]:
114+ def runner () -> None :
115+ loop = self .loop
116+ loop .set_task_factory (asyncio .eager_task_factory )
117+ asyncio .set_event_loop (loop )
118+
119+ self .sync_setup (self .context )
120+
121+ async def sig_h () -> None :
122+ await self .signal_queue .get ()
123+ log .info ("Recieved shutdown signal, shutting down worker." )
124+ loop .call_soon (self .loop .stop )
125+
126+ async def wrapped_main () -> None :
127+ t1 = asyncio .create_task (self .async_main (self .context ))
128+ t2 = asyncio .create_task (sig_h ())
129+ await asyncio .gather (t1 , t2 )
130+
131+ def stop_when_done (fut : asyncio .Future [None ]) -> None :
132+ self .loop .stop ()
133+
134+ fut = asyncio .ensure_future (wrapped_main (), loop = self .loop )
135+ try :
136+ fut .add_done_callback (stop_when_done )
137+ self .loop .run_forever ()
138+ finally :
139+ fut .remove_done_callback (stop_when_done )
140+
141+ self .loop .run_until_complete (self .async_cleanup (self .context ))
142+
143+ tasks : set [asyncio .Task [Any ]] = {t for t in asyncio .all_tasks (loop ) if not t .done ()}
144+
145+ async def limited_finalization () -> None :
146+ _done , pending = await asyncio .wait (tasks , timeout = self .timeout )
147+ if not pending :
148+ log .debug ("All tasks finished" )
149+ return
150+
151+ for task in tasks :
152+ task .cancel ()
153+
154+ _done , pending = await asyncio .wait (tasks , timeout = self .timeout )
155+
156+ for task in pending :
157+ name = task .get_name ()
158+ coro = task .get_coro ()
159+ log .warning ("Task %s wrapping coro %r did not exit properly" , name , coro )
160+
161+ if tasks :
162+ loop .run_until_complete (limited_finalization ())
163+ loop .run_until_complete (loop .shutdown_asyncgens ())
164+ loop .run_until_complete (loop .shutdown_default_executor ())
165+
166+ for task in tasks :
167+ try :
168+ if (exc := task .exception ()) is not None :
169+ loop .call_exception_handler (
170+ {
171+ "message" : "Unhandled exception in task during shutdown." ,
172+ "exception" : exc ,
173+ "task" : task ,
174+ }
175+ )
176+ except (asyncio .InvalidStateError , asyncio .CancelledError ):
177+ pass
178+
179+ asyncio .set_event_loop (None )
180+ loop .close ()
181+
182+ if not fut .cancelled ():
183+ fut .result ()
184+
185+ self .sync_cleanup (self .context )
186+
187+ def wrapped_run () -> None :
188+ if self .thread is not None :
189+ msg = "This isn't re-entrant"
190+ raise RuntimeError (msg )
191+ self .thread = threading .Thread (target = runner )
192+ self .thread .start ()
193+
194+ def join () -> None :
195+ if not self .thread :
196+ self .thread = False
197+ return
198+ self .thread .join ()
199+
200+ def sig (signal : signal .Signals ) -> None :
201+ self .loop .call_soon (self .signal_queue .put_nowait , signal )
202+ return
203+
204+ return wrapped_run , sig , join
0 commit comments