diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index d8dbf140e..0ccbaef3e 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -9,7 +9,7 @@ import uuid from collections.abc import Iterable, Iterator, Mapping from types import TracebackType -from typing import Any +from typing import Any, Optional from ..exceptions import ( ConcurrencyError, @@ -216,11 +216,14 @@ def __iter__(self) -> Iterator[Data]: """ try: while True: - yield self.recv() + # can't receive None if peek=False (default) + yield self.recv() # type: ignore[misc] except ConnectionClosedOK: return - def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data: + def recv( + self, timeout: float | None = None, decode: bool | None = None + ) -> Optional[Data]: """ Receive the next message. @@ -267,8 +270,16 @@ def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data :meth:`recv_streaming` concurrently. """ + # Setup to handle the special case of a zero timeout + peek = timeout == 0.0 + if peek: + # can't provide specific timeout if using timeout param to express + # peeking, default to no timeout in such case + timeout = None + + # Attempt to receive a message try: - return self.recv_messages.get(timeout, decode) + return self.recv_messages.get(timeout, decode, peek) except EOFError: pass # fallthrough diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index 98490797f..a80003299 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -3,7 +3,7 @@ import codecs import queue import threading -from typing import Any, Callable, Iterable, Iterator +from typing import Any, Callable, Iterable, Iterator, Optional from ..exceptions import ConcurrencyError from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame @@ -105,11 +105,18 @@ def reset_queue(self, frames: Iterable[Frame]) -> None: for frame in queued: # pragma: no cover self.frames.put(frame) - def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: + def get( + self, + timeout: float | None = None, + decode: bool | None = None, + peek: bool = False, + ) -> Optional[Data]: """ Read the next message. - :meth:`get` returns a single :class:`str` or :class:`bytes`. + :meth:`get` returns a single :class:`str` or :class:`bytes`, or + :obj:`None` if the parameter ``peek`` is :obj:`True` and no message + is available. If the message is fragmented, :meth:`get` waits until the last frame is received, then it reassembles the message and returns it. To receive @@ -121,6 +128,9 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: decode: :obj:`False` disables UTF-8 decoding of text frames and returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of binary frames and returns :class:`str`. + peek: If :obj:`True`, :meth:`get` returns :obj:`None` immediately + if no message is available, or will finish receiving a + message and return it, respecting ``timeout``. Raises: EOFError: If the stream of frames has ended. @@ -140,6 +150,9 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: # until get() fetches a complete message or times out. try: + if peek and self.frames.empty(): + return None + deadline = Deadline(timeout) # First frame diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index aa445498c..b2427ec8e 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -173,6 +173,15 @@ def test_recv_fragmented_binary(self): self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) self.assertEqual(self.connection.recv(), b"\x01\x02\xfe\xff") + def test_recv_none_timeout_zero(self): + """recv with timeout=0 returns None when there is no message.""" + self.assertIsNone(self.connection.recv(timeout=0)) + + def test_recv_msg_timeout_zero(self): + """recv with timeout=0 returns message when there is one.""" + self.remote_connection.send("😀") + self.assertEqual(self.connection.recv(timeout=0), "😀") + def test_recv_connection_closed_ok(self): """recv raises ConnectionClosedOK after a normal closure.""" self.remote_connection.close()