Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/reference/types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ Types

.. autodata:: Data

.. autodata:: BytesLike

.. autodata:: DataLike

.. autodata:: LoggerLike

.. autodata:: StatusLike
Expand Down
12 changes: 6 additions & 6 deletions src/websockets/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
ConnectionClosedOK,
ProtocolError,
)
from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode
from ..frames import DATA_OPCODES, CloseCode, Frame, Opcode
from ..http11 import Request, Response
from ..protocol import CLOSED, OPEN, Event, Protocol, State
from ..typing import Data, LoggerLike, Subprotocol
from ..typing import BytesLike, Data, DataLike, LoggerLike, Subprotocol
from .compatibility import (
TimeoutError,
aiter,
Expand Down Expand Up @@ -402,7 +402,7 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data

async def send(
self,
message: Data | Iterable[Data] | AsyncIterable[Data],
message: DataLike | Iterable[DataLike] | AsyncIterable[DataLike],
text: bool | None = None,
) -> None:
"""
Expand Down Expand Up @@ -657,7 +657,7 @@ async def wait_closed(self) -> None:
"""
await asyncio.shield(self.connection_lost_waiter)

async def ping(self, data: Data | None = None) -> Awaitable[float]:
async def ping(self, data: DataLike | None = None) -> Awaitable[float]:
"""
Send a Ping_.

Expand Down Expand Up @@ -710,7 +710,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]:
self.protocol.send_ping(data)
return pong_waiter

async def pong(self, data: Data = b"") -> None:
async def pong(self, data: DataLike = b"") -> None:
"""
Send a Pong_.

Expand Down Expand Up @@ -1134,7 +1134,7 @@ def eof_received(self) -> None:

def broadcast(
connections: Iterable[Connection],
message: Data,
message: DataLike,
raise_exceptions: bool = False,
) -> None:
"""
Expand Down
7 changes: 5 additions & 2 deletions src/websockets/asyncio/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ async def get(self, decode: bool | None = None) -> Data:
finally:
self.get_in_progress = False

# This converts frame.data to bytes when it's a bytearray.
data = b"".join(frame.data for frame in frames)
if decode:
return data.decode()
Expand Down Expand Up @@ -243,7 +244,8 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
decoder = UTF8Decoder()
yield decoder.decode(frame.data, frame.fin)
else:
yield frame.data
# Convert to bytes when frame.data is a bytearray.
yield bytes(frame.data)

# Following frames, for fragmented messages
while not frame.fin:
Expand All @@ -257,7 +259,8 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]:
if decode:
yield decoder.decode(frame.data, frame.fin)
else:
yield frame.data
# Convert to bytes when frame.data is a bytearray.
yield bytes(frame.data)

self.get_in_progress = False

Expand Down
4 changes: 3 additions & 1 deletion src/websockets/extensions/permessage_deflate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
PayloadTooBig,
ProtocolError,
)
from ..typing import ExtensionName, ExtensionParameter
from ..typing import BytesLike, ExtensionName, ExtensionParameter
from .base import ClientExtensionFactory, Extension, ServerExtensionFactory


Expand Down Expand Up @@ -129,6 +129,7 @@ def decode(
# Uncompress data. Protect against zip bombs by preventing zlib from
# decompressing more than max_length bytes (except when the limit is
# disabled with max_size = None).
data: BytesLike
if frame.fin and len(frame.data) < 2044:
# Profiling shows that appending four bytes, which makes a copy, is
# faster than calling decompress() again when data is less than 2kB.
Expand Down Expand Up @@ -182,6 +183,7 @@ def encode(self, frame: frames.Frame) -> frames.Frame:
)

# Compress data.
data: BytesLike
data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH)
if frame.fin:
# Sync flush generates between 5 or 6 bytes, ending with the bytes
Expand Down
13 changes: 7 additions & 6 deletions src/websockets/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Callable

from .exceptions import PayloadTooBig, ProtocolError
from .typing import BytesLike


try:
Expand Down Expand Up @@ -118,9 +119,6 @@ class CloseCode(enum.IntEnum):
}


BytesLike = bytes, bytearray, memoryview


@dataclasses.dataclass
class Frame:
"""
Expand All @@ -140,7 +138,7 @@ class Frame:
"""

opcode: Opcode
data: bytes | bytearray | memoryview
data: BytesLike
fin: bool = True
rsv1: bool = False
rsv2: bool = False
Expand Down Expand Up @@ -202,7 +200,7 @@ def __str__(self) -> str:
@classmethod
def parse(
cls,
read_exact: Callable[[int], Generator[None, None, bytes]],
read_exact: Callable[[int], Generator[None, None, bytes | bytearray]],
*,
mask: bool,
max_size: int | None = None,
Expand Down Expand Up @@ -324,6 +322,7 @@ def serialize(
output.write(mask_bytes)

# Prepare the data.
data: BytesLike
if mask:
data = apply_mask(self.data, mask_bytes)
else:
Expand Down Expand Up @@ -383,7 +382,7 @@ def __str__(self) -> str:
return result

@classmethod
def parse(cls, data: bytes) -> Close:
def parse(cls, data: BytesLike) -> Close:
"""
Parse the payload of a close frame.

Expand All @@ -395,6 +394,8 @@ def parse(cls, data: bytes) -> Close:
UnicodeDecodeError: If the reason isn't valid UTF-8.

"""
if isinstance(data, memoryview):
raise AssertionError("only compressed outgoing frames use memoryview")
if len(data) >= 2:
(code,) = struct.unpack("!H", data[:2])
reason = data[2:].decode()
Expand Down
27 changes: 14 additions & 13 deletions src/websockets/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
MAX_BODY_SIZE = int(os.environ.get("WEBSOCKETS_MAX_BODY_SIZE", "1_048_576")) # 1 MiB


def d(value: bytes) -> str:
def d(value: bytes | bytearray) -> str:
"""
Decode a bytestring for interpolating into an error message.

Expand Down Expand Up @@ -102,7 +102,7 @@ def exception(self) -> Exception | None: # pragma: no cover
@classmethod
def parse(
cls,
read_line: Callable[[int], Generator[None, None, bytes]],
read_line: Callable[[int], Generator[None, None, bytes | bytearray]],
) -> Generator[None, None, Request]:
"""
Parse a WebSocket handshake request.
Expand Down Expand Up @@ -194,7 +194,7 @@ class Response:
status_code: int
reason_phrase: str
headers: Headers
body: bytes = b""
body: bytes | bytearray = b""

_exception: Exception | None = None

Expand All @@ -210,9 +210,9 @@ def exception(self) -> Exception | None: # pragma: no cover
@classmethod
def parse(
cls,
read_line: Callable[[int], Generator[None, None, bytes]],
read_exact: Callable[[int], Generator[None, None, bytes]],
read_to_eof: Callable[[int], Generator[None, None, bytes]],
read_line: Callable[[int], Generator[None, None, bytes | bytearray]],
read_exact: Callable[[int], Generator[None, None, bytes | bytearray]],
read_to_eof: Callable[[int], Generator[None, None, bytes | bytearray]],
proxy: bool = False,
) -> Generator[None, None, Response]:
"""
Expand Down Expand Up @@ -276,6 +276,7 @@ def parse(

headers = yield from parse_headers(read_line)

body: bytes | bytearray
if proxy:
body = b""
else:
Expand All @@ -299,8 +300,8 @@ def serialize(self) -> bytes:


def parse_line(
read_line: Callable[[int], Generator[None, None, bytes]],
) -> Generator[None, None, bytes]:
read_line: Callable[[int], Generator[None, None, bytes | bytearray]],
) -> Generator[None, None, bytes | bytearray]:
"""
Parse a single line.

Expand All @@ -326,7 +327,7 @@ def parse_line(


def parse_headers(
read_line: Callable[[int], Generator[None, None, bytes]],
read_line: Callable[[int], Generator[None, None, bytes | bytearray]],
) -> Generator[None, None, Headers]:
"""
Parse HTTP headers.
Expand Down Expand Up @@ -379,10 +380,10 @@ def parse_headers(
def read_body(
status_code: int,
headers: Headers,
read_line: Callable[[int], Generator[None, None, bytes]],
read_exact: Callable[[int], Generator[None, None, bytes]],
read_to_eof: Callable[[int], Generator[None, None, bytes]],
) -> Generator[None, None, bytes]:
read_line: Callable[[int], Generator[None, None, bytes | bytearray]],
read_exact: Callable[[int], Generator[None, None, bytes | bytearray]],
read_to_eof: Callable[[int], Generator[None, None, bytes | bytearray]],
) -> Generator[None, None, bytes | bytearray]:
# https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3

# Since websockets only does GET requests (no HEAD, no CONNECT), all
Expand Down
9 changes: 4 additions & 5 deletions src/websockets/legacy/framing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from .. import extensions, frames
from ..exceptions import PayloadTooBig, ProtocolError
from ..frames import BytesLike
from ..typing import Data
from ..typing import BytesLike, DataLike


try:
Expand All @@ -19,7 +18,7 @@
class Frame(NamedTuple):
fin: bool
opcode: frames.Opcode
data: bytes
data: BytesLike
rsv1: bool = False
rsv2: bool = False
rsv3: bool = False
Expand Down Expand Up @@ -147,7 +146,7 @@ def write(
write(self.new_frame.serialize(mask=mask, extensions=extensions))


def prepare_data(data: Data) -> tuple[int, bytes]:
def prepare_data(data: DataLike) -> tuple[int, BytesLike]:
"""
Convert a string or byte-like object to an opcode and a bytes-like object.

Expand All @@ -171,7 +170,7 @@ def prepare_data(data: Data) -> tuple[int, bytes]:
raise TypeError("data must be str or bytes-like")


def prepare_ctrl(data: Data) -> bytes:
def prepare_ctrl(data: DataLike) -> bytes:
"""
Convert a string or byte-like object to bytes.

Expand Down
Loading