@@ -483,7 +483,6 @@ def __init__(self, timeout: Optional[float] = None, buffer_size: int = 2**14):
483
483
self ._timeout = timeout
484
484
self ._is_compressed = False
485
485
self ._compressor_id : Optional [int ] = None
486
- self ._need_compression_header = False
487
486
self ._max_message_size = MAX_MESSAGE_SIZE
488
487
self ._request_id : Optional [int ] = None
489
488
self ._closed = asyncio .get_running_loop ().create_future ()
@@ -539,36 +538,38 @@ async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[
539
538
if read_waiter in self ._done_messages :
540
539
self ._done_messages .remove (read_waiter )
541
540
if message :
542
- start , end , op_code , overflow , overflow_index = (
541
+ start , end , op_code , is_compressed , compressor_id , overflow , overflow_index = (
543
542
message [0 ],
544
543
message [1 ],
545
544
message [2 ],
546
545
message [3 ],
547
546
message [4 ],
547
+ message [5 ],
548
+ message [6 ],
548
549
)
549
- if self . _is_compressed :
550
+ if is_compressed :
550
551
header_size = 25
551
552
else :
552
553
header_size = 16
553
554
if overflow is not None :
554
- if self . _is_compressed and self . _compressor_id is not None :
555
+ if is_compressed and compressor_id is not None :
555
556
return decompress (
556
557
memoryview (
557
558
bytearray (self ._buffer [start + header_size : self ._end_index ])
558
559
+ bytearray (overflow [:overflow_index ])
559
560
),
560
- self . _compressor_id ,
561
+ compressor_id ,
561
562
), op_code
562
563
else :
563
564
return memoryview (
564
565
bytearray (self ._buffer [start + header_size : self ._end_index ])
565
566
+ bytearray (overflow [:overflow_index ])
566
567
), op_code
567
568
else :
568
- if self . _is_compressed and self . _compressor_id is not None :
569
+ if is_compressed and compressor_id is not None :
569
570
return decompress (
570
571
memoryview (self ._buffer [start + header_size : end ]),
571
- self . _compressor_id ,
572
+ compressor_id ,
572
573
), op_code
573
574
else :
574
575
return memoryview (self ._buffer [start + header_size : end ]), op_code
@@ -624,6 +625,8 @@ def buffer_updated(self, nbytes: int) -> None:
624
625
self ._start_index ,
625
626
self ._body_size + self ._start_index ,
626
627
self ._op_code ,
628
+ self ._is_compressed ,
629
+ self ._compressor_id ,
627
630
self ._overflow ,
628
631
self ._overflow_index ,
629
632
)
@@ -635,18 +638,20 @@ def buffer_updated(self, nbytes: int) -> None:
635
638
else :
636
639
self ._start_index += self ._body_size
637
640
self ._done_messages .append (result )
641
+ # Reset internal state to expect a new message
642
+ self ._expecting_header = True
643
+ self ._body_size = 0
644
+ self ._op_code = None # type: ignore[assignment]
645
+ self ._overflow = None
646
+ self ._overflow_index = 0
647
+ self ._is_compressed = False
648
+ self ._compressor_id = None
638
649
# If at least one header's worth of data remains after the current message, reprocess all leftover data
639
650
if self ._end_index - self ._start_index >= 16 :
640
651
self ._read_waiter = asyncio .get_running_loop ().create_future ()
641
652
self ._pending_messages .append (self ._read_waiter )
642
653
nbytes_reprocess = self ._end_index - self ._start_index
643
654
self ._end_index -= nbytes_reprocess
644
- # Reset internal state to expect a new message
645
- self ._expecting_header = True
646
- self ._body_size = 0
647
- self ._op_code = None # type: ignore[assignment]
648
- self ._overflow = None
649
- self ._overflow_index = 0
650
655
self .buffer_updated (nbytes_reprocess )
651
656
# Pause reading to avoid storing an arbitrary number of messages in memory before necessary
652
657
self .transport .pause_reading ()
@@ -673,12 +678,9 @@ def process_header(self) -> tuple[int, int]:
673
678
)
674
679
if op_code == 2012 :
675
680
self ._is_compressed = True
676
- if self ._end_index >= 25 :
677
- op_code , _ , self ._compressor_id = _UNPACK_COMPRESSION_HEADER (
678
- self ._buffer [self ._start_index + 16 : self ._start_index + 25 ]
679
- )
680
- else :
681
- self ._need_compression_header = True
681
+ op_code , _ , self ._compressor_id = _UNPACK_COMPRESSION_HEADER (
682
+ self ._buffer [self ._start_index + 16 : self ._start_index + 25 ]
683
+ )
682
684
683
685
return length , op_code
684
686
0 commit comments