diff --git a/docs/reference/types.rst b/docs/reference/types.rst index d249b9294..f39e1c2b7 100644 --- a/docs/reference/types.rst +++ b/docs/reference/types.rst @@ -5,6 +5,10 @@ Types .. autodata:: Data +.. autodata:: BytesLike + +.. autodata:: DataLike + .. autodata:: LoggerLike .. autodata:: StatusLike diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 6acced0c1..a13757075 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -19,10 +19,10 @@ ConnectionClosedOK, ProtocolError, ) -from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode +from ..frames import DATA_OPCODES, CloseCode, Frame, Opcode from ..http11 import Request, Response from ..protocol import CLOSED, OPEN, Event, Protocol, State -from ..typing import Data, LoggerLike, Subprotocol +from ..typing import BytesLike, Data, DataLike, LoggerLike, Subprotocol from .compatibility import ( TimeoutError, aiter, @@ -402,7 +402,7 @@ async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data async def send( self, - message: Data | Iterable[Data] | AsyncIterable[Data], + message: DataLike | Iterable[DataLike] | AsyncIterable[DataLike], text: bool | None = None, ) -> None: """ @@ -657,7 +657,7 @@ async def wait_closed(self) -> None: """ await asyncio.shield(self.connection_lost_waiter) - async def ping(self, data: Data | None = None) -> Awaitable[float]: + async def ping(self, data: DataLike | None = None) -> Awaitable[float]: """ Send a Ping_. @@ -710,7 +710,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]: self.protocol.send_ping(data) return pong_waiter - async def pong(self, data: Data = b"") -> None: + async def pong(self, data: DataLike = b"") -> None: """ Send a Pong_. @@ -1134,7 +1134,7 @@ def eof_received(self) -> None: def broadcast( connections: Iterable[Connection], - message: Data, + message: DataLike, raise_exceptions: bool = False, ) -> None: """ diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 1fd41811c..af26d5d7a 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -179,6 +179,7 @@ async def get(self, decode: bool | None = None) -> Data: finally: self.get_in_progress = False + # This converts frame.data to bytes when it's a bytearray. data = b"".join(frame.data for frame in frames) if decode: return data.decode() @@ -243,7 +244,8 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: decoder = UTF8Decoder() yield decoder.decode(frame.data, frame.fin) else: - yield frame.data + # Convert to bytes when frame.data is a bytearray. + yield bytes(frame.data) # Following frames, for fragmented messages while not frame.fin: @@ -257,7 +259,8 @@ async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: if decode: yield decoder.decode(frame.data, frame.fin) else: - yield frame.data + # Convert to bytes when frame.data is a bytearray. + yield bytes(frame.data) self.get_in_progress = False diff --git a/src/websockets/extensions/permessage_deflate.py b/src/websockets/extensions/permessage_deflate.py index 7e9e7a5dd..2bc63d799 100644 --- a/src/websockets/extensions/permessage_deflate.py +++ b/src/websockets/extensions/permessage_deflate.py @@ -13,7 +13,7 @@ PayloadTooBig, ProtocolError, ) -from ..typing import ExtensionName, ExtensionParameter +from ..typing import BytesLike, ExtensionName, ExtensionParameter from .base import ClientExtensionFactory, Extension, ServerExtensionFactory @@ -129,6 +129,7 @@ def decode( # Uncompress data. Protect against zip bombs by preventing zlib from # decompressing more than max_length bytes (except when the limit is # disabled with max_size = None). + data: BytesLike if frame.fin and len(frame.data) < 2044: # Profiling shows that appending four bytes, which makes a copy, is # faster than calling decompress() again when data is less than 2kB. @@ -182,6 +183,7 @@ def encode(self, frame: frames.Frame) -> frames.Frame: ) # Compress data. + data: BytesLike data = self.encoder.compress(frame.data) + self.encoder.flush(zlib.Z_SYNC_FLUSH) if frame.fin: # Sync flush generates between 5 or 6 bytes, ending with the bytes diff --git a/src/websockets/frames.py b/src/websockets/frames.py index 3a7077b66..89c48f390 100644 --- a/src/websockets/frames.py +++ b/src/websockets/frames.py @@ -10,6 +10,7 @@ from typing import Callable from .exceptions import PayloadTooBig, ProtocolError +from .typing import BytesLike try: @@ -118,9 +119,6 @@ class CloseCode(enum.IntEnum): } -BytesLike = bytes, bytearray, memoryview - - @dataclasses.dataclass class Frame: """ @@ -140,7 +138,7 @@ class Frame: """ opcode: Opcode - data: bytes | bytearray | memoryview + data: BytesLike fin: bool = True rsv1: bool = False rsv2: bool = False @@ -202,7 +200,7 @@ def __str__(self) -> str: @classmethod def parse( cls, - read_exact: Callable[[int], Generator[None, None, bytes]], + read_exact: Callable[[int], Generator[None, None, bytes | bytearray]], *, mask: bool, max_size: int | None = None, @@ -324,6 +322,7 @@ def serialize( output.write(mask_bytes) # Prepare the data. + data: BytesLike if mask: data = apply_mask(self.data, mask_bytes) else: @@ -383,7 +382,7 @@ def __str__(self) -> str: return result @classmethod - def parse(cls, data: bytes) -> Close: + def parse(cls, data: BytesLike) -> Close: """ Parse the payload of a close frame. @@ -395,6 +394,8 @@ def parse(cls, data: bytes) -> Close: UnicodeDecodeError: If the reason isn't valid UTF-8. """ + if isinstance(data, memoryview): + raise AssertionError("only compressed outgoing frames use memoryview") if len(data) >= 2: (code,) = struct.unpack("!H", data[:2]) reason = data[2:].decode() diff --git a/src/websockets/http11.py b/src/websockets/http11.py index 290ef087e..5af73eb0c 100644 --- a/src/websockets/http11.py +++ b/src/websockets/http11.py @@ -47,7 +47,7 @@ MAX_BODY_SIZE = int(os.environ.get("WEBSOCKETS_MAX_BODY_SIZE", "1_048_576")) # 1 MiB -def d(value: bytes) -> str: +def d(value: bytes | bytearray) -> str: """ Decode a bytestring for interpolating into an error message. @@ -102,7 +102,7 @@ def exception(self) -> Exception | None: # pragma: no cover @classmethod def parse( cls, - read_line: Callable[[int], Generator[None, None, bytes]], + read_line: Callable[[int], Generator[None, None, bytes | bytearray]], ) -> Generator[None, None, Request]: """ Parse a WebSocket handshake request. @@ -194,7 +194,7 @@ class Response: status_code: int reason_phrase: str headers: Headers - body: bytes = b"" + body: bytes | bytearray = b"" _exception: Exception | None = None @@ -210,9 +210,9 @@ def exception(self) -> Exception | None: # pragma: no cover @classmethod def parse( cls, - read_line: Callable[[int], Generator[None, None, bytes]], - read_exact: Callable[[int], Generator[None, None, bytes]], - read_to_eof: Callable[[int], Generator[None, None, bytes]], + read_line: Callable[[int], Generator[None, None, bytes | bytearray]], + read_exact: Callable[[int], Generator[None, None, bytes | bytearray]], + read_to_eof: Callable[[int], Generator[None, None, bytes | bytearray]], proxy: bool = False, ) -> Generator[None, None, Response]: """ @@ -276,6 +276,7 @@ def parse( headers = yield from parse_headers(read_line) + body: bytes | bytearray if proxy: body = b"" else: @@ -299,8 +300,8 @@ def serialize(self) -> bytes: def parse_line( - read_line: Callable[[int], Generator[None, None, bytes]], -) -> Generator[None, None, bytes]: + read_line: Callable[[int], Generator[None, None, bytes | bytearray]], +) -> Generator[None, None, bytes | bytearray]: """ Parse a single line. @@ -326,7 +327,7 @@ def parse_line( def parse_headers( - read_line: Callable[[int], Generator[None, None, bytes]], + read_line: Callable[[int], Generator[None, None, bytes | bytearray]], ) -> Generator[None, None, Headers]: """ Parse HTTP headers. @@ -379,10 +380,10 @@ def parse_headers( def read_body( status_code: int, headers: Headers, - read_line: Callable[[int], Generator[None, None, bytes]], - read_exact: Callable[[int], Generator[None, None, bytes]], - read_to_eof: Callable[[int], Generator[None, None, bytes]], -) -> Generator[None, None, bytes]: + read_line: Callable[[int], Generator[None, None, bytes | bytearray]], + read_exact: Callable[[int], Generator[None, None, bytes | bytearray]], + read_to_eof: Callable[[int], Generator[None, None, bytes | bytearray]], +) -> Generator[None, None, bytes | bytearray]: # https://datatracker.ietf.org/doc/html/rfc7230#section-3.3.3 # Since websockets only does GET requests (no HEAD, no CONNECT), all diff --git a/src/websockets/legacy/framing.py b/src/websockets/legacy/framing.py index add0c6e0e..452d2fb34 100644 --- a/src/websockets/legacy/framing.py +++ b/src/websockets/legacy/framing.py @@ -6,8 +6,7 @@ from .. import extensions, frames from ..exceptions import PayloadTooBig, ProtocolError -from ..frames import BytesLike -from ..typing import Data +from ..typing import BytesLike, DataLike try: @@ -19,7 +18,7 @@ class Frame(NamedTuple): fin: bool opcode: frames.Opcode - data: bytes + data: BytesLike rsv1: bool = False rsv2: bool = False rsv3: bool = False @@ -147,7 +146,7 @@ def write( write(self.new_frame.serialize(mask=mask, extensions=extensions)) -def prepare_data(data: Data) -> tuple[int, bytes]: +def prepare_data(data: DataLike) -> tuple[int, BytesLike]: """ Convert a string or byte-like object to an opcode and a bytes-like object. @@ -171,7 +170,7 @@ def prepare_data(data: Data) -> tuple[int, bytes]: raise TypeError("data must be str or bytes-like") -def prepare_ctrl(data: Data) -> bytes: +def prepare_ctrl(data: DataLike) -> bytes: """ Convert a string or byte-like object to bytes. diff --git a/src/websockets/legacy/protocol.py b/src/websockets/legacy/protocol.py index 1d604677a..03de30124 100644 --- a/src/websockets/legacy/protocol.py +++ b/src/websockets/legacy/protocol.py @@ -39,7 +39,7 @@ Opcode, ) from ..protocol import State -from ..typing import Data, LoggerLike, Subprotocol +from ..typing import BytesLike, Data, DataLike, LoggerLike, Subprotocol from .framing import Frame, prepare_ctrl, prepare_data @@ -563,7 +563,7 @@ async def recv(self) -> Data: async def send( self, - message: Data | Iterable[Data] | AsyncIterable[Data], + message: DataLike | Iterable[DataLike] | AsyncIterable[DataLike], ) -> None: """ Send a message. @@ -638,7 +638,7 @@ async def send( elif isinstance(message, Iterable): # Work around https://github.com/python/mypy/issues/6227 - message = cast(Iterable[Data], message) + message = cast(Iterable[DataLike], message) iter_message = iter(message) try: @@ -678,14 +678,14 @@ async def send( # Implement aiter_message = aiter(message) without aiter # Work around https://github.com/python/mypy/issues/5738 aiter_message = cast( - Callable[[AsyncIterable[Data]], AsyncIterator[Data]], + Callable[[AsyncIterable[DataLike]], AsyncIterator[DataLike]], type(message).__aiter__, )(message) try: # Implement fragment = anext(aiter_message) without anext # Work around https://github.com/python/mypy/issues/5738 fragment = await cast( - Callable[[AsyncIterator[Data]], Awaitable[Data]], + Callable[[AsyncIterator[DataLike]], Awaitable[DataLike]], type(aiter_message).__anext__, )(aiter_message) except StopAsyncIteration: @@ -788,7 +788,7 @@ async def wait_closed(self) -> None: """ await asyncio.shield(self.connection_lost_waiter) - async def ping(self, data: Data | None = None) -> Awaitable[float]: + async def ping(self, data: DataLike | None = None) -> Awaitable[float]: """ Send a Ping_. @@ -847,7 +847,7 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]: return asyncio.shield(pong_waiter) - async def pong(self, data: Data = b"") -> None: + async def pong(self, data: DataLike = b"") -> None: """ Send a Pong_. @@ -1025,10 +1025,12 @@ async def read_message(self) -> Data | None: # Shortcut for the common case - no fragmentation if frame.fin: - return frame.data.decode() if text else frame.data + if isinstance(frame.data, memoryview): + raise AssertionError("only compressed outgoing frames use memoryview") + return frame.data.decode() if text else bytes(frame.data) # 5.4. Fragmentation - fragments: list[Data] = [] + fragments: list[DataLike] = [] max_size = self.max_size if text: decoder_factory = codecs.getincrementaldecoder("utf-8") @@ -1152,7 +1154,7 @@ async def read_frame(self, max_size: int | None) -> Frame: self.logger.debug("< %s", frame) return frame - def write_frame_sync(self, fin: bool, opcode: int, data: bytes) -> None: + def write_frame_sync(self, fin: bool, opcode: int, data: BytesLike) -> None: frame = Frame(fin, Opcode(opcode), data) if self.debug: self.logger.debug("> %s", frame) @@ -1174,7 +1176,7 @@ async def drain(self) -> None: await self.ensure_open() async def write_frame( - self, fin: bool, opcode: int, data: bytes, *, _state: int = State.OPEN + self, fin: bool, opcode: int, data: BytesLike, *, _state: int = State.OPEN ) -> None: # Defensive assertion for protocol compliance. if self.state is not _state: # pragma: no cover @@ -1184,7 +1186,9 @@ async def write_frame( self.write_frame_sync(fin, opcode, data) await self.drain() - async def write_close_frame(self, close: Close, data: bytes | None = None) -> None: + async def write_close_frame( + self, close: Close, data: BytesLike | None = None + ) -> None: """ Write a close frame if and only if the connection state is OPEN. @@ -1538,7 +1542,7 @@ def eof_received(self) -> None: def broadcast( websockets: Iterable[WebSocketCommonProtocol], - message: Data, + message: DataLike, raise_exceptions: bool = False, ) -> None: """ diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 4054941b9..bf05eb546 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -28,7 +28,7 @@ ) from .http11 import Request, Response from .streams import StreamReader -from .typing import LoggerLike, Origin, Subprotocol +from .typing import BytesLike, LoggerLike, Origin, Subprotocol __all__ = [ @@ -291,7 +291,7 @@ def receive_eof(self) -> None: # Public methods for sending events. - def send_continuation(self, data: bytes, fin: bool) -> None: + def send_continuation(self, data: BytesLike, fin: bool) -> None: """ Send a `Continuation frame`_. @@ -315,7 +315,7 @@ def send_continuation(self, data: bytes, fin: bool) -> None: self.expect_continuation_frame = not fin self.send_frame(Frame(OP_CONT, data, fin)) - def send_text(self, data: bytes, fin: bool = True) -> None: + def send_text(self, data: BytesLike, fin: bool = True) -> None: """ Send a `Text frame`_. @@ -338,7 +338,7 @@ def send_text(self, data: bytes, fin: bool = True) -> None: self.expect_continuation_frame = not fin self.send_frame(Frame(OP_TEXT, data, fin)) - def send_binary(self, data: bytes, fin: bool = True) -> None: + def send_binary(self, data: BytesLike, fin: bool = True) -> None: """ Send a `Binary frame`_. @@ -397,7 +397,7 @@ def send_close(self, code: int | None = None, reason: str = "") -> None: self.close_sent = close self.state = CLOSING - def send_ping(self, data: bytes) -> None: + def send_ping(self, data: BytesLike) -> None: """ Send a `Ping frame`_. @@ -413,7 +413,7 @@ def send_ping(self, data: bytes) -> None: raise InvalidState(f"connection is {self.state.name.lower()}") self.send_frame(Frame(OP_PING, data)) - def send_pong(self, data: bytes) -> None: + def send_pong(self, data: BytesLike) -> None: """ Send a `Pong frame`_. diff --git a/src/websockets/speedups.pyi b/src/websockets/speedups.pyi index 821438a06..ffd6c3e07 100644 --- a/src/websockets/speedups.pyi +++ b/src/websockets/speedups.pyi @@ -1 +1,3 @@ -def apply_mask(data: bytes, mask: bytes) -> bytes: ... +from .typing import BytesLike + +def apply_mask(data: BytesLike, mask: bytes | bytearray) -> bytes: ... diff --git a/src/websockets/streams.py b/src/websockets/streams.py index f52e6193a..08ff58e77 100644 --- a/src/websockets/streams.py +++ b/src/websockets/streams.py @@ -17,7 +17,7 @@ def __init__(self) -> None: self.buffer = bytearray() self.eof = False - def read_line(self, m: int) -> Generator[None, None, bytes]: + def read_line(self, m: int) -> Generator[None, None, bytearray]: """ Read a LF-terminated line from the stream. @@ -51,7 +51,7 @@ def read_line(self, m: int) -> Generator[None, None, bytes]: del self.buffer[:n] return r - def read_exact(self, n: int) -> Generator[None, None, bytes]: + def read_exact(self, n: int) -> Generator[None, None, bytearray]: """ Read a given number of bytes from the stream. @@ -74,7 +74,7 @@ def read_exact(self, n: int) -> Generator[None, None, bytes]: del self.buffer[:n] return r - def read_to_eof(self, m: int) -> Generator[None, None, bytes]: + def read_to_eof(self, m: int) -> Generator[None, None, bytearray]: """ Read all bytes from the stream. diff --git a/src/websockets/sync/client.py b/src/websockets/sync/client.py index bfaef61eb..98860deee 100644 --- a/src/websockets/sync/client.py +++ b/src/websockets/sync/client.py @@ -16,7 +16,7 @@ from ..http11 import USER_AGENT, Response from ..protocol import CONNECTING, Event from ..streams import StreamReader -from ..typing import LoggerLike, Origin, Subprotocol +from ..typing import BytesLike, LoggerLike, Origin, Subprotocol from ..uri import Proxy, WebSocketURI, get_proxy, parse_proxy, parse_uri from .connection import Connection from .utils import Deadline @@ -633,10 +633,10 @@ def recv(self, buflen: int) -> bytes: except ssl_module.SSLEOFError: return b"" # always ignore ragged EOFs - def send(self, data: bytes) -> int: + def send(self, data: BytesLike) -> int: return self.run_io(self.ssl_object.write, data) - def sendall(self, data: bytes) -> None: + def sendall(self, data: BytesLike) -> None: # adapted from ssl_module.SSLSocket.sendall() count = 0 with memoryview(data) as view, view.cast("B") as byte_view: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index bedbf4def..d8b23a8a0 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -18,10 +18,10 @@ ConnectionClosedOK, ProtocolError, ) -from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode +from ..frames import DATA_OPCODES, CloseCode, Frame, Opcode from ..http11 import Request, Response from ..protocol import CLOSED, OPEN, Event, Protocol, State -from ..typing import Data, LoggerLike, Subprotocol +from ..typing import BytesLike, Data, DataLike, LoggerLike, Subprotocol from .messages import Assembler from .utils import Deadline @@ -410,7 +410,7 @@ def recv_streaming(self, decode: bool | None = None) -> Iterator[Data]: def send( self, - message: Data | Iterable[Data], + message: DataLike | Iterable[DataLike], text: bool | None = None, ) -> None: """ @@ -602,7 +602,7 @@ def close(self, code: int = CloseCode.NORMAL_CLOSURE, reason: str = "") -> None: def ping( self, - data: Data | None = None, + data: DataLike | None = None, ack_on_close: bool = False, ) -> threading.Event: """ @@ -659,7 +659,7 @@ def ping( self.protocol.send_ping(data) return pong_waiter - def pong(self, data: Data = b"") -> None: + def pong(self, data: DataLike = b"") -> None: """ Send a Pong_. diff --git a/src/websockets/sync/messages.py b/src/websockets/sync/messages.py index c619e78a1..c4d04bc83 100644 --- a/src/websockets/sync/messages.py +++ b/src/websockets/sync/messages.py @@ -193,6 +193,7 @@ def get(self, timeout: float | None = None, decode: bool | None = None) -> Data: finally: self.get_in_progress = False + # This converts frame.data to bytes when it's a bytearray. data = b"".join(frame.data for frame in frames) if decode: return data.decode() @@ -255,7 +256,8 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]: decoder = UTF8Decoder() yield decoder.decode(frame.data, frame.fin) else: - yield frame.data + # Convert to bytes when frame.data is a bytearray. + yield bytes(frame.data) # Following frames, for fragmented messages while not frame.fin: @@ -266,7 +268,8 @@ def get_iter(self, decode: bool | None = None) -> Iterator[Data]: if decode: yield decoder.decode(frame.data, frame.fin) else: - yield frame.data + # Convert to bytes when frame.data is a bytearray. + yield bytes(frame.data) self.get_in_progress = False diff --git a/src/websockets/typing.py b/src/websockets/typing.py index 93636e1c9..69b1a8d37 100644 --- a/src/websockets/typing.py +++ b/src/websockets/typing.py @@ -20,13 +20,18 @@ Data = str | bytes """Types supported in a WebSocket message: -:class:`str` for a Text_ frame, :class:`bytes` for a Binary_. +:class:`str` for a Text_ frame, :class:`bytes` for a Binary_ frame. .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary : https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 """ +BytesLike = bytes | bytearray | memoryview +"""Types accepted where :class:`bytes` is expected.""" + +DataLike = str | bytes | bytearray | memoryview +"""Types accepted where :class:`Data` is expected.""" if TYPE_CHECKING: LoggerLike = logging.Logger | logging.LoggerAdapter[Any] diff --git a/src/websockets/utils.py b/src/websockets/utils.py index 62d2dc177..b2a90e52b 100644 --- a/src/websockets/utils.py +++ b/src/websockets/utils.py @@ -5,6 +5,8 @@ import secrets import sys +from .typing import BytesLike + __all__ = ["accept_key", "apply_mask"] @@ -33,7 +35,7 @@ def accept_key(key: str) -> str: return base64.b64encode(sha1).decode() -def apply_mask(data: bytes, mask: bytes) -> bytes: +def apply_mask(data: BytesLike, mask: bytes | bytearray) -> bytes: """ Apply masking to the data of a WebSocket message. diff --git a/tox.ini b/tox.ini index aa88b0591..ce4572e59 100644 --- a/tox.ini +++ b/tox.ini @@ -46,6 +46,6 @@ deps = commands = mypy --strict src deps = - mypy<1.16.0 + mypy python-socks werkzeug