@@ -226,14 +226,16 @@ impl RSGIHTTPProtocol {
226226#[ pyclass( frozen, module = "granian._granian" ) ]
227227pub ( crate ) struct RSGIWebsocketTransport {
228228 rt : RuntimeRef ,
229+ dg : Arc < Notify > ,
229230 tx : Arc < AsyncMutex < Option < WSTxStream > > > ,
230231 rx : Arc < AsyncMutex < WSRxStream > > ,
231232}
232233
233234impl RSGIWebsocketTransport {
234- pub fn new ( rt : RuntimeRef , tx : Arc < AsyncMutex < Option < WSTxStream > > > , rx : WSRxStream ) -> Self {
235+ pub fn new ( rt : RuntimeRef , dg : Arc < Notify > , tx : Arc < AsyncMutex < Option < WSTxStream > > > , rx : WSRxStream ) -> Self {
235236 Self {
236237 rt,
238+ dg,
237239 tx,
238240 rx : Arc :: new ( AsyncMutex :: new ( rx) ) ,
239241 }
@@ -244,9 +246,15 @@ impl RSGIWebsocketTransport {
244246impl RSGIWebsocketTransport {
245247 fn receive < ' p > ( & self , py : Python < ' p > ) -> PyResult < Bound < ' p , PyAny > > {
246248 let transport = self . rx . clone ( ) ;
249+ let dg = self . dg . clone ( ) ;
250+
247251 future_into_py_futlike ( self . rt . clone ( ) , py, async move {
248252 if let Ok ( mut stream) = transport. try_lock ( ) {
249- while let Some ( recv) = stream. next ( ) . await {
253+ while let Some ( recv) = tokio:: select! {
254+ biased;
255+ recv = stream. next( ) => recv,
256+ ( ) = dg. notified( ) => Some ( Err ( tokio_tungstenite:: tungstenite:: Error :: ConnectionClosed ) ) ,
257+ } {
250258 match recv {
251259 Ok ( Message :: Ping ( _) | Message :: Pong ( _) ) => { }
252260 Ok ( message) => return FutureResultToPy :: RSGIWSMessage ( message) ,
@@ -297,6 +305,7 @@ impl RSGIWebsocketTransport {
297305pub ( crate ) struct RSGIWebsocketProtocol {
298306 rt : RuntimeRef ,
299307 tx : Mutex < Option < oneshot:: Sender < WebsocketDetachedTransport > > > ,
308+ disconnect_guard : Arc < Notify > ,
300309 websocket : Arc < AsyncMutex < HyperWebsocket > > ,
301310 upgrade : RwLock < Option < UpgradeData > > ,
302311 transport : Arc < AsyncMutex < Option < WSTxStream > > > ,
@@ -308,10 +317,12 @@ impl RSGIWebsocketProtocol {
308317 tx : oneshot:: Sender < WebsocketDetachedTransport > ,
309318 websocket : HyperWebsocket ,
310319 upgrade : UpgradeData ,
320+ disconnect_guard : Arc < Notify > ,
311321 ) -> Self {
312322 Self {
313323 rt,
314324 tx : Mutex :: new ( Some ( tx) ) ,
325+ disconnect_guard,
315326 websocket : Arc :: new ( AsyncMutex :: new ( websocket) ) ,
316327 upgrade : RwLock :: new ( Some ( upgrade) ) ,
317328 transport : Arc :: new ( AsyncMutex :: new ( None ) ) ,
@@ -341,9 +352,11 @@ impl RSGIWebsocketProtocol {
341352
342353 fn accept < ' p > ( & self , py : Python < ' p > ) -> PyResult < Bound < ' p , PyAny > > {
343354 let rth = self . rt . clone ( ) ;
355+ let dg = self . disconnect_guard . clone ( ) ;
344356 let mut upgrade = self . upgrade . write ( ) . unwrap ( ) . take ( ) . unwrap ( ) ;
345357 let transport = self . websocket . clone ( ) ;
346358 let itransport = self . transport . clone ( ) ;
359+
347360 future_into_py_futlike ( self . rt . clone ( ) , py, async move {
348361 let mut ws = transport. lock ( ) . await ;
349362 match upgrade. send ( None , None , None ) . await {
@@ -354,11 +367,7 @@ impl RSGIWebsocketProtocol {
354367 let mut guard = itransport. lock ( ) . await ;
355368 * guard = Some ( stx) ;
356369 }
357- FutureResultToPy :: RSGIWSAccept ( RSGIWebsocketTransport :: new (
358- rth. clone ( ) ,
359- itransport. clone ( ) ,
360- srx,
361- ) )
370+ FutureResultToPy :: RSGIWSAccept ( RSGIWebsocketTransport :: new ( rth, dg, itransport, srx) )
362371 }
363372 _ => FutureResultToPy :: Err ( error_proto ! ( ) ) ,
364373 } ,
0 commit comments