Skip to content

Commit 0062b6b

Browse files
authored
Handle cleanup of ws tasks on shutdown (#803)
1 parent 11df933 commit 0062b6b

File tree

9 files changed

+89
-19
lines changed

9 files changed

+89
-19
lines changed

src/asgi/callbacks.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ pub(crate) fn call_http(
171171
pub(crate) fn call_ws(
172172
cb: ArcCBScheduler,
173173
rt: RuntimeRef,
174+
disconnect_guard: Arc<Notify>,
174175
server_addr: SockAddr,
175176
client_addr: SockAddr,
176177
scheme: HTTPProto,
@@ -179,7 +180,7 @@ pub(crate) fn call_ws(
179180
upgrade: UpgradeData,
180181
) -> oneshot::Receiver<WebsocketDetachedTransport> {
181182
let (tx, rx) = oneshot::channel();
182-
let protocol = WebsocketProtocol::new(rt.clone(), tx, ws, upgrade);
183+
let protocol = WebsocketProtocol::new(rt.clone(), tx, ws, upgrade, disconnect_guard);
183184

184185
rt.spawn_blocking(move |py| {
185186
if let Ok(scope) = build_scope_ws(py, req, server_addr, client_addr, scheme)

src/asgi/http.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,15 @@ macro_rules! handle_request_with_ws {
8181
let (restx, mut resrx) = mpsc::channel(1);
8282
let (parts, _) = req.into_parts();
8383
let rth = rt.clone();
84+
let cancel_sig = Arc::new(Notify::new());
8485

85-
rt.spawn(async move {
86+
rt.spawn_cancellable(cancel_sig.clone(), async move {
8687
let tx_ref = restx.clone();
8788

8889
match $handler_ws(
8990
callback,
9091
rth,
92+
cancel_sig,
9193
server_addr,
9294
client_addr,
9395
scheme,

src/asgi/io.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ impl WebsocketDetachedTransport {
345345
pub(crate) struct ASGIWebsocketProtocol {
346346
rt: RuntimeRef,
347347
tx: Mutex<Option<oneshot::Sender<WebsocketDetachedTransport>>>,
348+
disconnect_guard: Arc<Notify>,
348349
websocket: Mutex<Option<HyperWebsocket>>,
349350
upgrade: Mutex<Option<UpgradeData>>,
350351
response_intent: Mutex<Option<(u16, HeaderMap)>>,
@@ -362,10 +363,12 @@ impl ASGIWebsocketProtocol {
362363
tx: oneshot::Sender<WebsocketDetachedTransport>,
363364
websocket: HyperWebsocket,
364365
upgrade: UpgradeData,
366+
disconnect_guard: Arc<Notify>,
365367
) -> Self {
366368
Self {
367369
rt,
368370
tx: Mutex::new(Some(tx)),
371+
disconnect_guard,
369372
websocket: Mutex::new(Some(websocket)),
370373
upgrade: Mutex::new(Some(upgrade)),
371374
response_intent: Mutex::new(None),
@@ -526,6 +529,7 @@ impl ASGIWebsocketProtocol {
526529
let accepted_ev = self.init_event.clone();
527530
let closed = self.closed.clone();
528531
let transport = self.ws_rx.clone();
532+
let guard_disconnect = self.disconnect_guard.clone();
529533

530534
future_into_py_futlike(self.rt.clone(), py, async move {
531535
if !accepted.load(atomic::Ordering::Acquire) {
@@ -534,7 +538,11 @@ impl ASGIWebsocketProtocol {
534538
}
535539

536540
if let Some(ws) = &mut *(transport.lock().await) {
537-
while let Some(recv) = ws.next().await {
541+
while let Some(recv) = tokio::select! {
542+
biased;
543+
recv = ws.next() => recv,
544+
() = guard_disconnect.notified() => Some(Err(tokio_tungstenite::tungstenite::Error::ConnectionClosed)),
545+
} {
538546
match recv {
539547
Ok(Message::Ping(_) | Message::Pong(_)) => {}
540548
Ok(message @ Message::Close(_)) => {

src/rsgi/callbacks.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,13 @@ pub(crate) fn call_http(
138138
pub(crate) fn call_ws(
139139
cb: ArcCBScheduler,
140140
rt: RuntimeRef,
141+
disconnect_guard: Arc<Notify>,
141142
ws: HyperWebsocket,
142143
upgrade: UpgradeData,
143144
scope: WebsocketScope,
144145
) -> oneshot::Receiver<WebsocketDetachedTransport> {
145146
let (tx, rx) = oneshot::channel();
146-
let protocol = WebsocketProtocol::new(rt.clone(), tx, ws, upgrade);
147+
let protocol = WebsocketProtocol::new(rt.clone(), tx, ws, upgrade, disconnect_guard);
147148

148149
rt.spawn_blocking(move |py| {
149150
if let Ok(watcher) = CallbackWatcherWebsocket::new(py, protocol, scope) {

src/rsgi/http.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,13 @@ macro_rules! handle_request_with_ws {
8282
let scope = build_scope!(WebsocketScope, server_addr, client_addr, parts, scheme);
8383
let (restx, mut resrx) = mpsc::channel(1);
8484
let rth = rt.clone();
85+
let cancel_sig = Arc::new(Notify::new());
8586

86-
rt.spawn(async move {
87+
rt.spawn_cancellable(cancel_sig.clone(), async move {
8788
let tx_ref = restx.clone();
8889

89-
match $handler_ws(callback, rth, ws, UpgradeData::new(res, restx), scope).await {
90+
match $handler_ws(callback, rth, cancel_sig, ws, UpgradeData::new(res, restx), scope).await
91+
{
9092
Ok((status, consumed, stream)) => match (consumed, stream) {
9193
(false, _) => {
9294
let _ = tx_ref

src/rsgi/io.rs

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -226,14 +226,16 @@ impl RSGIHTTPProtocol {
226226
#[pyclass(frozen, module = "granian._granian")]
227227
pub(crate) struct RSGIWebsocketTransport {
228228
rt: RuntimeRef,
229+
dg: Arc<Notify>,
229230
tx: Arc<AsyncMutex<Option<WSTxStream>>>,
230231
rx: Arc<AsyncMutex<WSRxStream>>,
231232
}
232233

233234
impl RSGIWebsocketTransport {
234-
pub fn new(rt: RuntimeRef, tx: Arc<AsyncMutex<Option<WSTxStream>>>, rx: WSRxStream) -> Self {
235+
pub fn new(rt: RuntimeRef, dg: Arc<Notify>, tx: Arc<AsyncMutex<Option<WSTxStream>>>, rx: WSRxStream) -> Self {
235236
Self {
236237
rt,
238+
dg,
237239
tx,
238240
rx: Arc::new(AsyncMutex::new(rx)),
239241
}
@@ -244,9 +246,15 @@ impl RSGIWebsocketTransport {
244246
impl RSGIWebsocketTransport {
245247
fn receive<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
246248
let transport = self.rx.clone();
249+
let dg = self.dg.clone();
250+
247251
future_into_py_futlike(self.rt.clone(), py, async move {
248252
if let Ok(mut stream) = transport.try_lock() {
249-
while let Some(recv) = stream.next().await {
253+
while let Some(recv) = tokio::select! {
254+
biased;
255+
recv = stream.next() => recv,
256+
() = dg.notified() => Some(Err(tokio_tungstenite::tungstenite::Error::ConnectionClosed)),
257+
} {
250258
match recv {
251259
Ok(Message::Ping(_) | Message::Pong(_)) => {}
252260
Ok(message) => return FutureResultToPy::RSGIWSMessage(message),
@@ -297,6 +305,7 @@ impl RSGIWebsocketTransport {
297305
pub(crate) struct RSGIWebsocketProtocol {
298306
rt: RuntimeRef,
299307
tx: Mutex<Option<oneshot::Sender<WebsocketDetachedTransport>>>,
308+
disconnect_guard: Arc<Notify>,
300309
websocket: Arc<AsyncMutex<HyperWebsocket>>,
301310
upgrade: RwLock<Option<UpgradeData>>,
302311
transport: Arc<AsyncMutex<Option<WSTxStream>>>,
@@ -308,10 +317,12 @@ impl RSGIWebsocketProtocol {
308317
tx: oneshot::Sender<WebsocketDetachedTransport>,
309318
websocket: HyperWebsocket,
310319
upgrade: UpgradeData,
320+
disconnect_guard: Arc<Notify>,
311321
) -> Self {
312322
Self {
313323
rt,
314324
tx: Mutex::new(Some(tx)),
325+
disconnect_guard,
315326
websocket: Arc::new(AsyncMutex::new(websocket)),
316327
upgrade: RwLock::new(Some(upgrade)),
317328
transport: Arc::new(AsyncMutex::new(None)),
@@ -341,9 +352,11 @@ impl RSGIWebsocketProtocol {
341352

342353
fn accept<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
343354
let rth = self.rt.clone();
355+
let dg = self.disconnect_guard.clone();
344356
let mut upgrade = self.upgrade.write().unwrap().take().unwrap();
345357
let transport = self.websocket.clone();
346358
let itransport = self.transport.clone();
359+
347360
future_into_py_futlike(self.rt.clone(), py, async move {
348361
let mut ws = transport.lock().await;
349362
match upgrade.send(None, None, None).await {
@@ -354,11 +367,7 @@ impl RSGIWebsocketProtocol {
354367
let mut guard = itransport.lock().await;
355368
*guard = Some(stx);
356369
}
357-
FutureResultToPy::RSGIWSAccept(RSGIWebsocketTransport::new(
358-
rth.clone(),
359-
itransport.clone(),
360-
srx,
361-
))
370+
FutureResultToPy::RSGIWSAccept(RSGIWebsocketTransport::new(rth, dg, itransport, srx))
362371
}
363372
_ => FutureResultToPy::Err(error_proto!()),
364373
},

src/runtime.rs

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ pub trait Runtime: Send + 'static {
3434
fn spawn_blocking<F>(&self, task: F)
3535
where
3636
F: FnOnce(Python) + Send + 'static;
37+
38+
fn spawn_cancellable<F>(&self, on_cancel: Arc<tokio::sync::Notify>, fut: F) -> Self::JoinHandle
39+
where
40+
F: Future<Output = ()> + Send + 'static;
3741
}
3842

3943
pub trait ContextExt: Runtime {
@@ -44,6 +48,7 @@ pub(crate) struct RuntimeWrapper {
4448
pub inner: tokio::runtime::Runtime,
4549
br: Arc<blocking::BlockingRunner>,
4650
pr: Arc<Py<PyAny>>,
51+
sig: Arc<tokio::sync::Notify>,
4752
}
4853

4954
impl RuntimeWrapper {
@@ -62,6 +67,7 @@ impl RuntimeWrapper {
6267
inner: default_runtime(blocking_threads),
6368
br: br.into(),
6469
pr: py_loop,
70+
sig: tokio::sync::Notify::new().into(),
6571
}
6672
}
6773

@@ -80,11 +86,17 @@ impl RuntimeWrapper {
8086
inner: rt,
8187
br: br.into(),
8288
pr: py_loop,
89+
sig: tokio::sync::Notify::new().into(),
8390
}
8491
}
8592

8693
pub fn handler(&self) -> RuntimeRef {
87-
RuntimeRef::new(self.inner.handle().clone(), self.br.clone(), self.pr.clone())
94+
RuntimeRef::new(
95+
self.inner.handle().clone(),
96+
self.br.clone(),
97+
self.pr.clone(),
98+
self.sig.clone(),
99+
)
88100
}
89101
}
90102

@@ -93,16 +105,27 @@ pub struct RuntimeRef {
93105
pub inner: tokio::runtime::Handle,
94106
innerb: Arc<blocking::BlockingRunner>,
95107
innerp: Arc<Py<PyAny>>,
108+
sig: Arc<tokio::sync::Notify>,
96109
}
97110

98111
impl RuntimeRef {
99-
pub fn new(rt: tokio::runtime::Handle, br: Arc<blocking::BlockingRunner>, pyloop: Arc<Py<PyAny>>) -> Self {
112+
pub fn new(
113+
rt: tokio::runtime::Handle,
114+
br: Arc<blocking::BlockingRunner>,
115+
pyloop: Arc<Py<PyAny>>,
116+
sig: Arc<tokio::sync::Notify>,
117+
) -> Self {
100118
Self {
101119
inner: rt,
102120
innerb: br,
103121
innerp: pyloop,
122+
sig,
104123
}
105124
}
125+
126+
pub fn close(&self) {
127+
self.sig.notify_waiters();
128+
}
106129
}
107130

108131
impl JoinError for tokio::task::JoinError {
@@ -129,6 +152,23 @@ impl Runtime for RuntimeRef {
129152
{
130153
_ = self.innerb.run(task);
131154
}
155+
156+
fn spawn_cancellable<F>(&self, on_cancel: Arc<tokio::sync::Notify>, fut: F) -> Self::JoinHandle
157+
where
158+
F: Future<Output = ()> + Send + 'static,
159+
{
160+
let sig = self.sig.clone();
161+
162+
self.inner.spawn(async move {
163+
tokio::select! {
164+
biased;
165+
() = fut => {},
166+
() = sig.notified() => {
167+
on_cancel.notify_one();
168+
}
169+
};
170+
})
171+
}
132172
}
133173

134174
impl ContextExt for RuntimeRef {
@@ -293,7 +333,7 @@ where
293333
}
294334

295335
#[allow(unused_must_use)]
296-
pub(crate) fn run_until_complete<F>(rt: RuntimeWrapper, event_loop: Bound<PyAny>, fut: F) -> PyResult<()>
336+
pub(crate) fn run_until_complete<F>(rt: &RuntimeWrapper, event_loop: Bound<PyAny>, fut: F) -> PyResult<()>
297337
where
298338
F: Future<Output = PyResult<()>> + Send + 'static,
299339
{

src/serve.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,12 @@ macro_rules! serve_fn {
8080
let wrk = crate::workers::Worker::new(ctx, acceptor, handler, rth, target, metrics.0);
8181
let tasks = wrk.tasks.clone();
8282

83-
let main_loop = crate::runtime::run_until_complete(rt, event_loop.clone(), async move {
83+
let main_loop = crate::runtime::run_until_complete(&rt, event_loop.clone(), async move {
8484
wrk.listen(srx, listener, backpressure).await;
8585

8686
log::info!("Stopping worker-{worker_id}");
8787

88+
wrk.rt.close();
8889
tasks.close();
8990
tasks.wait().await;
9091
mc_notify.notified().await;
@@ -93,6 +94,8 @@ macro_rules! serve_fn {
9394
Ok(())
9495
});
9596

97+
drop(rt);
98+
9699
if let Err(err) = main_loop {
97100
log::error!("{err}");
98101
std::process::exit(1);
@@ -174,6 +177,7 @@ macro_rules! serve_fn {
174177

175178
log::info!("Stopping worker-{} runtime-{}", worker_id, thread_id + 1);
176179

180+
wrk.rt.close();
177181
tasks.close();
178182
tasks.wait().await;
179183

@@ -212,7 +216,7 @@ macro_rules! serve_fn {
212216
mc_notify.notify_one();
213217
}
214218

215-
let main_loop = crate::runtime::run_until_complete(rtm, event_loop.clone(), async move {
219+
let main_loop = crate::runtime::run_until_complete(&rtm, event_loop.clone(), async move {
216220
let _ = pyrx.changed().await;
217221
stx.send(true).unwrap();
218222
log::info!("Stopping worker-{worker_id}");
@@ -223,6 +227,8 @@ macro_rules! serve_fn {
223227
Ok(())
224228
});
225229

230+
drop(rtm);
231+
226232
if let Err(err) = main_loop {
227233
log::error!("{err}");
228234
std::process::exit(1);
@@ -293,6 +299,7 @@ macro_rules! serve_fn {
293299

294300
log::info!("Stopping worker-{worker_id}");
295301

302+
wrk.rt.close();
296303
tasks.close();
297304
tasks.wait().await;
298305

src/workers.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ pub(crate) struct Worker<C, A, H, F, M> {
291291
ctx: C,
292292
acceptor: A,
293293
handler: H,
294-
rt: crate::runtime::RuntimeRef,
294+
pub rt: crate::runtime::RuntimeRef,
295295
pub tasks: tokio_util::task::TaskTracker,
296296
target: F,
297297
metrics: M,

0 commit comments

Comments
 (0)