@@ -521,7 +521,8 @@ async def read(
521
521
self , request_id : Optional [int ], max_message_size : int , debug : bool = False
522
522
) -> tuple [bytes , int ]:
523
523
"""Read a single MongoDB Wire Protocol message from this connection."""
524
- self .transport .resume_reading ()
524
+ if self .transport :
525
+ self .transport .resume_reading ()
525
526
if self ._done_messages :
526
527
message = await self ._done_messages .popleft ()
527
528
else :
@@ -532,9 +533,10 @@ async def read(
532
533
self ._length = 0
533
534
self ._overflow_length = 0
534
535
self ._body_length = 0
536
+ self ._start = 0
535
537
self ._op_code = None # type: ignore[assignment]
536
538
self ._overflow = None
537
- if self .transport .is_closing ():
539
+ if self .transport and self . transport .is_closing ():
538
540
raise OSError ("Connection is closed" )
539
541
read_waiter = asyncio .get_running_loop ().create_future ()
540
542
self ._pending_messages .append (read_waiter )
@@ -545,10 +547,12 @@ async def read(
545
547
self ._done_messages .remove (read_waiter )
546
548
if message :
547
549
start , end , op_code = message [0 ], message [1 ], message [2 ]
548
- header_size = 16
550
+ if self ._is_compressed :
551
+ header_size = 25
552
+ else :
553
+ header_size = 16
549
554
if self ._body_length > self ._buffer_size and self ._overflow is not None :
550
555
if self ._is_compressed and self ._compressor_id is not None :
551
- header_size = 25
552
556
return decompress (
553
557
memoryview (
554
558
bytearray (self ._buffer [header_size : self ._length ])
@@ -563,7 +567,6 @@ async def read(
563
567
), op_code
564
568
else :
565
569
if self ._is_compressed and self ._compressor_id is not None :
566
- header_size = 25
567
570
return decompress (
568
571
memoryview (self ._buffer [start + header_size : end ]),
569
572
self ._compressor_id ,
@@ -594,24 +597,24 @@ def buffer_updated(self, nbytes: int) -> None:
594
597
self .connection_lost (exc )
595
598
return
596
599
self ._expecting_header = False
600
+ # TODO: account for multiple messages processed within a single read() call
597
601
if self ._body_length > self ._buffer_size :
598
602
self ._overflow = memoryview (
599
603
bytearray (self ._body_length - (self ._length + nbytes ) + 1024 )
600
604
)
601
605
self ._length += nbytes
602
- if self ._length + self ._overflow_length >= self ._body_length :
606
+ if self ._length + self ._overflow_length - self . _start >= self ._body_length :
603
607
if self ._pending_messages :
604
608
done = self ._pending_messages .popleft ()
605
609
else :
606
610
done = asyncio .get_running_loop ().create_future ()
607
- done .set_result ((self ._start , self ._body_length , self ._op_code ))
608
- self ._start = 0
611
+ done .set_result ((self ._start , self ._body_length + self . _start , self ._op_code ))
612
+ self ._start += self . _body_length
609
613
self ._done_messages .append (done )
610
- if self ._length > self ._body_length :
614
+ if self ._length - self . _start > self ._body_length :
611
615
self ._read_waiter = asyncio .get_running_loop ().create_future ()
612
616
self ._pending_messages .append (self ._read_waiter )
613
- self ._start = self ._body_length
614
- extra = self ._length - self ._body_length
617
+ extra = self ._length - self ._start
615
618
self ._length -= extra
616
619
self ._expecting_header = True
617
620
self ._body_length = 0
@@ -621,7 +624,9 @@ def buffer_updated(self, nbytes: int) -> None:
621
624
622
625
def process_header (self ) -> tuple [int , int ]:
623
626
"""Unpack a MongoDB Wire Protocol header."""
624
- length , _ , response_to , op_code = _UNPACK_HEADER (self ._buffer [self ._start : 16 ])
627
+ length , _ , response_to , op_code = _UNPACK_HEADER (
628
+ self ._buffer [self ._start : self ._start + 16 ]
629
+ )
625
630
# No request_id for exhaust cursor "getMore".
626
631
if self ._request_id is not None :
627
632
if self ._request_id != response_to :
@@ -640,7 +645,9 @@ def process_header(self) -> tuple[int, int]:
640
645
if op_code == 2012 :
641
646
self ._is_compressed = True
642
647
if self ._length >= 25 :
643
- op_code , _ , self ._compressor_id = _UNPACK_COMPRESSION_HEADER (self ._buffer [16 :25 ])
648
+ op_code , _ , self ._compressor_id = _UNPACK_COMPRESSION_HEADER (
649
+ self ._buffer [self ._start + 16 : self ._start + 25 ]
650
+ )
644
651
else :
645
652
self ._need_compression_header = True
646
653
0 commit comments