Skip to content

Commit d10a246

Browse files
JaeHyuckSagi0baro
andauthored
Ensure to send ASGI websocket.disconnect after server-initiated close (#801)
* Fix websocket.disconnect not delivered after server-initiated close Signed-off-by: JaeHyuck Sa <jaehyuck.sa.dev@gmail.com> * Refactor code * Reduce tests flakyness --------- Signed-off-by: JaeHyuck Sa <jaehyuck.sa.dev@gmail.com> Co-authored-by: Giovanni Barillari <giovanni.barillari@sentry.io>
1 parent 6736727 commit d10a246

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

src/asgi/io.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,11 @@ impl ASGIWebsocketProtocol {
558558
}
559559
}
560560
}
561+
562+
if closed.load(atomic::Ordering::Acquire) {
563+
return FutureResultToPy::ASGIWSMessage(Message::Close(None));
564+
}
565+
561566
FutureResultToPy::Err(error_flow!("Transport not initialized or closed"))
562567
})
563568
}

tests/apps/asgi.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,17 @@ async def ws_push(scope, receive, send):
138138
pass
139139

140140

141+
async def ws_close(scope, receive, send):
142+
await receive()
143+
await send({'type': 'websocket.accept'})
144+
msg = await receive()
145+
fpath = msg.get('text') or msg['bytes'].decode('utf8')
146+
await send({'type': 'websocket.close', 'code': 1000})
147+
msg = await receive()
148+
if msg['type'] == 'websocket.disconnect':
149+
pathlib.Path(fpath).touch()
150+
151+
141152
async def err_app(scope, receive, send):
142153
1 / 0
143154

@@ -208,6 +219,7 @@ def app(scope, receive, send):
208219
'/ws_rejectc': ws_reject_custom,
209220
'/ws_info': ws_info,
210221
'/ws_echo': ws_echo,
222+
'/ws_close': ws_close,
211223
'/ws_push': ws_push,
212224
'/err_app': err_app,
213225
'/err_proto/type': err_proto_msg,

tests/test_ws.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import json
23
import os
34

@@ -32,6 +33,25 @@ async def test_reject(server, runtime_mode):
3233
assert exc.value.response.status_code == 403
3334

3435

36+
@pytest.mark.asyncio
37+
@pytest.mark.parametrize('runtime_mode', ['mt', 'st'])
38+
async def test_asgi_server_close(asgi_server, runtime_mode, tmp_path):
39+
target = tmp_path / 'ws_result'
40+
41+
async with asgi_server(runtime_mode) as port:
42+
async with websockets.connect(f'ws://localhost:{port}/ws_close') as ws:
43+
await ws.send(str(target.resolve()))
44+
try:
45+
await ws.recv()
46+
except Exception:
47+
pass
48+
49+
# reduce flakyness
50+
await asyncio.sleep(0.1)
51+
52+
assert target.exists()
53+
54+
3555
@pytest.mark.asyncio
3656
@pytest.mark.parametrize('runtime_mode', ['mt', 'st'])
3757
async def test_asgi_reject_custom(asgi_server, runtime_mode):

0 commit comments

Comments
 (0)