@@ -474,10 +474,10 @@ def __init__(self, timeout: Optional[float] = None, buffer_size: int = 2**14):
474
474
self .transport : Transport = None # type: ignore[assignment]
475
475
self ._buffer = memoryview (bytearray (self ._buffer_size ))
476
476
self ._overflow : Optional [memoryview ] = None
477
- self ._start = 0
478
- self ._length = 0
479
- self ._overflow_length = 0
480
- self ._body_length = 0
477
+ self ._start_index = 0
478
+ self ._end_index = 0
479
+ self ._overflow_index = 0
480
+ self ._body_size = 0
481
481
self ._op_code : int = None # type: ignore[assignment]
482
482
self ._connection_lost = False
483
483
self ._paused = False
@@ -490,7 +490,6 @@ def __init__(self, timeout: Optional[float] = None, buffer_size: int = 2**14):
490
490
self ._max_message_size = MAX_MESSAGE_SIZE
491
491
self ._request_id : Optional [int ] = None
492
492
self ._closed = asyncio .get_running_loop ().create_future ()
493
- self ._debug = False
494
493
self ._expecting_header = True
495
494
self ._pending_messages : collections .deque [Future ] = collections .deque ()
496
495
self ._done_messages : collections .deque [Future ] = collections .deque ()
@@ -517,23 +516,20 @@ async def write(self, message: bytes) -> None:
517
516
await self ._drain_helper ()
518
517
self .transport .resume_reading ()
519
518
520
- async def read (
521
- self , request_id : Optional [int ], max_message_size : int , debug : bool = False
522
- ) -> tuple [bytes , int ]:
519
+ async def read (self , request_id : Optional [int ], max_message_size : int ) -> tuple [bytes , int ]:
523
520
"""Read a single MongoDB Wire Protocol message from this connection."""
524
521
if self .transport :
525
522
self .transport .resume_reading ()
526
523
if self ._done_messages :
527
524
message = await self ._done_messages .popleft ()
528
525
else :
529
526
self ._expecting_header = True
530
- self ._debug = debug
531
527
self ._max_message_size = max_message_size
532
528
self ._request_id = request_id
533
- self ._length = 0
534
- self ._overflow_length = 0
535
- self ._body_length = 0
536
- self ._start = 0
529
+ self ._end_index = 0
530
+ self ._overflow_index = 0
531
+ self ._body_size = 0
532
+ self ._start_index = 0
537
533
self ._op_code = None # type: ignore[assignment]
538
534
self ._overflow = None
539
535
if self .transport and self .transport .is_closing ():
@@ -546,24 +542,30 @@ async def read(
546
542
if read_waiter in self ._done_messages :
547
543
self ._done_messages .remove (read_waiter )
548
544
if message :
549
- start , end , op_code = message [0 ], message [1 ], message [2 ]
545
+ start , end , op_code , overflow , overflow_index = (
546
+ message [0 ],
547
+ message [1 ],
548
+ message [2 ],
549
+ message [3 ],
550
+ message [4 ],
551
+ )
550
552
if self ._is_compressed :
551
553
header_size = 25
552
554
else :
553
555
header_size = 16
554
- if self . _body_length > self . _buffer_size and self . _overflow is not None :
556
+ if overflow is not None :
555
557
if self ._is_compressed and self ._compressor_id is not None :
556
558
return decompress (
557
559
memoryview (
558
- bytearray (self ._buffer [header_size : self ._length ])
559
- + bytearray (self . _overflow [: self . _overflow_length ])
560
+ bytearray (self ._buffer [start + header_size : self ._end_index ])
561
+ + bytearray (overflow [: overflow_index ])
560
562
),
561
563
self ._compressor_id ,
562
564
), op_code
563
565
else :
564
566
return memoryview (
565
- bytearray (self ._buffer [header_size : self ._length ])
566
- + bytearray (self . _overflow [: self . _overflow_length ])
567
+ bytearray (self ._buffer [start + header_size : self ._end_index ])
568
+ + bytearray (overflow [: overflow_index ])
567
569
), op_code
568
570
else :
569
571
if self ._is_compressed and self ._compressor_id is not None :
@@ -578,55 +580,75 @@ async def read(
578
580
def get_buffer (self , sizehint : int ) -> memoryview :
579
581
"""Called to allocate a new receive buffer."""
580
582
if self ._overflow is not None :
581
- return self ._overflow [self ._overflow_length :]
582
- return self ._buffer [self ._length :]
583
+ return self ._overflow [self ._overflow_index :]
584
+ return self ._buffer [self ._end_index :]
583
585
584
586
def buffer_updated (self , nbytes : int ) -> None :
585
587
"""Called when the buffer was updated with the received data"""
588
+ # Wrote 0 bytes into a non-empty buffer, signal connection closed
586
589
if nbytes == 0 :
587
590
self .connection_lost (OSError ("connection closed" ))
588
591
return
589
592
else :
593
+ # Wrote data into overflow buffer
590
594
if self ._overflow is not None :
591
- self ._overflow_length += nbytes
595
+ self ._overflow_index += nbytes
596
+ # Wrote data into default buffer
592
597
else :
593
598
if self ._expecting_header :
594
599
try :
595
- self ._body_length , self ._op_code = self .process_header ()
600
+ self ._body_size , self ._op_code = self .process_header ()
596
601
except ProtocolError as exc :
597
602
self .connection_lost (exc )
598
603
return
599
604
self ._expecting_header = False
600
- # TODO: account for multiple messages processed within a single read() call
601
- if self ._body_length > self ._buffer_size :
602
- self ._overflow = memoryview (
603
- bytearray (self ._body_length - (self ._length + nbytes ) + 1024 )
604
- )
605
- self ._length += nbytes
606
- if self ._length + self ._overflow_length - self ._start >= self ._body_length :
605
+ # The new message's data is too large for the default buffer, allocate a new buffer for this message
606
+ if self ._body_size > self ._buffer_size - (self ._end_index + nbytes ):
607
+ self ._overflow = memoryview (bytearray (self ._body_size ))
608
+ self ._end_index += nbytes
609
+ # All data of the current message has been received
610
+ if self ._end_index + self ._overflow_index - self ._start_index >= self ._body_size :
607
611
if self ._pending_messages :
608
- done = self ._pending_messages .popleft ()
612
+ result = self ._pending_messages .popleft ()
609
613
else :
610
- done = asyncio .get_running_loop ().create_future ()
611
- done .set_result ((self ._start , self ._body_length + self ._start , self ._op_code ))
612
- self ._start += self ._body_length
613
- self ._done_messages .append (done )
614
- # If we have more data after processing the last message, start processing a new message
615
- if self ._length - self ._start > 0 :
614
+ result = asyncio .get_running_loop ().create_future ()
615
+ # Future has been cancelled, close this connection
616
+ if result .done ():
617
+ self .connection_lost (None )
618
+ return
619
+ # Necessary values to construct message from buffers
620
+ result .set_result (
621
+ (
622
+ self ._start_index ,
623
+ self ._body_size + self ._start_index ,
624
+ self ._op_code ,
625
+ self ._overflow ,
626
+ self ._overflow_index ,
627
+ )
628
+ )
629
+ # Update the buffer's first written offset to reflect this message's size
630
+ self ._start_index += self ._body_size
631
+ self ._done_messages .append (result )
632
+ # If at least one header's worth of data remains after the current message, reprocess all leftover data
633
+ if self ._end_index - self ._start_index >= 16 :
616
634
self ._read_waiter = asyncio .get_running_loop ().create_future ()
617
635
self ._pending_messages .append (self ._read_waiter )
618
- extra = self ._length - self ._start
619
- self ._length -= extra
636
+ nbytes_reprocess = self ._end_index - self ._start_index
637
+ self ._end_index -= nbytes_reprocess
638
+ # Reset internal state to expect a new message
620
639
self ._expecting_header = True
621
- self ._body_length = 0
640
+ self ._body_size = 0
622
641
self ._op_code = None # type: ignore[assignment]
623
- self .buffer_updated (extra )
642
+ self ._overflow = None
643
+ self ._overflow_index = 0
644
+ self .buffer_updated (nbytes_reprocess )
645
+ # Pause reading to avoid storing an arbitrary number of messages in memory before necessary
624
646
self .transport .pause_reading ()
625
647
626
648
def process_header (self ) -> tuple [int , int ]:
627
649
"""Unpack a MongoDB Wire Protocol header."""
628
650
length , _ , response_to , op_code = _UNPACK_HEADER (
629
- self ._buffer [self ._start : self ._start + 16 ]
651
+ self ._buffer [self ._start_index : self ._start_index + 16 ]
630
652
)
631
653
# No request_id for exhaust cursor "getMore".
632
654
if self ._request_id is not None :
@@ -645,9 +667,9 @@ def process_header(self) -> tuple[int, int]:
645
667
)
646
668
if op_code == 2012 :
647
669
self ._is_compressed = True
648
- if self ._length >= 25 :
670
+ if self ._end_index >= 25 :
649
671
op_code , _ , self ._compressor_id = _UNPACK_COMPRESSION_HEADER (
650
- self ._buffer [self ._start + 16 : self ._start + 25 ]
672
+ self ._buffer [self ._start_index + 16 : self ._start_index + 25 ]
651
673
)
652
674
else :
653
675
self ._need_compression_header = True
@@ -719,7 +741,6 @@ async def async_receive_message(
719
741
conn : AsyncConnection ,
720
742
request_id : Optional [int ],
721
743
max_message_size : int = MAX_MESSAGE_SIZE ,
722
- debug : bool = False ,
723
744
) -> Union [_OpReply , _OpMsg ]:
724
745
"""Receive a raw BSON message or raise socket.error."""
725
746
timeout : Optional [Union [float , int ]]
@@ -738,7 +759,7 @@ async def async_receive_message(
738
759
timeout = max (deadline - time .monotonic (), 0 )
739
760
740
761
cancellation_task = create_task (_poll_cancellation (conn ))
741
- read_task = create_task (conn .conn .get_conn .read (request_id , max_message_size , debug ))
762
+ read_task = create_task (conn .conn .get_conn .read (request_id , max_message_size ))
742
763
tasks = [read_task , cancellation_task ]
743
764
try :
744
765
done , pending = await asyncio .wait (
0 commit comments