Skip to content

Commit b761d41

Browse files
authored
fix client hang when connection lost just after remote closes (#182)
Bad ordering: 1. Remote close 2. TCP closed 3. Local confirms => no ConnectionClosed raised, client hangs forever
1 parent 89f2749 commit b761d41

File tree

2 files changed

+76
-37
lines changed

2 files changed

+76
-37
lines changed

tests/test_connection.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
import trio
3939
import trustme
4040
import wsproto
41-
from trio.testing import memory_stream_pair
41+
from trio.testing import memory_stream_pair, memory_stream_pump
4242
from wsproto.events import CloseConnection
4343

4444
try:
@@ -1017,3 +1017,41 @@ async def test_finalization_dropped_exception(echo_server, autojump_clock):
10171017
await trio.sleep_forever()
10181018
finally:
10191019
raise ValueError
1020+
1021+
1022+
async def test_remote_close_rude():
1023+
"""
1024+
Bad ordering:
1025+
1. Remote close
1026+
2. TCP closed
1027+
3. Local confirms
1028+
=> no ConnectionClosed raised, client hangs forever
1029+
"""
1030+
client_stream, server_stream = memory_stream_pair()
1031+
1032+
async def client():
1033+
client_conn = await wrap_client_stream(nursery, client_stream, HOST, RESOURCE)
1034+
assert not client_conn.closed
1035+
await client_conn.send_message('Hello from client!')
1036+
with pytest.raises(ConnectionClosed):
1037+
await client_conn.get_message()
1038+
1039+
async def server():
1040+
server_request = await wrap_server_stream(nursery, server_stream)
1041+
server_ws = await server_request.accept()
1042+
assert not server_ws.closed
1043+
msg = await server_ws.get_message()
1044+
assert msg == "Hello from client!"
1045+
1046+
# disable pumping so that the CloseConnection arrives at the same time as the stream closure
1047+
server_stream.send_stream.send_all_hook = None
1048+
await server_ws._send(CloseConnection(code=1000, reason=None))
1049+
await server_stream.aclose()
1050+
1051+
# pump the messages over
1052+
memory_stream_pump(server_stream.send_stream, client_stream.receive_stream)
1053+
1054+
1055+
async with trio.open_nursery() as nursery:
1056+
nursery.start_soon(server)
1057+
nursery.start_soon(client)

trio_websocket/_impl.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1210,45 +1210,46 @@ async def _reader_task(self):
12101210
except ConnectionClosed:
12111211
self._reader_running = False
12121212

1213-
while self._reader_running:
1214-
# Process events.
1215-
for event in self._wsproto.events():
1216-
event_type = type(event)
1213+
async with self._send_channel:
1214+
while self._reader_running:
1215+
# Process events.
1216+
for event in self._wsproto.events():
1217+
event_type = type(event)
1218+
try:
1219+
handler = handlers[event_type]
1220+
logger.debug('%s received event: %s', self,
1221+
event_type)
1222+
await handler(event)
1223+
except KeyError:
1224+
logger.warning('%s received unknown event type: "%s"', self,
1225+
event_type)
1226+
except ConnectionClosed:
1227+
self._reader_running = False
1228+
break
1229+
1230+
# Get network data.
12171231
try:
1218-
handler = handlers[event_type]
1219-
logger.debug('%s received event: %s', self,
1220-
event_type)
1221-
await handler(event)
1222-
except KeyError:
1223-
logger.warning('%s received unknown event type: "%s"', self,
1224-
event_type)
1225-
except ConnectionClosed:
1226-
self._reader_running = False
1232+
data = await self._stream.receive_some(RECEIVE_BYTES)
1233+
except (trio.BrokenResourceError, trio.ClosedResourceError):
1234+
await self._abort_web_socket()
12271235
break
1228-
1229-
# Get network data.
1230-
try:
1231-
data = await self._stream.receive_some(RECEIVE_BYTES)
1232-
except (trio.BrokenResourceError, trio.ClosedResourceError):
1233-
await self._abort_web_socket()
1234-
break
1235-
if len(data) == 0:
1236-
logger.debug('%s received zero bytes (connection closed)',
1237-
self)
1238-
# If TCP closed before WebSocket, then record it as an abnormal
1239-
# closure.
1236+
if len(data) == 0:
1237+
logger.debug('%s received zero bytes (connection closed)',
1238+
self)
1239+
# If TCP closed before WebSocket, then record it as an abnormal
1240+
# closure.
1241+
if self._wsproto.state != ConnectionState.CLOSED:
1242+
await self._abort_web_socket()
1243+
break
1244+
logger.debug('%s received %d bytes', self, len(data))
12401245
if self._wsproto.state != ConnectionState.CLOSED:
1241-
await self._abort_web_socket()
1242-
break
1243-
logger.debug('%s received %d bytes', self, len(data))
1244-
if self._wsproto.state != ConnectionState.CLOSED:
1245-
try:
1246-
self._wsproto.receive_data(data)
1247-
except wsproto.utilities.RemoteProtocolError as err:
1248-
logger.debug('%s remote protocol error: %s', self, err)
1249-
if err.event_hint:
1250-
await self._send(err.event_hint)
1251-
await self._close_stream()
1246+
try:
1247+
self._wsproto.receive_data(data)
1248+
except wsproto.utilities.RemoteProtocolError as err:
1249+
logger.debug('%s remote protocol error: %s', self, err)
1250+
if err.event_hint:
1251+
await self._send(err.event_hint)
1252+
await self._close_stream()
12521253

12531254
logger.debug('%s reader task finished', self)
12541255

0 commit comments

Comments
 (0)