@@ -466,27 +466,29 @@ def recv_into(self, buffer: bytes) -> int:
466
466
467
467
468
468
class PyMongoProtocol (BufferedProtocol ):
469
- def __init__ (self , timeout : Optional [float ] = None , buffer_size : int = 2 ** 14 ):
470
- self ._buffer_size = buffer_size
469
+ def __init__ (self , timeout : Optional [float ] = None , buffer_size : int = 2 ** 10 ):
471
470
self .transport : Transport = None # type: ignore[assignment]
472
- self ._buffer = memoryview (bytearray (self ._buffer_size ))
473
- self ._overflow : Optional [memoryview ] = None
474
- self ._start_index = 0
475
- self ._end_index = 0
476
- self ._overflow_index = 0
477
- self ._body_size = 0
478
- self ._op_code : int = None # type: ignore[assignment]
471
+ # Each message is reader in 2-3 parts: header, compression header, and message body
472
+ # The message buffer is allocated after the header is read.
473
+ self ._header = memoryview (bytearray (16 ))
474
+ self ._header_index = 0
475
+ self ._compression_header = memoryview (bytearray (9 ))
476
+ self ._compression_index = 0
477
+ self ._message : Optional [memoryview ] = None
478
+ self ._message_index = 0
479
+ # State. TODO: replace booleans with an enum?
480
+ self ._expecting_header = True
481
+ self ._expecting_compression = False
482
+ self ._message_size = 0
483
+ self ._op_code = 0
479
484
self ._connection_lost = False
480
- self ._paused = False
481
- self ._drain_waiter : Optional [Future ] = None
482
485
self ._read_waiter : Optional [Future ] = None
483
486
self ._timeout = timeout
484
487
self ._is_compressed = False
485
488
self ._compressor_id : Optional [int ] = None
486
489
self ._max_message_size = MAX_MESSAGE_SIZE
487
490
self ._request_id : Optional [int ] = None
488
491
self ._closed = asyncio .get_running_loop ().create_future ()
489
- self ._expecting_header = True
490
492
self ._pending_messages : collections .deque [Future ] = collections .deque ()
491
493
self ._done_messages : collections .deque [Future ] = collections .deque ()
492
494
@@ -515,8 +517,6 @@ async def write(self, message: bytes) -> None:
515
517
except AttributeError :
516
518
raise OSError ("connection is already closed" ) from None
517
519
self .transport .write (message )
518
- # await self._drain_helper()
519
- # self.transport.resume_reading()
520
520
521
521
async def read (self , request_id : Optional [int ], max_message_size : int ) -> tuple [bytes , int ]:
522
522
"""Read a single MongoDB Wire Protocol message from this connection."""
@@ -526,20 +526,11 @@ async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[
526
526
# Known bug in SSL Protocols, fixed in Python 3.11: https://github.com/python/cpython/issues/89322
527
527
except AttributeError :
528
528
raise OSError ("connection is already closed" ) from None
529
+ self ._max_message_size = max_message_size
530
+ self ._request_id = request_id
529
531
if self ._done_messages :
530
532
message = await self ._done_messages .popleft ()
531
533
else :
532
- self ._expecting_header = True
533
- self ._max_message_size = max_message_size
534
- self ._request_id = request_id
535
- self ._end_index = 0
536
- self ._overflow_index = 0
537
- self ._body_size = 0
538
- self ._start_index = 0
539
- self ._op_code = None # type: ignore[assignment]
540
- self ._overflow = None
541
- self ._is_compressed = False
542
- self ._compressor_id = None
543
534
if self .transport and self .transport .is_closing ():
544
535
raise OSError ("connection is already closed" )
545
536
read_waiter = asyncio .get_running_loop ().create_future ()
@@ -550,39 +541,10 @@ async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[
550
541
if read_waiter in self ._done_messages :
551
542
self ._done_messages .remove (read_waiter )
552
543
if message :
553
- start , end , op_code , is_compressed , compressor_id , overflow , overflow_index = (
554
- message [0 ],
555
- message [1 ],
556
- message [2 ],
557
- message [3 ],
558
- message [4 ],
559
- message [5 ],
560
- message [6 ],
561
- )
562
- if is_compressed :
563
- header_size = 25
564
- else :
565
- header_size = 16
566
- if overflow is not None :
567
- if is_compressed and compressor_id is not None :
568
- return decompress (
569
- self ._buffer [start + header_size : self ._end_index ].tobytes ()
570
- + overflow [:overflow_index ].tobytes (),
571
- compressor_id ,
572
- ), op_code
573
- else :
574
- return (
575
- self ._buffer [start + header_size : self ._end_index ].tobytes ()
576
- + overflow [:overflow_index ].tobytes ()
577
- ), op_code
578
- else :
579
- if is_compressed and compressor_id is not None :
580
- return decompress (
581
- self ._buffer [start + header_size : end ],
582
- compressor_id ,
583
- ), op_code
584
- else :
585
- return self ._buffer [start + header_size : end ].tobytes (), op_code
544
+ op_code , compressor_id , data = message
545
+ if compressor_id is not None :
546
+ data = decompress (data , compressor_id )
547
+ return data , op_code
586
548
raise OSError ("connection closed" )
587
549
588
550
def get_buffer (self , sizehint : int ) -> memoryview :
@@ -592,88 +554,67 @@ def get_buffer(self, sizehint: int) -> memoryview:
592
554
either no data remains or an empty buffer is returned.
593
555
"""
594
556
# Check for SSL EOF edge case, no data will be written to the buffer we return
557
+ # TODO: is this needed?
595
558
if sizehint == 0 :
596
559
return memoryview (bytearray (16 ))
597
- if self ._overflow is not None :
598
- return self ._overflow [self ._overflow_index :]
599
- return self ._buffer [self ._end_index :]
560
+ # TODO: optimize this by caching pointers to the buffers.
561
+ # return self._buffer[self._index:]
562
+ if self ._expecting_header :
563
+ return self ._header [self ._header_index :]
564
+ if self ._expecting_compression :
565
+ return self ._compression_header [self ._compression_index :]
566
+ return self ._message [self ._message_index :]
600
567
601
568
def buffer_updated (self , nbytes : int ) -> None :
602
569
"""Called when the buffer was updated with the received data"""
603
570
# Wrote 0 bytes into a non-empty buffer, signal connection closed
604
571
if nbytes == 0 :
605
572
self .connection_lost (OSError ("connection closed" ))
606
573
return
607
- else :
608
- # Wrote data into overflow buffer
609
- if self ._overflow is not None :
610
- self ._overflow_index += nbytes
611
- # Wrote data into default buffer
612
- else :
613
- if self ._expecting_header :
614
- try :
615
- self ._body_size , self ._op_code = self .process_header ()
616
- self ._request_id = None
617
- except ProtocolError as exc :
618
- self .connection_lost (exc )
619
- return
620
- self ._expecting_header = False
621
- # The new message's data is too large for the default buffer, allocate a new buffer for this message
622
- if self ._body_size > self ._buffer_size - (self ._end_index + nbytes ):
623
- self ._overflow = memoryview (bytearray (self ._body_size ))
624
- self ._end_index += nbytes
625
- # All data of the current message has been received
626
- if self ._end_index + self ._overflow_index - self ._start_index >= self ._body_size :
627
- # Pause reading to avoid storing an arbitrary number of messages in memory before necessary
628
- self .transport .pause_reading ()
629
- if self ._pending_messages :
630
- result = self ._pending_messages .popleft ()
631
- else :
632
- result = asyncio .get_running_loop ().create_future ()
633
- # Future has been cancelled, close this connection
634
- if result .done ():
635
- self .connection_lost (None )
574
+ if self ._expecting_header :
575
+ self ._header_index += nbytes
576
+ if self ._header_index >= 16 :
577
+ try :
578
+ self ._message_size , self ._op_code = self .process_header ()
579
+ except ProtocolError as exc :
580
+ self .connection_lost (exc )
636
581
return
637
- # Necessary values to construct message from buffers
638
- result .set_result (
639
- (
640
- self ._start_index ,
641
- self ._body_size + self ._start_index ,
642
- self ._op_code ,
643
- self ._is_compressed ,
644
- self ._compressor_id ,
645
- self ._overflow ,
646
- self ._overflow_index ,
647
- )
648
- )
649
- # If the current message has an overflow buffer, then the entire default buffer is full
650
- if self ._overflow :
651
- self ._start_index = self ._end_index
652
- # Update the buffer's first written offset to reflect this message's size
653
- else :
654
- self ._start_index += self ._body_size
655
- self ._done_messages .append (result )
656
- # Reset internal state to expect a new message
657
- self ._expecting_header = True
658
- self ._body_size = 0
659
- self ._op_code = None # type: ignore[assignment]
660
- self ._overflow = None
661
- self ._overflow_index = 0
662
- self ._is_compressed = False
663
- self ._compressor_id = None
664
- # If at least one header's worth of data remains after the current message, reprocess all leftover data
665
- if self ._end_index - self ._start_index >= 16 :
666
- self ._read_waiter = asyncio .get_running_loop ().create_future ()
667
- self ._pending_messages .append (self ._read_waiter )
668
- nbytes_reprocess = self ._end_index - self ._start_index
669
- self ._end_index -= nbytes_reprocess
670
- self .buffer_updated (nbytes_reprocess )
582
+ self ._message = memoryview (bytearray (self ._message_size ))
583
+ return
584
+ if self ._expecting_compression :
585
+ self ._compression_index += nbytes
586
+ if self ._compression_index >= 9 :
587
+ self ._op_code = self .process_compression_header ()
588
+ return
589
+
590
+ self ._message_index += nbytes
591
+ if self ._message_index >= self ._message_size :
592
+ # Pause reading to avoid storing an arbitrary number of messages in memory.
593
+ self .transport .pause_reading ()
594
+ if self ._pending_messages :
595
+ result = self ._pending_messages .popleft ()
596
+ else :
597
+ result = asyncio .get_running_loop ().create_future ()
598
+ # Future has been cancelled, close this connection
599
+ if result .done ():
600
+ self .connection_lost (None )
601
+ return
602
+ # Necessary values to construct message from buffers
603
+ result .set_result ((self ._op_code , self ._compressor_id , self ._message ))
604
+ self ._done_messages .append (result )
605
+ # Reset internal state to expect a new message
606
+ self ._expecting_header = True
607
+ self ._header_index = 0
608
+ self ._compression_index = 0
609
+ self ._message_index = 0
610
+ self ._message_size = 0
611
+ self ._message = None
612
+ self ._op_code = 0
613
+ self ._compressor_id = None
671
614
672
615
def process_header (self ) -> tuple [int , int ]:
673
616
"""Unpack a MongoDB Wire Protocol header."""
674
- length , _ , response_to , op_code = _UNPACK_HEADER (
675
- self ._buffer [self ._start_index : self ._start_index + 16 ]
676
- )
617
+ length , _ , response_to , op_code = _UNPACK_HEADER (self ._header )
677
618
# No request_id for exhaust cursor "getMore".
678
619
if self ._request_id is not None :
679
620
if self ._request_id != response_to :
@@ -689,24 +630,22 @@ def process_header(self) -> tuple[int, int]:
689
630
f"Message length ({ length !r} ) is larger than server max "
690
631
f"message size ({ self ._max_message_size !r} )"
691
632
)
692
- if op_code == 2012 :
693
- self . _is_compressed = True
694
- op_code , _ , self . _compressor_id = _UNPACK_COMPRESSION_HEADER (
695
- self . _buffer [ self . _start_index + 16 : self . _start_index + 25 ]
696
- )
697
-
698
- return length , op_code
633
+ if op_code == 2012 : # OP_COMPRESSED
634
+ if length <= 25 :
635
+ raise ProtocolError (
636
+ f"Message length ( { length !r } ) not longer than standard OP_COMPRESSED message header size (25)"
637
+ )
638
+ self . _expecting_compression = True
639
+ length -= 9
699
640
700
- def pause_writing (self ) -> None :
701
- assert not self ._paused
702
- self ._paused = True
641
+ self ._expecting_header = False
642
+ return length - 16 , op_code
703
643
704
- def resume_writing (self ) -> None :
705
- assert self ._paused
706
- self ._paused = False
707
- #
708
- # if self._drain_waiter and not self._drain_waiter.done():
709
- # self._drain_waiter.set_result(None)
644
+ def process_compression_header (self ) -> tuple [int , int ]:
645
+ """Unpack a MongoDB Wire Protocol header."""
646
+ op_code , _ , self ._compressor_id = _UNPACK_COMPRESSION_HEADER (self ._compression_header )
647
+ self ._expecting_compression = False
648
+ return op_code
710
649
711
650
def connection_lost (self , exc : Exception | None ) -> None :
712
651
self ._connection_lost = True
@@ -725,27 +664,6 @@ def connection_lost(self, exc: Exception | None) -> None:
725
664
else :
726
665
self ._closed .set_exception (exc )
727
666
728
- # # Wake up the writer(s) if currently paused.
729
- # if not self._paused:
730
- # return
731
- #
732
- # if self._drain_waiter and not self._drain_waiter.done():
733
- # if exc is None:
734
- # self._drain_waiter.set_result(None)
735
- # else:
736
- # self._drain_waiter.set_exception(exc)
737
-
738
- async def _drain_helper (self ) -> None :
739
- if self ._connection_lost :
740
- raise ConnectionResetError ("Connection lost" )
741
- if not self ._paused :
742
- return
743
- self ._drain_waiter = asyncio .get_running_loop ().create_future ()
744
- await self ._drain_waiter
745
-
746
- def data (self ) -> bytes :
747
- return self ._buffer .tobytes ()
748
-
749
667
async def wait_closed (self ) -> None :
750
668
await asyncio .wait ([self ._closed ])
751
669
0 commit comments