@@ -466,27 +466,29 @@ def recv_into(self, buffer: bytes) -> int:
466466
467467
468468class PyMongoProtocol (BufferedProtocol ):
469- def __init__ (self , timeout : Optional [float ] = None , buffer_size : int = 2 ** 14 ):
470- self ._buffer_size = buffer_size
469+ def __init__ (self , timeout : Optional [float ] = None , buffer_size : int = 2 ** 10 ):
471470 self .transport : Transport = None # type: ignore[assignment]
472- self ._buffer = memoryview (bytearray (self ._buffer_size ))
473- self ._overflow : Optional [memoryview ] = None
474- self ._start_index = 0
475- self ._end_index = 0
476- self ._overflow_index = 0
477- self ._body_size = 0
478- self ._op_code : int = None # type: ignore[assignment]
471+ # Each message is reader in 2-3 parts: header, compression header, and message body
472+ # The message buffer is allocated after the header is read.
473+ self ._header = memoryview (bytearray (16 ))
474+ self ._header_index = 0
475+ self ._compression_header = memoryview (bytearray (9 ))
476+ self ._compression_index = 0
477+ self ._message : Optional [memoryview ] = None
478+ self ._message_index = 0
479+ # State. TODO: replace booleans with an enum?
480+ self ._expecting_header = True
481+ self ._expecting_compression = False
482+ self ._message_size = 0
483+ self ._op_code = 0
479484 self ._connection_lost = False
480- self ._paused = False
481- self ._drain_waiter : Optional [Future ] = None
482485 self ._read_waiter : Optional [Future ] = None
483486 self ._timeout = timeout
484487 self ._is_compressed = False
485488 self ._compressor_id : Optional [int ] = None
486489 self ._max_message_size = MAX_MESSAGE_SIZE
487490 self ._request_id : Optional [int ] = None
488491 self ._closed = asyncio .get_running_loop ().create_future ()
489- self ._expecting_header = True
490492 self ._pending_messages : collections .deque [Future ] = collections .deque ()
491493 self ._done_messages : collections .deque [Future ] = collections .deque ()
492494
@@ -515,8 +517,6 @@ async def write(self, message: bytes) -> None:
515517 except AttributeError :
516518 raise OSError ("connection is already closed" ) from None
517519 self .transport .write (message )
518- # await self._drain_helper()
519- # self.transport.resume_reading()
520520
521521 async def read (self , request_id : Optional [int ], max_message_size : int ) -> tuple [bytes , int ]:
522522 """Read a single MongoDB Wire Protocol message from this connection."""
@@ -526,20 +526,11 @@ async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[
526526 # Known bug in SSL Protocols, fixed in Python 3.11: https://github.com/python/cpython/issues/89322
527527 except AttributeError :
528528 raise OSError ("connection is already closed" ) from None
529+ self ._max_message_size = max_message_size
530+ self ._request_id = request_id
529531 if self ._done_messages :
530532 message = await self ._done_messages .popleft ()
531533 else :
532- self ._expecting_header = True
533- self ._max_message_size = max_message_size
534- self ._request_id = request_id
535- self ._end_index = 0
536- self ._overflow_index = 0
537- self ._body_size = 0
538- self ._start_index = 0
539- self ._op_code = None # type: ignore[assignment]
540- self ._overflow = None
541- self ._is_compressed = False
542- self ._compressor_id = None
543534 if self .transport and self .transport .is_closing ():
544535 raise OSError ("connection is already closed" )
545536 read_waiter = asyncio .get_running_loop ().create_future ()
@@ -550,39 +541,10 @@ async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[
550541 if read_waiter in self ._done_messages :
551542 self ._done_messages .remove (read_waiter )
552543 if message :
553- start , end , op_code , is_compressed , compressor_id , overflow , overflow_index = (
554- message [0 ],
555- message [1 ],
556- message [2 ],
557- message [3 ],
558- message [4 ],
559- message [5 ],
560- message [6 ],
561- )
562- if is_compressed :
563- header_size = 25
564- else :
565- header_size = 16
566- if overflow is not None :
567- if is_compressed and compressor_id is not None :
568- return decompress (
569- self ._buffer [start + header_size : self ._end_index ].tobytes ()
570- + overflow [:overflow_index ].tobytes (),
571- compressor_id ,
572- ), op_code
573- else :
574- return (
575- self ._buffer [start + header_size : self ._end_index ].tobytes ()
576- + overflow [:overflow_index ].tobytes ()
577- ), op_code
578- else :
579- if is_compressed and compressor_id is not None :
580- return decompress (
581- self ._buffer [start + header_size : end ],
582- compressor_id ,
583- ), op_code
584- else :
585- return self ._buffer [start + header_size : end ].tobytes (), op_code
544+ op_code , compressor_id , data = message
545+ if compressor_id is not None :
546+ data = decompress (data , compressor_id )
547+ return data , op_code
586548 raise OSError ("connection closed" )
587549
588550 def get_buffer (self , sizehint : int ) -> memoryview :
@@ -592,88 +554,67 @@ def get_buffer(self, sizehint: int) -> memoryview:
592554 either no data remains or an empty buffer is returned.
593555 """
594556 # Check for SSL EOF edge case, no data will be written to the buffer we return
557+ # TODO: is this needed?
595558 if sizehint == 0 :
596559 return memoryview (bytearray (16 ))
597- if self ._overflow is not None :
598- return self ._overflow [self ._overflow_index :]
599- return self ._buffer [self ._end_index :]
560+ # TODO: optimize this by caching pointers to the buffers.
561+ # return self._buffer[self._index:]
562+ if self ._expecting_header :
563+ return self ._header [self ._header_index :]
564+ if self ._expecting_compression :
565+ return self ._compression_header [self ._compression_index :]
566+ return self ._message [self ._message_index :]
600567
601568 def buffer_updated (self , nbytes : int ) -> None :
602569 """Called when the buffer was updated with the received data"""
603570 # Wrote 0 bytes into a non-empty buffer, signal connection closed
604571 if nbytes == 0 :
605572 self .connection_lost (OSError ("connection closed" ))
606573 return
607- else :
608- # Wrote data into overflow buffer
609- if self ._overflow is not None :
610- self ._overflow_index += nbytes
611- # Wrote data into default buffer
612- else :
613- if self ._expecting_header :
614- try :
615- self ._body_size , self ._op_code = self .process_header ()
616- self ._request_id = None
617- except ProtocolError as exc :
618- self .connection_lost (exc )
619- return
620- self ._expecting_header = False
621- # The new message's data is too large for the default buffer, allocate a new buffer for this message
622- if self ._body_size > self ._buffer_size - (self ._end_index + nbytes ):
623- self ._overflow = memoryview (bytearray (self ._body_size ))
624- self ._end_index += nbytes
625- # All data of the current message has been received
626- if self ._end_index + self ._overflow_index - self ._start_index >= self ._body_size :
627- # Pause reading to avoid storing an arbitrary number of messages in memory before necessary
628- self .transport .pause_reading ()
629- if self ._pending_messages :
630- result = self ._pending_messages .popleft ()
631- else :
632- result = asyncio .get_running_loop ().create_future ()
633- # Future has been cancelled, close this connection
634- if result .done ():
635- self .connection_lost (None )
574+ if self ._expecting_header :
575+ self ._header_index += nbytes
576+ if self ._header_index >= 16 :
577+ try :
578+ self ._message_size , self ._op_code = self .process_header ()
579+ except ProtocolError as exc :
580+ self .connection_lost (exc )
636581 return
637- # Necessary values to construct message from buffers
638- result .set_result (
639- (
640- self ._start_index ,
641- self ._body_size + self ._start_index ,
642- self ._op_code ,
643- self ._is_compressed ,
644- self ._compressor_id ,
645- self ._overflow ,
646- self ._overflow_index ,
647- )
648- )
649- # If the current message has an overflow buffer, then the entire default buffer is full
650- if self ._overflow :
651- self ._start_index = self ._end_index
652- # Update the buffer's first written offset to reflect this message's size
653- else :
654- self ._start_index += self ._body_size
655- self ._done_messages .append (result )
656- # Reset internal state to expect a new message
657- self ._expecting_header = True
658- self ._body_size = 0
659- self ._op_code = None # type: ignore[assignment]
660- self ._overflow = None
661- self ._overflow_index = 0
662- self ._is_compressed = False
663- self ._compressor_id = None
664- # If at least one header's worth of data remains after the current message, reprocess all leftover data
665- if self ._end_index - self ._start_index >= 16 :
666- self ._read_waiter = asyncio .get_running_loop ().create_future ()
667- self ._pending_messages .append (self ._read_waiter )
668- nbytes_reprocess = self ._end_index - self ._start_index
669- self ._end_index -= nbytes_reprocess
670- self .buffer_updated (nbytes_reprocess )
582+ self ._message = memoryview (bytearray (self ._message_size ))
583+ return
584+ if self ._expecting_compression :
585+ self ._compression_index += nbytes
586+ if self ._compression_index >= 9 :
587+ self ._op_code = self .process_compression_header ()
588+ return
589+
590+ self ._message_index += nbytes
591+ if self ._message_index >= self ._message_size :
592+ # Pause reading to avoid storing an arbitrary number of messages in memory.
593+ self .transport .pause_reading ()
594+ if self ._pending_messages :
595+ result = self ._pending_messages .popleft ()
596+ else :
597+ result = asyncio .get_running_loop ().create_future ()
598+ # Future has been cancelled, close this connection
599+ if result .done ():
600+ self .connection_lost (None )
601+ return
602+ # Necessary values to construct message from buffers
603+ result .set_result ((self ._op_code , self ._compressor_id , self ._message ))
604+ self ._done_messages .append (result )
605+ # Reset internal state to expect a new message
606+ self ._expecting_header = True
607+ self ._header_index = 0
608+ self ._compression_index = 0
609+ self ._message_index = 0
610+ self ._message_size = 0
611+ self ._message = None
612+ self ._op_code = 0
613+ self ._compressor_id = None
671614
672615 def process_header (self ) -> tuple [int , int ]:
673616 """Unpack a MongoDB Wire Protocol header."""
674- length , _ , response_to , op_code = _UNPACK_HEADER (
675- self ._buffer [self ._start_index : self ._start_index + 16 ]
676- )
617+ length , _ , response_to , op_code = _UNPACK_HEADER (self ._header )
677618 # No request_id for exhaust cursor "getMore".
678619 if self ._request_id is not None :
679620 if self ._request_id != response_to :
@@ -689,24 +630,22 @@ def process_header(self) -> tuple[int, int]:
689630 f"Message length ({ length !r} ) is larger than server max "
690631 f"message size ({ self ._max_message_size !r} )"
691632 )
692- if op_code == 2012 :
693- self . _is_compressed = True
694- op_code , _ , self . _compressor_id = _UNPACK_COMPRESSION_HEADER (
695- self . _buffer [ self . _start_index + 16 : self . _start_index + 25 ]
696- )
697-
698- return length , op_code
633+ if op_code == 2012 : # OP_COMPRESSED
634+ if length <= 25 :
635+ raise ProtocolError (
636+ f"Message length ( { length !r } ) not longer than standard OP_COMPRESSED message header size (25)"
637+ )
638+ self . _expecting_compression = True
639+ length -= 9
699640
700- def pause_writing (self ) -> None :
701- assert not self ._paused
702- self ._paused = True
641+ self ._expecting_header = False
642+ return length - 16 , op_code
703643
704- def resume_writing (self ) -> None :
705- assert self ._paused
706- self ._paused = False
707- #
708- # if self._drain_waiter and not self._drain_waiter.done():
709- # self._drain_waiter.set_result(None)
644+ def process_compression_header (self ) -> tuple [int , int ]:
645+ """Unpack a MongoDB Wire Protocol header."""
646+ op_code , _ , self ._compressor_id = _UNPACK_COMPRESSION_HEADER (self ._compression_header )
647+ self ._expecting_compression = False
648+ return op_code
710649
711650 def connection_lost (self , exc : Exception | None ) -> None :
712651 self ._connection_lost = True
@@ -725,27 +664,6 @@ def connection_lost(self, exc: Exception | None) -> None:
725664 else :
726665 self ._closed .set_exception (exc )
727666
728- # # Wake up the writer(s) if currently paused.
729- # if not self._paused:
730- # return
731- #
732- # if self._drain_waiter and not self._drain_waiter.done():
733- # if exc is None:
734- # self._drain_waiter.set_result(None)
735- # else:
736- # self._drain_waiter.set_exception(exc)
737-
738- async def _drain_helper (self ) -> None :
739- if self ._connection_lost :
740- raise ConnectionResetError ("Connection lost" )
741- if not self ._paused :
742- return
743- self ._drain_waiter = asyncio .get_running_loop ().create_future ()
744- await self ._drain_waiter
745-
746- def data (self ) -> bytes :
747- return self ._buffer .tobytes ()
748-
749667 async def wait_closed (self ) -> None :
750668 await asyncio .wait ([self ._closed ])
751669
0 commit comments