Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/websockets/sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import uuid
from collections.abc import Iterable, Iterator, Mapping
from types import TracebackType
from typing import Any
from typing import Any, Optional

from ..exceptions import (
ConcurrencyError,
Expand Down Expand Up @@ -216,11 +216,14 @@ def __iter__(self) -> Iterator[Data]:
"""
try:
while True:
yield self.recv()
# can't receive None if peek=False (default)
yield self.recv() # type: ignore[misc]
except ConnectionClosedOK:
return

def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data:
def recv(
self, timeout: float | None = None, decode: bool | None = None
) -> Optional[Data]:
"""
Receive the next message.

Expand Down Expand Up @@ -267,8 +270,16 @@ def recv(self, timeout: float | None = None, decode: bool | None = None) -> Data
:meth:`recv_streaming` concurrently.

"""
# Setup to handle the special case of a zero timeout
peek = timeout == 0.0
if peek:
# can't provide specific timeout if using timeout param to express
# peeking, default to no timeout in such case
timeout = None

# Attempt to receive a message
try:
return self.recv_messages.get(timeout, decode)
return self.recv_messages.get(timeout, decode, peek)
except EOFError:
pass
# fallthrough
Expand Down
19 changes: 16 additions & 3 deletions src/websockets/sync/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import codecs
import queue
import threading
from typing import Any, Callable, Iterable, Iterator
from typing import Any, Callable, Iterable, Iterator, Optional

from ..exceptions import ConcurrencyError
from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame
Expand Down Expand Up @@ -105,11 +105,18 @@ def reset_queue(self, frames: Iterable[Frame]) -> None:
for frame in queued: # pragma: no cover
self.frames.put(frame)

def get(self, timeout: float | None = None, decode: bool | None = None) -> Data:
def get(
self,
timeout: float | None = None,
decode: bool | None = None,
peek: bool = False,
) -> Optional[Data]:
"""
Read the next message.

:meth:`get` returns a single :class:`str` or :class:`bytes`.
:meth:`get` returns a single :class:`str` or :class:`bytes`, or
:obj:`None` if the parameter ``peek`` is :obj:`True` and no message
is available.

If the message is fragmented, :meth:`get` waits until the last frame is
received, then it reassembles the message and returns it. To receive
Expand All @@ -121,6 +128,9 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data:
decode: :obj:`False` disables UTF-8 decoding of text frames and
returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of
binary frames and returns :class:`str`.
peek: If :obj:`True`, :meth:`get` returns :obj:`None` immediately
if no message is available, or will finish receiving a
message and return it, respecting ``timeout``.

Raises:
EOFError: If the stream of frames has ended.
Expand All @@ -140,6 +150,9 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data:
# until get() fetches a complete message or times out.

try:
if peek and self.frames.empty():
return None

deadline = Deadline(timeout)

# First frame
Expand Down
9 changes: 9 additions & 0 deletions tests/sync/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ def test_recv_fragmented_binary(self):
self.remote_connection.send([b"\x01\x02", b"\xfe\xff"])
self.assertEqual(self.connection.recv(), b"\x01\x02\xfe\xff")

def test_recv_none_timeout_zero(self):
"""recv with timeout=0 returns None when there is no message."""
self.assertIsNone(self.connection.recv(timeout=0))

def test_recv_msg_timeout_zero(self):
"""recv with timeout=0 returns message when there is one."""
self.remote_connection.send("😀")
self.assertEqual(self.connection.recv(timeout=0), "😀")

def test_recv_connection_closed_ok(self):
"""recv raises ConnectionClosedOK after a normal closure."""
self.remote_connection.close()
Expand Down