Skip to content

Commit 996c929

Browse files
authored
[PR #9685/5241897 backport][3.11] Fix WebSocket reader flow control calculations (#9793)
Co-authored-by: pre-commit-ci[bot]
1 parent bd3a3be commit 996c929

18 files changed

+318
-207
lines changed

CHANGES/9685.breaking.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
``FlowControlDataQueue`` has been replaced with the ``WebSocketDataQueue`` -- by :user:`bdraco`.

aiohttp/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@
9393
EMPTY_PAYLOAD as EMPTY_PAYLOAD,
9494
DataQueue as DataQueue,
9595
EofStream as EofStream,
96-
FlowControlDataQueue as FlowControlDataQueue,
9796
StreamReader as StreamReader,
9897
)
9998
from .tracing import (
@@ -216,7 +215,6 @@
216215
"DataQueue",
217216
"EMPTY_PAYLOAD",
218217
"EofStream",
219-
"FlowControlDataQueue",
220218
"StreamReader",
221219
# tracing
222220
"TraceConfig",

aiohttp/_websocket/reader.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,27 @@
55
from ..helpers import NO_EXTENSIONS
66

77
if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
8-
from .reader_py import WebSocketReader as WebSocketReaderPython
8+
from .reader_py import (
9+
WebSocketDataQueue as WebSocketDataQueuePython,
10+
WebSocketReader as WebSocketReaderPython,
11+
)
912

1013
WebSocketReader = WebSocketReaderPython
14+
WebSocketDataQueue = WebSocketDataQueuePython
1115
else:
1216
try:
1317
from .reader_c import ( # type: ignore[import-not-found]
18+
WebSocketDataQueue as WebSocketDataQueueCython,
1419
WebSocketReader as WebSocketReaderCython,
1520
)
1621

1722
WebSocketReader = WebSocketReaderCython
23+
WebSocketDataQueue = WebSocketDataQueueCython
1824
except ImportError: # pragma: no cover
19-
from .reader_py import WebSocketReader as WebSocketReaderPython
25+
from .reader_py import (
26+
WebSocketDataQueue as WebSocketDataQueuePython,
27+
WebSocketReader as WebSocketReaderPython,
28+
)
2029

2130
WebSocketReader = WebSocketReaderPython
31+
WebSocketDataQueue = WebSocketDataQueuePython

aiohttp/_websocket/reader_c.pxd

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,29 @@ cdef set MESSAGE_TYPES_WITH_CONTENT
3131
cdef tuple EMPTY_FRAME
3232
cdef tuple EMPTY_FRAME_ERROR
3333

34+
cdef class WebSocketDataQueue:
35+
36+
cdef unsigned int _size
37+
cdef public object _protocol
38+
cdef unsigned int _limit
39+
cdef object _loop
40+
cdef bint _eof
41+
cdef object _waiter
42+
cdef object _exception
43+
cdef public object _buffer
44+
cdef object _get_buffer
45+
cdef object _put_buffer
46+
47+
cdef void _release_waiter(self)
48+
49+
cpdef void feed_data(self, object data, unsigned int size)
50+
51+
@cython.locals(size="unsigned int")
52+
cdef _read_from_buffer(self)
3453

3554
cdef class WebSocketReader:
3655

37-
cdef object queue
38-
cdef object _queue_feed_data
56+
cdef WebSocketDataQueue queue
3957
cdef unsigned int _max_msg_size
4058

4159
cdef Exception _exc

aiohttp/_websocket/reader_py.py

Lines changed: 90 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
"""Reader for WebSocket protocol versions 13 and 8."""
22

3-
from typing import Final, List, Optional, Set, Tuple, Union
3+
import asyncio
4+
import builtins
5+
from collections import deque
6+
from typing import Deque, Final, List, Optional, Set, Tuple, Union
47

8+
from ..base_protocol import BaseProtocol
59
from ..compression_utils import ZLibDecompressor
6-
from ..helpers import set_exception
7-
from ..streams import FlowControlDataQueue
10+
from ..helpers import _EXC_SENTINEL, set_exception
11+
from ..streams import EofStream
812
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
913
from .models import (
1014
WS_DEFLATE_TRAILING,
@@ -39,16 +43,89 @@
3943

4044
TUPLE_NEW = tuple.__new__
4145

46+
int_ = int # Prevent Cython from converting to PyInt
47+
48+
49+
class WebSocketDataQueue:
50+
"""WebSocketDataQueue resumes and pauses an underlying stream.
51+
52+
It is a destination for WebSocket data.
53+
"""
4254

43-
class WebSocketReader:
4455
def __init__(
56+
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
57+
) -> None:
58+
self._size = 0
59+
self._protocol = protocol
60+
self._limit = limit * 2
61+
self._loop = loop
62+
self._eof = False
63+
self._waiter: Optional[asyncio.Future[None]] = None
64+
self._exception: Union[BaseException, None] = None
65+
self._buffer: Deque[Tuple[WSMessage, int]] = deque()
66+
self._get_buffer = self._buffer.popleft
67+
self._put_buffer = self._buffer.append
68+
69+
def exception(self) -> Optional[BaseException]:
70+
return self._exception
71+
72+
def set_exception(
4573
self,
46-
queue: FlowControlDataQueue[WSMessage],
47-
max_msg_size: int,
48-
compress: bool = True,
74+
exc: "BaseException",
75+
exc_cause: builtins.BaseException = _EXC_SENTINEL,
76+
) -> None:
77+
self._eof = True
78+
self._exception = exc
79+
if (waiter := self._waiter) is not None:
80+
self._waiter = None
81+
set_exception(waiter, exc, exc_cause)
82+
83+
def _release_waiter(self) -> None:
84+
if (waiter := self._waiter) is None:
85+
return
86+
self._waiter = None
87+
if not waiter.done():
88+
waiter.set_result(None)
89+
90+
def feed_eof(self) -> None:
91+
self._eof = True
92+
self._release_waiter()
93+
94+
def feed_data(self, data: "WSMessage", size: "int_") -> None:
95+
self._size += size
96+
self._put_buffer((data, size))
97+
self._release_waiter()
98+
if self._size > self._limit and not self._protocol._reading_paused:
99+
self._protocol.pause_reading()
100+
101+
async def read(self) -> WSMessage:
102+
if not self._buffer and not self._eof:
103+
assert not self._waiter
104+
self._waiter = self._loop.create_future()
105+
try:
106+
await self._waiter
107+
except (asyncio.CancelledError, asyncio.TimeoutError):
108+
self._waiter = None
109+
raise
110+
return self._read_from_buffer()
111+
112+
def _read_from_buffer(self) -> WSMessage:
113+
if self._buffer:
114+
data, size = self._get_buffer()
115+
self._size -= size
116+
if self._size < self._limit and self._protocol._reading_paused:
117+
self._protocol.resume_reading()
118+
return data
119+
if self._exception is not None:
120+
raise self._exception
121+
raise EofStream
122+
123+
124+
class WebSocketReader:
125+
def __init__(
126+
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
49127
) -> None:
50128
self.queue = queue
51-
self._queue_feed_data = queue.feed_data
52129
self._max_msg_size = max_msg_size
53130

54131
self._exc: Optional[Exception] = None
@@ -187,12 +264,12 @@ def _feed_data(self, data: bytes) -> None:
187264
# bottleneck, so we use tuple.__new__ to improve performance.
188265
# This is not type safe, but many tests should fail in
189266
# test_client_ws_functional.py if this is wrong.
190-
self._queue_feed_data(
267+
self.queue.feed_data(
191268
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
192269
len(payload_merged),
193270
)
194271
else:
195-
self._queue_feed_data(
272+
self.queue.feed_data(
196273
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
197274
len(payload_merged),
198275
)
@@ -221,14 +298,14 @@ def _feed_data(self, data: bytes) -> None:
221298
else:
222299
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))
223300

224-
self._queue_feed_data(msg, 0)
301+
self.queue.feed_data(msg, 0)
225302
elif opcode == OP_CODE_PING:
226303
msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
227-
self._queue_feed_data(msg, len(payload))
304+
self.queue.feed_data(msg, len(payload))
228305

229306
elif opcode == OP_CODE_PONG:
230307
msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
231-
self._queue_feed_data(msg, len(payload))
308+
self.queue.feed_data(msg, len(payload))
232309

233310
else:
234311
raise WebSocketError(

aiohttp/client.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from yarl import URL
3838

3939
from . import hdrs, http, payload
40+
from ._websocket.reader import WebSocketDataQueue
4041
from .abc import AbstractCookieJar
4142
from .client_exceptions import (
4243
ClientConnectionError,
@@ -100,8 +101,7 @@
100101
strip_auth_from_url,
101102
)
102103
from .http import WS_KEY, HttpVersion, WebSocketReader, WebSocketWriter
103-
from .http_websocket import WSHandshakeError, WSMessage, ws_ext_gen, ws_ext_parse
104-
from .streams import FlowControlDataQueue
104+
from .http_websocket import WSHandshakeError, ws_ext_gen, ws_ext_parse
105105
from .tracing import Trace, TraceConfig
106106
from .typedefs import JSONEncoder, LooseCookies, LooseHeaders, Query, StrOrURL
107107

@@ -1098,9 +1098,7 @@ async def _ws_connect(
10981098

10991099
transport = conn.transport
11001100
assert transport is not None
1101-
reader: FlowControlDataQueue[WSMessage] = FlowControlDataQueue(
1102-
conn_proto, 2**16, loop=self._loop
1103-
)
1101+
reader = WebSocketDataQueue(conn_proto, 2**16, loop=self._loop)
11041102
conn_proto.set_parser(WebSocketReader(reader, max_msg_size), reader)
11051103
writer = WebSocketWriter(
11061104
conn_proto,

aiohttp/client_proto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def set_exception(
151151
def set_parser(self, parser: Any, payload: Any) -> None:
152152
# TODO: actual types are:
153153
# parser: WebSocketReader
154-
# payload: FlowControlDataQueue
154+
# payload: WebSocketDataQueue
155155
# but they are not generi enough
156156
# Need an ABC for both types
157157
self._payload = payload

aiohttp/client_ws.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import attr
99

10+
from ._websocket.reader import WebSocketDataQueue
1011
from .client_exceptions import ClientError, ServerTimeoutError, WSMessageTypeError
1112
from .client_reqrep import ClientResponse
1213
from .helpers import calculate_timeout_when, set_result
@@ -19,7 +20,7 @@
1920
WSMsgType,
2021
)
2122
from .http_websocket import _INTERNAL_RECEIVE_TYPES, WebSocketWriter
22-
from .streams import EofStream, FlowControlDataQueue
23+
from .streams import EofStream
2324
from .typedefs import (
2425
DEFAULT_JSON_DECODER,
2526
DEFAULT_JSON_ENCODER,
@@ -45,7 +46,7 @@ class ClientWSTimeout:
4546
class ClientWebSocketResponse:
4647
def __init__(
4748
self,
48-
reader: "FlowControlDataQueue[WSMessage]",
49+
reader: WebSocketDataQueue,
4950
writer: WebSocketWriter,
5051
protocol: Optional[str],
5152
response: ClientResponse,
@@ -173,7 +174,7 @@ def _handle_ping_pong_exception(self, exc: BaseException) -> None:
173174
self._exception = exc
174175
self._response.close()
175176
if self._waiting and not self._closing:
176-
self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None))
177+
self._reader.feed_data(WSMessage(WSMsgType.ERROR, exc, None), 0)
177178

178179
def _set_closed(self) -> None:
179180
"""Set the connection to closed.

aiohttp/streams.py

Lines changed: 7 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
"EofStream",
2929
"StreamReader",
3030
"DataQueue",
31-
"FlowControlDataQueue",
3231
)
3332

3433
_T = TypeVar("_T")
@@ -660,18 +659,15 @@ def feed_eof(self) -> None:
660659
self._waiter = None
661660
set_result(waiter, None)
662661

663-
async def _wait_for_data(self) -> None:
664-
assert not self._waiter
665-
self._waiter = self._loop.create_future()
666-
try:
667-
await self._waiter
668-
except (asyncio.CancelledError, asyncio.TimeoutError):
669-
self._waiter = None
670-
raise
671-
672662
async def read(self) -> _T:
673663
if not self._buffer and not self._eof:
674-
await self._wait_for_data()
664+
assert not self._waiter
665+
self._waiter = self._loop.create_future()
666+
try:
667+
await self._waiter
668+
except (asyncio.CancelledError, asyncio.TimeoutError):
669+
self._waiter = None
670+
raise
675671
if self._buffer:
676672
data, _ = self._buffer.popleft()
677673
return data
@@ -681,41 +677,3 @@ async def read(self) -> _T:
681677

682678
def __aiter__(self) -> AsyncStreamIterator[_T]:
683679
return AsyncStreamIterator(self.read)
684-
685-
686-
class FlowControlDataQueue(DataQueue[_T]):
687-
"""FlowControlDataQueue resumes and pauses an underlying stream.
688-
689-
It is a destination for parsed data.
690-
"""
691-
692-
def __init__(
693-
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
694-
) -> None:
695-
super().__init__(loop=loop)
696-
self._size = 0
697-
self._protocol = protocol
698-
self._limit = limit * 2
699-
self._buffer: Deque[Tuple[_T, int]] = collections.deque()
700-
701-
def feed_data(self, data: _T, size: int = 0) -> None:
702-
self._size += size
703-
self._buffer.append((data, size))
704-
if (waiter := self._waiter) is not None:
705-
self._waiter = None
706-
set_result(waiter, None)
707-
if self._size > self._limit and not self._protocol._reading_paused:
708-
self._protocol.pause_reading()
709-
710-
async def read(self) -> _T:
711-
if not self._buffer and not self._eof:
712-
await self._wait_for_data()
713-
if self._buffer:
714-
data, size = self._buffer.popleft()
715-
self._size -= size
716-
if self._size < self._limit and self._protocol._reading_paused:
717-
self._protocol.resume_reading()
718-
return data
719-
if self._exception is not None:
720-
raise self._exception
721-
raise EofStream

0 commit comments

Comments
 (0)