From 0215b10eeaf7ae8d77165a3a0d27f8332f81d645 Mon Sep 17 00:00:00 2001 From: Benjamin Thomas Schwertfeger Date: Wed, 27 Nov 2024 20:46:05 +0100 Subject: [PATCH 1/4] Resolve "Default `max_queue` blocks websocket cancellation with high traffic" --- src/websockets/asyncio/connection.py | 3 +++ src/websockets/asyncio/messages.py | 12 +++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index e5c350fe2..19321274a 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -913,6 +913,9 @@ async def send_context( if wait_for_close: try: async with asyncio_timeout_at(self.close_deadline): + self.recv_messages.cancelling = True + if self.recv_messages.paused: + self.recv_messages.resume() await asyncio.shield(self.connection_lost_waiter) except TimeoutError: # There's no risk to overwrite another error because diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index e6d1d31cc..c68e5d960 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -113,9 +113,13 @@ def __init__( # pragma: no cover # This flag prevents concurrent calls to get() by user code. self.get_in_progress = False + # This flag marks a soon cancellation + self.cancelling = False + # This flag marks the end of the connection. self.closed = False + async def get(self, decode: bool | None = None) -> Data: """ Read the next message. @@ -138,6 +142,8 @@ async def get(self, decode: bool | None = None) -> Data: :meth:`get_iter` concurrently. """ + if self.cancelling: + return if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True @@ -201,6 +207,8 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: :meth:`get_iter` concurrently. """ + if self.cancelling: + return if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True @@ -251,6 +259,8 @@ def put(self, frame: Frame) -> None: EOFError: If the stream of frames has ended. """ + if self.cancelling: + return if self.closed: raise EOFError("stream of frames ended") @@ -283,7 +293,7 @@ def close(self) -> None: """ End the stream of frames. - Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, or :meth:`put` is safe. They will raise :exc:`EOFError`. """ From 179740cf72b77d0b8f31952c7a98c2c94f194d52 Mon Sep 17 00:00:00 2001 From: Benjamin Thomas Schwertfeger Date: Wed, 27 Nov 2024 20:48:34 +0100 Subject: [PATCH 2/4] add repro.py script for debugging purposes --- repro.py | 84 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 repro.py diff --git a/repro.py b/repro.py new file mode 100644 index 000000000..8fda0e9c9 --- /dev/null +++ b/repro.py @@ -0,0 +1,84 @@ +import asyncio +import logging +import json +from websockets.asyncio.client import connect + +# import debugpy + +# # Allow VS Code to attach +# debugpy.listen(("0.0.0.0", 5678)) # Use the port you've specified +# print("Waiting for debugger to attach...") +# debugpy.wait_for_client() + +logging.getLogger(__name__) +logging.basicConfig( + format="%(asctime)s %(module)s:%(lineno)d %(levelname)8s | %(message)s", + datefmt="%Y/%m/%d %H:%M:%S", + level=logging.DEBUG, +) + +class MyClient: + + def __init__(self): + self.keep_alive = True + + async def run(self): + + async with connect( + f"wss://ws.kraken.com/v2", + ping_interval=30, + # max_queue=None, # having this enabled doesn't cause problems + ) as socket: + await socket.send( + json.dumps( + { + "method": "subscribe", + "params": { + "channel": "book", + "symbol": [ + "BTC/USD", + "DOT/USD", + "ETH/USD", + "MATIC/USD", + "BTC/EUR", + "DOT/EUR", + "ETH/EUR", + "XLM/USD", + "XLM/EUR", + ], + "depth": 100 + }, + } + ) + ) + + while self.keep_alive: + try: + _message = await asyncio.wait_for(socket.recv(), timeout=10) + except TimeoutError: + pass + except asyncio.CancelledError: + self.keep_alive = False + else: + try: + message = json.loads(_message) + except ValueError: + pass + + async def __aenter__(self): + self.task: asyncio.Task = asyncio.create_task(self.run()) + return self + + async def __aexit__(self, *args, **kwargs): + self.keep_alive = False + if hasattr(self, "task") and not self.task.done(): + await self.task + + +async def main(): + async with MyClient(): + await asyncio.sleep(3) + + +if __name__ == "__main__": + asyncio.run(main()) From f4d0b2060bd9503ee67ad834ad032d5c6b2d7a06 Mon Sep 17 00:00:00 2001 From: Benjamin Thomas Schwertfeger Date: Thu, 28 Nov 2024 17:51:50 +0100 Subject: [PATCH 3/4] cleanup --- src/websockets/asyncio/connection.py | 4 +--- src/websockets/asyncio/messages.py | 25 +++++++++++++++---------- src/websockets/sync/messages.py | 2 +- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 19321274a..c947f5a44 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -913,9 +913,7 @@ async def send_context( if wait_for_close: try: async with asyncio_timeout_at(self.close_deadline): - self.recv_messages.cancelling = True - if self.recv_messages.paused: - self.recv_messages.resume() + self.recv_messages.prepare_close() await asyncio.shield(self.connection_lost_waiter) except TimeoutError: # There's no risk to overwrite another error because diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index c68e5d960..5a7bfbb8f 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -113,13 +113,12 @@ def __init__( # pragma: no cover # This flag prevents concurrent calls to get() by user code. self.get_in_progress = False - # This flag marks a soon cancellation - self.cancelling = False + # This flag marks a soon end of the connection. + self.closing = False # This flag marks the end of the connection. self.closed = False - async def get(self, decode: bool | None = None) -> Data: """ Read the next message. @@ -142,8 +141,6 @@ async def get(self, decode: bool | None = None) -> Data: :meth:`get_iter` concurrently. """ - if self.cancelling: - return if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True @@ -207,8 +204,6 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: :meth:`get_iter` concurrently. """ - if self.cancelling: - return if self.get_in_progress: raise ConcurrencyError("get() or get_iter() is already running") self.get_in_progress = True @@ -259,13 +254,13 @@ def put(self, frame: Frame) -> None: EOFError: If the stream of frames has ended. """ - if self.cancelling: - return if self.closed: raise EOFError("stream of frames ended") self.frames.put(frame) - self.maybe_pause() + + if not self.closing: + self.maybe_pause() def maybe_pause(self) -> None: """Pause the writer if queue is above the high water mark.""" @@ -289,6 +284,16 @@ def maybe_resume(self) -> None: self.paused = False self.resume() + def prepare_close(self) -> None: + """ + Prepare to close by ensuring that no more messages will be processed. + """ + self.closing = True + + # Resuming the writer to avoid deadlocks + if self.paused: + self.resume() + def close(self) -> None: """ End the stream of frames. diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 98490797f..7d21a4faf 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -291,7 +291,7 @@ def close(self) -> None: """ End the stream of frames. - Callling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, or :meth:`put` is safe. They will raise :exc:`EOFError`. """ From c4b3d2bad27ff5f970b3625e49f82695b19ca212 Mon Sep 17 00:00:00 2001 From: Benjamin Thomas Schwertfeger Date: Thu, 28 Nov 2024 18:11:05 +0100 Subject: [PATCH 4/4] add a unit test --- src/websockets/asyncio/messages.py | 1 + tests/asyncio/test_connection.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 5a7bfbb8f..936adf631 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -292,6 +292,7 @@ def prepare_close(self) -> None: # Resuming the writer to avoid deadlocks if self.paused: + self.paused = False self.resume() def close(self) -> None: diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 5a0b61bf7..95f4e56a1 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -806,6 +806,27 @@ async def test_close_preserves_queued_messages(self): self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") self.assertIsNone(exc.__cause__) + async def test_close_preserves_queued_messages_gt_max_queue(self): + """ + close preserves messages buffered in the assembler, even if they + exceed the default buffer size. + """ + + for _ in range(100): + await self.remote_connection.send("😀") + + await self.connection.close() + + for _ in range(100): + self.assertEqual(await self.connection.recv(), "😀") + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + async def test_close_idempotency(self): """close does nothing if the connection is already closed.""" await self.connection.close()