From 36409237c9377fb8fc2fbb90f9f74192777b518c Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sat, 21 Sep 2024 19:04:08 +0200 Subject: [PATCH] Wait until state is CLOSED to acces close_exc. Fix #1449. --- docs/project/changelog.rst | 4 ++++ src/websockets/asyncio/connection.py | 14 +++++++++++--- src/websockets/sync/connection.py | 14 ++++++++++++-- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/docs/project/changelog.rst b/docs/project/changelog.rst index 456c15dac..615d3ab71 100644 --- a/docs/project/changelog.rst +++ b/docs/project/changelog.rst @@ -75,6 +75,10 @@ Bug fixes start the connection handler anymore when ``process_request`` or ``process_response`` returns an HTTP response. +* Fixed a bug in the :mod:`threading` implementation that could lead to + incorrect error reporting when closing a connection while + :meth:`~sync.connection.Connection.recv` is running. + 13.0.1 ------ diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 1b24f9af0..6af61a4a9 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -274,6 +274,8 @@ async def recv(self, decode: bool | None = None) -> Data: try: return await self.recv_messages.get(decode) except EOFError: + # Wait for the protocol state to be CLOSED before accessing close_exc. + await asyncio.shield(self.connection_lost_waiter) raise self.protocol.close_exc from self.recv_exc except ConcurrencyError: raise ConcurrencyError( @@ -329,6 +331,8 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data async for frame in self.recv_messages.get_iter(decode): yield frame except EOFError: + # Wait for the protocol state to be CLOSED before accessing close_exc. + await asyncio.shield(self.connection_lost_waiter) raise self.protocol.close_exc from self.recv_exc except ConcurrencyError: raise ConcurrencyError( @@ -864,6 +868,7 @@ async def send_context( # raise an exception. if raise_close_exc: self.close_transport() + # Wait for the protocol state to be CLOSED before accessing close_exc. await asyncio.shield(self.connection_lost_waiter) raise self.protocol.close_exc from original_exc @@ -926,11 +931,14 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: self.transport = transport def connection_lost(self, exc: Exception | None) -> None: - self.protocol.receive_eof() # receive_eof is idempotent + # Calling protocol.receive_eof() is safe because it's idempotent. + # This guarantees that the protocol state becomes CLOSED. + self.protocol.receive_eof() + assert self.protocol.state is CLOSED - # Abort recv() and pending pings with a ConnectionClosed exception. - # Set recv_exc first to get proper exception reporting. self.set_recv_exc(exc) + + # Abort recv() and pending pings with a ConnectionClosed exception. self.recv_messages.close() self.abort_pings() diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 65a7b63ed..77b488093 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -206,6 +206,8 @@ def recv(self, timeout: float | None = None) -> Data: try: return self.recv_messages.get(timeout) except EOFError: + # Wait for the protocol state to be CLOSED before accessing close_exc. + self.recv_events_thread.join() raise self.protocol.close_exc from self.recv_exc except ConcurrencyError: raise ConcurrencyError( @@ -240,6 +242,8 @@ def recv_streaming(self) -> Iterator[Data]: for frame in self.recv_messages.get_iter(): yield frame except EOFError: + # Wait for the protocol state to be CLOSED before accessing close_exc. + self.recv_events_thread.join() raise self.protocol.close_exc from self.recv_exc except ConcurrencyError: raise ConcurrencyError( @@ -629,8 +633,6 @@ def recv_events(self) -> None: self.logger.error("unexpected internal error", exc_info=True) with self.protocol_mutex: self.set_recv_exc(exc) - # We don't know where we crashed. Force protocol state to CLOSED. - self.protocol.state = CLOSED finally: # This isn't expected to raise an exception. self.close_socket() @@ -738,6 +740,7 @@ def send_context( # raise an exception. if raise_close_exc: self.close_socket() + # Wait for the protocol state to be CLOSED before accessing close_exc. self.recv_events_thread.join() raise self.protocol.close_exc from original_exc @@ -788,4 +791,11 @@ def close_socket(self) -> None: except OSError: pass # socket is already closed self.socket.close() + + # Calling protocol.receive_eof() is safe because it's idempotent. + # This guarantees that the protocol state becomes CLOSED. + self.protocol.receive_eof() + assert self.protocol.state is CLOSED + + # Abort recv() with a ConnectionClosed exception. self.recv_messages.close()