diff --git a/aiohttp/_websocket/models.py b/aiohttp/_websocket/models.py index 085fb460cb5..b42f88fd5f3 100644 --- a/aiohttp/_websocket/models.py +++ b/aiohttp/_websocket/models.py @@ -59,6 +59,21 @@ def json( return loads(self.data) +class WSMessageTextBytes(NamedTuple): + """WebSocket TEXT message with raw bytes (no UTF-8 decoding).""" + + data: bytes + size: int + extra: str | None = None + type: Literal[WSMsgType.TEXT] = WSMsgType.TEXT + + def json( + self, *, loads: Callable[[str | bytes | bytearray], Any] = json.loads + ) -> Any: + """Return parsed JSON data.""" + return loads(self.data) + + class WSMessageBinary(NamedTuple): data: bytes size: int @@ -117,6 +132,7 @@ class WSMessageError(NamedTuple): WSMessage = Union[ WSMessageContinuation, WSMessageText, + WSMessageTextBytes, WSMessageBinary, WSMessagePing, WSMessagePong, diff --git a/aiohttp/_websocket/reader_c.pxd b/aiohttp/_websocket/reader_c.pxd index 9a6fdae3e97..7e5e46f13c7 100644 --- a/aiohttp/_websocket/reader_c.pxd +++ b/aiohttp/_websocket/reader_c.pxd @@ -27,6 +27,7 @@ cdef object TUPLE_NEW cdef object WSMsgType cdef object WSMessageText +cdef object WSMessageTextBytes cdef object WSMessageBinary cdef object WSMessagePing cdef object WSMessagePong @@ -66,6 +67,7 @@ cdef class WebSocketReader: cdef WebSocketDataQueue queue cdef unsigned int _max_msg_size + cdef bint _decode_text cdef Exception _exc cdef bytearray _partial diff --git a/aiohttp/_websocket/reader_py.py b/aiohttp/_websocket/reader_py.py index 5bcc2ecfb78..17d6aa1b16d 100644 --- a/aiohttp/_websocket/reader_py.py +++ b/aiohttp/_websocket/reader_py.py @@ -20,6 +20,7 @@ WSMessagePing, WSMessagePong, WSMessageText, + WSMessageTextBytes, WSMsgType, ) @@ -139,10 +140,15 @@ def _read_from_buffer(self) -> WSMessage: class WebSocketReader: def __init__( - self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True + self, + queue: WebSocketDataQueue, + max_msg_size: int, + compress: bool = True, + decode_text: bool = True, ) -> None: self.queue = queue self._max_msg_size = max_msg_size + self._decode_text = decode_text self._exc: Exception | None = None self._partial = bytearray() @@ -270,18 +276,24 @@ def _handle_frame( size = len(payload_merged) if opcode == OP_CODE_TEXT: - try: - text = payload_merged.decode("utf-8") - except UnicodeDecodeError as exc: - raise WebSocketError( - WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" - ) from exc - - # XXX: The Text and Binary messages here can be a performance - # bottleneck, so we use tuple.__new__ to improve performance. - # This is not type safe, but many tests should fail in - # test_client_ws_functional.py if this is wrong. - msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT)) + if self._decode_text: + try: + text = payload_merged.decode("utf-8") + except UnicodeDecodeError as exc: + raise WebSocketError( + WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" + ) from exc + + # XXX: The Text and Binary messages here can be a performance + # bottleneck, so we use tuple.__new__ to improve performance. + # This is not type safe, but many tests should fail in + # test_client_ws_functional.py if this is wrong. + msg = TUPLE_NEW(WSMessageText, (text, size, "", WS_MSG_TYPE_TEXT)) + else: + # Keep as bytes for performance (e.g., for orjson parsing) + msg = TUPLE_NEW( + WSMessageTextBytes, (payload_merged, size, "", WS_MSG_TYPE_TEXT) + ) else: msg = TUPLE_NEW( WSMessageBinary, (payload_merged, size, "", WS_MSG_TYPE_BINARY) diff --git a/aiohttp/client.py b/aiohttp/client.py index 026006023ce..99f9c9994dc 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -888,6 +888,7 @@ def ws_connect( proxy_headers: LooseHeaders | None = None, compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, + decode_text: bool = True, ) -> "_WSRequestContextManager": """Initiate websocket connection.""" return _WSRequestContextManager( @@ -911,6 +912,7 @@ def ws_connect( proxy_headers=proxy_headers, compress=compress, max_msg_size=max_msg_size, + decode_text=decode_text, ) ) @@ -936,6 +938,7 @@ async def _ws_connect( proxy_headers: LooseHeaders | None = None, compress: int = 0, max_msg_size: int = 4 * 1024 * 1024, + decode_text: bool = True, ) -> ClientWebSocketResponse: if timeout is not sentinel: if isinstance(timeout, ClientWSTimeout): @@ -1098,7 +1101,9 @@ async def _ws_connect( transport = conn.transport assert transport is not None reader = WebSocketDataQueue(conn_proto, 2**16, loop=self._loop) - conn_proto.set_parser(WebSocketReader(reader, max_msg_size), reader) + conn_proto.set_parser( + WebSocketReader(reader, max_msg_size, decode_text=decode_text), reader + ) writer = WebSocketWriter( conn_proto, transport, diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index 36959aae0c7..6901083fc22 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -384,12 +384,20 @@ async def receive(self, timeout: float | None = None) -> WSMessage: return msg async def receive_str(self, *, timeout: float | None = None) -> str: + """Receive TEXT message. + + Returns str when decode_text=True (default), bytes when decode_text=False. + + Note: The return type annotation is kept as str for backwards compatibility, + but this method will return bytes when the WebSocket connection was created + with decode_text=False. + """ msg = await self.receive(timeout) if msg.type is not WSMsgType.TEXT: raise WSMessageTypeError( f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT" ) - return msg.data + return msg.data # type: ignore[return-value] async def receive_bytes(self, *, timeout: float | None = None) -> bytes: msg = await self.receive(timeout) diff --git a/aiohttp/http_websocket.py b/aiohttp/http_websocket.py index f49d8aee287..830318c0b9a 100644 --- a/aiohttp/http_websocket.py +++ b/aiohttp/http_websocket.py @@ -17,6 +17,7 @@ WSMessagePing, WSMessagePong, WSMessageText, + WSMessageTextBytes, WSMsgType, ) from ._websocket.reader import WebSocketReader @@ -48,6 +49,7 @@ "WSMessagePong", "WSMessageBinary", "WSMessageText", + "WSMessageTextBytes", "WSMessagePing", "WSMessageContinuation", ) diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index 8eee7e3ad71..82a89c049d5 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -95,6 +95,7 @@ def __init__( compress: bool = True, max_msg_size: int = 4 * 1024 * 1024, writer_limit: int = DEFAULT_LIMIT, + decode_text: bool = True, ) -> None: super().__init__(status=101) self._protocols = protocols @@ -108,6 +109,7 @@ def __init__( self._compress: bool | int = compress self._max_msg_size = max_msg_size self._writer_limit = writer_limit + self._decode_text = decode_text def _cancel_heartbeat(self) -> None: self._cancel_pong_response_cb() @@ -341,7 +343,10 @@ def _post_start( self._reader = WebSocketDataQueue(request._protocol, 2**16, loop=loop) request.protocol.set_parser( WebSocketReader( - self._reader, self._max_msg_size, compress=bool(self._compress) + self._reader, + self._max_msg_size, + compress=bool(self._compress), + decode_text=self._decode_text, ) ) # disable HTTP keepalive for WebSocket @@ -589,12 +594,20 @@ async def receive(self, timeout: float | None = None) -> WSMessage: return msg async def receive_str(self, *, timeout: float | None = None) -> str: + """Receive TEXT message. + + Returns str when decode_text=True (default), bytes when decode_text=False. + + Note: The return type annotation is kept as str for backwards compatibility, + but this method will return bytes when the WebSocket connection was created + with decode_text=False. + """ msg = await self.receive(timeout) if msg.type is not WSMsgType.TEXT: raise WSMessageTypeError( f"Received message {msg.type}:{msg.data!r} is not WSMsgType.TEXT" ) - return msg.data + return msg.data # type: ignore[return-value] async def receive_bytes(self, *, timeout: float | None = None) -> bytes: msg = await self.receive(timeout) diff --git a/examples/server_simple.py b/examples/server_simple.py index c602bff426d..b6300daf568 100644 --- a/examples/server_simple.py +++ b/examples/server_simple.py @@ -1,4 +1,6 @@ # server_simple.py +from typing import TYPE_CHECKING + from aiohttp import web @@ -14,7 +16,10 @@ async def wshandle(request: web.Request) -> web.StreamResponse: async for msg in ws: if msg.type is web.WSMsgType.TEXT: - await ws.send_str(f"Hello, {msg.data}") + data = msg.data + if TYPE_CHECKING: + assert isinstance(data, str) + await ws.send_str(f"Hello, {data}") elif msg.type is web.WSMsgType.BINARY: await ws.send_bytes(msg.data) elif msg.type is web.WSMsgType.CLOSE: diff --git a/examples/web_ws.py b/examples/web_ws.py index 4051ae12890..85224bf9e66 100755 --- a/examples/web_ws.py +++ b/examples/web_ws.py @@ -6,6 +6,7 @@ # mypy: disallow-any-expr, disallow-any-unimported, disallow-subclassing-any import os +from typing import TYPE_CHECKING from aiohttp import web @@ -32,9 +33,12 @@ async def wshandler(request: web.Request) -> web.WebSocketResponse | web.Respons async for msg in resp: if msg.type is web.WSMsgType.TEXT: + data = msg.data + if TYPE_CHECKING: + assert isinstance(data, str) for ws in request.app[sockets]: if ws is not resp: - await ws.send_str(msg.data) + await ws.send_str(data) else: return resp return resp diff --git a/tests/autobahn/client/client.py b/tests/autobahn/client/client.py index 5dbcd6ff654..1c230e11c7e 100644 --- a/tests/autobahn/client/client.py +++ b/tests/autobahn/client/client.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import asyncio +from typing import TYPE_CHECKING import aiohttp @@ -19,7 +20,10 @@ async def client(url: str, name: str) -> None: async with session.ws_connect(text_url) as ws: async for msg in ws: if msg.type is aiohttp.WSMsgType.TEXT: - await ws.send_str(msg.data) + data = msg.data + if TYPE_CHECKING: + assert isinstance(data, str) + await ws.send_str(data) elif msg.type is aiohttp.WSMsgType.BINARY: await ws.send_bytes(msg.data) else: diff --git a/tests/autobahn/server/server.py b/tests/autobahn/server/server.py index 47122ae2578..f63c0455d5d 100644 --- a/tests/autobahn/server/server.py +++ b/tests/autobahn/server/server.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import logging +from typing import TYPE_CHECKING from aiohttp import WSCloseCode, web @@ -21,7 +22,10 @@ async def wshandler(request: web.Request) -> web.WebSocketResponse: msg = await ws.receive() if msg.type is web.WSMsgType.TEXT: - await ws.send_str(msg.data) + data = msg.data + if TYPE_CHECKING: + assert isinstance(data, str) + await ws.send_str(data) elif msg.type is web.WSMsgType.BINARY: await ws.send_bytes(msg.data) elif msg.type is web.WSMsgType.CLOSE: diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index 0bc05f300d4..bddceaad0f2 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -1,7 +1,7 @@ import asyncio import json import sys -from typing import NoReturn +from typing import TYPE_CHECKING, NoReturn from unittest import mock import pytest @@ -1080,7 +1080,10 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await ws.prepare(request) msg = await ws.receive() assert msg.type is WSMsgType.TEXT - await ws.send_str(msg.data + "/answer") + data = msg.data + if TYPE_CHECKING: + assert isinstance(data, str) + await ws.send_str(data + "/answer") await ws.close() return ws @@ -1106,7 +1109,10 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await ws.prepare(request) msg = await ws.receive() assert msg.type is WSMsgType.TEXT - await ws.send_str(msg.data + "/answer") + data = msg.data + if TYPE_CHECKING: + assert isinstance(data, str) + await ws.send_str(data + "/answer") await ws.close() return ws @@ -1130,7 +1136,10 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await ws.prepare(request) msg = await ws.receive() assert msg.type is WSMsgType.TEXT - await ws.send_str(msg.data + "/answer") + data = msg.data + if TYPE_CHECKING: + assert isinstance(data, str) + await ws.send_str(data + "/answer") await ws.close() return ws diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index bbab015061f..27efc5229e3 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -3,7 +3,7 @@ import socket import sys from collections.abc import Iterator, Mapping -from typing import NoReturn +from typing import TYPE_CHECKING, NoReturn from unittest import mock import pytest @@ -45,7 +45,10 @@ async def websocket_handler(request: web.Request) -> web.WebSocketResponse: if msg.data == "close": await ws.close() else: - await ws.send_str(msg.data + "/answer") + data = msg.data + if TYPE_CHECKING: + assert isinstance(data, str) + await ws.send_str(data + "/answer") return ws diff --git a/tests/test_web_websocket_functional.py b/tests/test_web_websocket_functional.py index afa76e2d742..42eb748e6f4 100644 --- a/tests/test_web_websocket_functional.py +++ b/tests/test_web_websocket_functional.py @@ -4,7 +4,7 @@ import contextlib import sys import weakref -from typing import NoReturn +from typing import TYPE_CHECKING, NoReturn from unittest import mock import pytest @@ -913,6 +913,8 @@ async def handler(request: web.Request) -> web.WebSocketResponse: async for msg in ws: assert msg.type == aiohttp.WSMsgType.TEXT s = msg.data + if TYPE_CHECKING: + assert isinstance(s, str) await ws.send_str(s + "/answer") await ws.close() closed.set_result(1)