Skip to content

Commit 1d00bd2

Browse files
[PR #10744/23d3ee06 backport][3.11] Refactor WebSocket reader to avoid frequent realloc when frames are fragmented (#10747)
Co-authored-by: J. Nick Koston <[email protected]>
1 parent 099cc0c commit 1d00bd2

File tree

3 files changed

+72
-57
lines changed

3 files changed

+72
-57
lines changed

CHANGES/10744.misc.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improved performance of the WebSocket reader with large messages -- by :user:`bdraco`.

aiohttp/_websocket/reader_c.pxd

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,14 @@ cdef class WebSocketReader:
6868
cdef int _opcode
6969
cdef bint _frame_fin
7070
cdef int _frame_opcode
71-
cdef object _frame_payload
72-
cdef unsigned long long _frame_payload_len
71+
cdef list _payload_fragments
72+
cdef Py_ssize_t _frame_payload_len
7373

7474
cdef bytes _tail
7575
cdef bint _has_mask
7676
cdef bytes _frame_mask
77-
cdef unsigned long long _payload_length
78-
cdef unsigned int _payload_length_flag
77+
cdef Py_ssize_t _payload_bytes_to_read
78+
cdef unsigned int _payload_len_flag
7979
cdef int _compressed
8080
cdef object _decompressobj
8181
cdef bint _compress
@@ -91,17 +91,20 @@ cdef class WebSocketReader:
9191
cpdef void _handle_frame(self, bint fin, int opcode, object payload, int compressed) except *
9292

9393
@cython.locals(
94-
start_pos="unsigned int",
95-
data_len="unsigned int",
96-
length="unsigned int",
97-
chunk_size="unsigned int",
98-
chunk_len="unsigned int",
99-
data_length="unsigned int",
94+
start_pos=Py_ssize_t,
95+
data_len=Py_ssize_t,
96+
length=Py_ssize_t,
97+
chunk_size=Py_ssize_t,
98+
chunk_len=Py_ssize_t,
99+
data_len=Py_ssize_t,
100100
data_cstr="const unsigned char *",
101101
first_byte="unsigned char",
102102
second_byte="unsigned char",
103-
end_pos="unsigned int",
103+
f_start_pos=Py_ssize_t,
104+
f_end_pos=Py_ssize_t,
104105
has_mask=bint,
105106
fin=bint,
107+
had_fragments=Py_ssize_t,
108+
payload_bytearray=bytearray,
106109
)
107110
cpdef void _feed_data(self, bytes data) except *

aiohttp/_websocket/reader_py.py

Lines changed: 57 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,14 @@ def __init__(
144144
self._opcode: int = OP_CODE_NOT_SET
145145
self._frame_fin = False
146146
self._frame_opcode: int = OP_CODE_NOT_SET
147-
self._frame_payload: Union[bytes, bytearray] = b""
147+
self._payload_fragments: list[bytes] = []
148148
self._frame_payload_len = 0
149149

150150
self._tail: bytes = b""
151151
self._has_mask = False
152152
self._frame_mask: Optional[bytes] = None
153-
self._payload_length = 0
154-
self._payload_length_flag = 0
153+
self._payload_bytes_to_read = 0
154+
self._payload_len_flag = 0
155155
self._compressed: int = COMPRESSED_NOT_SET
156156
self._decompressobj: Optional[ZLibDecompressor] = None
157157
self._compress = compress
@@ -317,13 +317,13 @@ def _feed_data(self, data: bytes) -> None:
317317
data, self._tail = self._tail + data, b""
318318

319319
start_pos: int = 0
320-
data_length = len(data)
320+
data_len = len(data)
321321
data_cstr = data
322322

323323
while True:
324324
# read header
325325
if self._state == READ_HEADER:
326-
if data_length - start_pos < 2:
326+
if data_len - start_pos < 2:
327327
break
328328
first_byte = data_cstr[start_pos]
329329
second_byte = data_cstr[start_pos + 1]
@@ -382,77 +382,88 @@ def _feed_data(self, data: bytes) -> None:
382382
self._frame_fin = bool(fin)
383383
self._frame_opcode = opcode
384384
self._has_mask = bool(has_mask)
385-
self._payload_length_flag = length
385+
self._payload_len_flag = length
386386
self._state = READ_PAYLOAD_LENGTH
387387

388388
# read payload length
389389
if self._state == READ_PAYLOAD_LENGTH:
390-
length_flag = self._payload_length_flag
391-
if length_flag == 126:
392-
if data_length - start_pos < 2:
390+
len_flag = self._payload_len_flag
391+
if len_flag == 126:
392+
if data_len - start_pos < 2:
393393
break
394394
first_byte = data_cstr[start_pos]
395395
second_byte = data_cstr[start_pos + 1]
396396
start_pos += 2
397-
self._payload_length = first_byte << 8 | second_byte
398-
elif length_flag > 126:
399-
if data_length - start_pos < 8:
397+
self._payload_bytes_to_read = first_byte << 8 | second_byte
398+
elif len_flag > 126:
399+
if data_len - start_pos < 8:
400400
break
401-
self._payload_length = UNPACK_LEN3(data, start_pos)[0]
401+
self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
402402
start_pos += 8
403403
else:
404-
self._payload_length = length_flag
404+
self._payload_bytes_to_read = len_flag
405405

406406
self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
407407

408408
# read payload mask
409409
if self._state == READ_PAYLOAD_MASK:
410-
if data_length - start_pos < 4:
410+
if data_len - start_pos < 4:
411411
break
412412
self._frame_mask = data_cstr[start_pos : start_pos + 4]
413413
start_pos += 4
414414
self._state = READ_PAYLOAD
415415

416416
if self._state == READ_PAYLOAD:
417-
chunk_len = data_length - start_pos
418-
if self._payload_length >= chunk_len:
419-
end_pos = data_length
420-
self._payload_length -= chunk_len
417+
chunk_len = data_len - start_pos
418+
if self._payload_bytes_to_read >= chunk_len:
419+
f_end_pos = data_len
420+
self._payload_bytes_to_read -= chunk_len
421421
else:
422-
end_pos = start_pos + self._payload_length
423-
self._payload_length = 0
424-
425-
if self._frame_payload_len:
426-
if type(self._frame_payload) is not bytearray:
427-
self._frame_payload = bytearray(self._frame_payload)
428-
self._frame_payload += data_cstr[start_pos:end_pos]
429-
else:
430-
# Fast path for the first frame
431-
self._frame_payload = data_cstr[start_pos:end_pos]
432-
433-
self._frame_payload_len += end_pos - start_pos
434-
start_pos = end_pos
435-
436-
if self._payload_length != 0:
422+
f_end_pos = start_pos + self._payload_bytes_to_read
423+
self._payload_bytes_to_read = 0
424+
425+
had_fragments = self._frame_payload_len
426+
self._frame_payload_len += f_end_pos - start_pos
427+
f_start_pos = start_pos
428+
start_pos = f_end_pos
429+
430+
if self._payload_bytes_to_read != 0:
431+
# If we don't have a complete frame, we need to save the
432+
# data for the next call to feed_data.
433+
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
437434
break
438435

439-
if self._has_mask:
436+
payload: Union[bytes, bytearray]
437+
if had_fragments:
438+
# We have to join the payload fragments get the payload
439+
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
440+
if self._has_mask:
441+
assert self._frame_mask is not None
442+
payload_bytearray = bytearray()
443+
payload_bytearray.join(self._payload_fragments)
444+
websocket_mask(self._frame_mask, payload_bytearray)
445+
payload = payload_bytearray
446+
else:
447+
payload = b"".join(self._payload_fragments)
448+
self._payload_fragments.clear()
449+
elif self._has_mask:
440450
assert self._frame_mask is not None
441-
if type(self._frame_payload) is not bytearray:
442-
self._frame_payload = bytearray(self._frame_payload)
443-
websocket_mask(self._frame_mask, self._frame_payload)
451+
payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
452+
if type(payload_bytearray) is not bytearray: # pragma: no branch
453+
# Cython will do the conversion for us
454+
# but we need to do it for Python and we
455+
# will always get here in Python
456+
payload_bytearray = bytearray(payload_bytearray)
457+
websocket_mask(self._frame_mask, payload_bytearray)
458+
payload = payload_bytearray
459+
else:
460+
payload = data_cstr[f_start_pos:f_end_pos]
444461

445462
self._handle_frame(
446-
self._frame_fin,
447-
self._frame_opcode,
448-
self._frame_payload,
449-
self._compressed,
463+
self._frame_fin, self._frame_opcode, payload, self._compressed
450464
)
451-
self._frame_payload = b""
452465
self._frame_payload_len = 0
453466
self._state = READ_HEADER
454467

455468
# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
456-
self._tail = (
457-
data_cstr[start_pos:data_length] if start_pos < data_length else b""
458-
)
469+
self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""

0 commit comments

Comments
 (0)