@@ -144,14 +144,14 @@ def __init__(
144
144
self ._opcode : int = OP_CODE_NOT_SET
145
145
self ._frame_fin = False
146
146
self ._frame_opcode : int = OP_CODE_NOT_SET
147
- self ._frame_payload : Union [bytes , bytearray ] = b""
147
+ self ._payload_fragments : list [bytes ] = []
148
148
self ._frame_payload_len = 0
149
149
150
150
self ._tail : bytes = b""
151
151
self ._has_mask = False
152
152
self ._frame_mask : Optional [bytes ] = None
153
- self ._payload_length = 0
154
- self ._payload_length_flag = 0
153
+ self ._payload_bytes_to_read = 0
154
+ self ._payload_len_flag = 0
155
155
self ._compressed : int = COMPRESSED_NOT_SET
156
156
self ._decompressobj : Optional [ZLibDecompressor ] = None
157
157
self ._compress = compress
@@ -317,13 +317,13 @@ def _feed_data(self, data: bytes) -> None:
317
317
data , self ._tail = self ._tail + data , b""
318
318
319
319
start_pos : int = 0
320
- data_length = len (data )
320
+ data_len = len (data )
321
321
data_cstr = data
322
322
323
323
while True :
324
324
# read header
325
325
if self ._state == READ_HEADER :
326
- if data_length - start_pos < 2 :
326
+ if data_len - start_pos < 2 :
327
327
break
328
328
first_byte = data_cstr [start_pos ]
329
329
second_byte = data_cstr [start_pos + 1 ]
@@ -382,77 +382,88 @@ def _feed_data(self, data: bytes) -> None:
382
382
self ._frame_fin = bool (fin )
383
383
self ._frame_opcode = opcode
384
384
self ._has_mask = bool (has_mask )
385
- self ._payload_length_flag = length
385
+ self ._payload_len_flag = length
386
386
self ._state = READ_PAYLOAD_LENGTH
387
387
388
388
# read payload length
389
389
if self ._state == READ_PAYLOAD_LENGTH :
390
- length_flag = self ._payload_length_flag
391
- if length_flag == 126 :
392
- if data_length - start_pos < 2 :
390
+ len_flag = self ._payload_len_flag
391
+ if len_flag == 126 :
392
+ if data_len - start_pos < 2 :
393
393
break
394
394
first_byte = data_cstr [start_pos ]
395
395
second_byte = data_cstr [start_pos + 1 ]
396
396
start_pos += 2
397
- self ._payload_length = first_byte << 8 | second_byte
398
- elif length_flag > 126 :
399
- if data_length - start_pos < 8 :
397
+ self ._payload_bytes_to_read = first_byte << 8 | second_byte
398
+ elif len_flag > 126 :
399
+ if data_len - start_pos < 8 :
400
400
break
401
- self ._payload_length = UNPACK_LEN3 (data , start_pos )[0 ]
401
+ self ._payload_bytes_to_read = UNPACK_LEN3 (data , start_pos )[0 ]
402
402
start_pos += 8
403
403
else :
404
- self ._payload_length = length_flag
404
+ self ._payload_bytes_to_read = len_flag
405
405
406
406
self ._state = READ_PAYLOAD_MASK if self ._has_mask else READ_PAYLOAD
407
407
408
408
# read payload mask
409
409
if self ._state == READ_PAYLOAD_MASK :
410
- if data_length - start_pos < 4 :
410
+ if data_len - start_pos < 4 :
411
411
break
412
412
self ._frame_mask = data_cstr [start_pos : start_pos + 4 ]
413
413
start_pos += 4
414
414
self ._state = READ_PAYLOAD
415
415
416
416
if self ._state == READ_PAYLOAD :
417
- chunk_len = data_length - start_pos
418
- if self ._payload_length >= chunk_len :
419
- end_pos = data_length
420
- self ._payload_length -= chunk_len
417
+ chunk_len = data_len - start_pos
418
+ if self ._payload_bytes_to_read >= chunk_len :
419
+ f_end_pos = data_len
420
+ self ._payload_bytes_to_read -= chunk_len
421
421
else :
422
- end_pos = start_pos + self ._payload_length
423
- self ._payload_length = 0
424
-
425
- if self ._frame_payload_len :
426
- if type (self ._frame_payload ) is not bytearray :
427
- self ._frame_payload = bytearray (self ._frame_payload )
428
- self ._frame_payload += data_cstr [start_pos :end_pos ]
429
- else :
430
- # Fast path for the first frame
431
- self ._frame_payload = data_cstr [start_pos :end_pos ]
432
-
433
- self ._frame_payload_len += end_pos - start_pos
434
- start_pos = end_pos
435
-
436
- if self ._payload_length != 0 :
422
+ f_end_pos = start_pos + self ._payload_bytes_to_read
423
+ self ._payload_bytes_to_read = 0
424
+
425
+ had_fragments = self ._frame_payload_len
426
+ self ._frame_payload_len += f_end_pos - start_pos
427
+ f_start_pos = start_pos
428
+ start_pos = f_end_pos
429
+
430
+ if self ._payload_bytes_to_read != 0 :
431
+ # If we don't have a complete frame, we need to save the
432
+ # data for the next call to feed_data.
433
+ self ._payload_fragments .append (data_cstr [f_start_pos :f_end_pos ])
437
434
break
438
435
439
- if self ._has_mask :
436
+ payload : Union [bytes , bytearray ]
437
+ if had_fragments :
438
+ # We have to join the payload fragments get the payload
439
+ self ._payload_fragments .append (data_cstr [f_start_pos :f_end_pos ])
440
+ if self ._has_mask :
441
+ assert self ._frame_mask is not None
442
+ payload_bytearray = bytearray ()
443
+ payload_bytearray .join (self ._payload_fragments )
444
+ websocket_mask (self ._frame_mask , payload_bytearray )
445
+ payload = payload_bytearray
446
+ else :
447
+ payload = b"" .join (self ._payload_fragments )
448
+ self ._payload_fragments .clear ()
449
+ elif self ._has_mask :
440
450
assert self ._frame_mask is not None
441
- if type (self ._frame_payload ) is not bytearray :
442
- self ._frame_payload = bytearray (self ._frame_payload )
443
- websocket_mask (self ._frame_mask , self ._frame_payload )
451
+ payload_bytearray = data_cstr [f_start_pos :f_end_pos ] # type: ignore[assignment]
452
+ if type (payload_bytearray ) is not bytearray : # pragma: no branch
453
+ # Cython will do the conversion for us
454
+ # but we need to do it for Python and we
455
+ # will always get here in Python
456
+ payload_bytearray = bytearray (payload_bytearray )
457
+ websocket_mask (self ._frame_mask , payload_bytearray )
458
+ payload = payload_bytearray
459
+ else :
460
+ payload = data_cstr [f_start_pos :f_end_pos ]
444
461
445
462
self ._handle_frame (
446
- self ._frame_fin ,
447
- self ._frame_opcode ,
448
- self ._frame_payload ,
449
- self ._compressed ,
463
+ self ._frame_fin , self ._frame_opcode , payload , self ._compressed
450
464
)
451
- self ._frame_payload = b""
452
465
self ._frame_payload_len = 0
453
466
self ._state = READ_HEADER
454
467
455
468
# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
456
- self ._tail = (
457
- data_cstr [start_pos :data_length ] if start_pos < data_length else b""
458
- )
469
+ self ._tail = data_cstr [start_pos :data_len ] if start_pos < data_len else b""
0 commit comments