@@ -511,12 +511,8 @@ async def write(self, message: bytes) -> None:
511511 """Write a message to this connection's transport."""
512512 if self .transport .is_closing ():
513513 raise OSError ("Connection is closed" )
514- try :
515- self .transport .resume_reading ()
516- # Known bug in SSL Protocols, fixed in Python 3.11: https://github.com/python/cpython/issues/89322
517- except AttributeError :
518- raise OSError ("connection is already closed" ) from None
519514 self .transport .write (message )
515+ self .transport .resume_reading ()
520516
521517 async def read (self , request_id : Optional [int ], max_message_size : int ) -> tuple [bytes , int ]:
522518 """Read a single MongoDB Wire Protocol message from this connection."""
@@ -553,6 +549,13 @@ def get_buffer(self, sizehint: int) -> memoryview:
553549 If any data does not fit into the returned buffer, this method will be called again until
554550 either no data remains or an empty buffer is returned.
555551 """
552+ # Due to a bug, Python <=3.11 will call get_buffer() even after we raise
553+ # ProtocolError in buffer_updated() and call connection_lost(). We allocate
554+ # a temp buffer to drain the waiting data.
555+ if self ._connection_lost :
556+ if not self ._message :
557+ self ._message = memoryview (bytearray (2 ** 14 ))
558+ return self ._message
556559 # TODO: optimize this by caching pointers to the buffers.
557560 # return self._buffer[self._index:]
558561 if self ._expecting_header :
@@ -567,9 +570,12 @@ def buffer_updated(self, nbytes: int) -> None:
567570 if nbytes == 0 :
568571 self .connection_lost (OSError ("connection closed" ))
569572 return
573+ if self ._connection_lost :
574+ return
570575 if self ._expecting_header :
571576 self ._header_index += nbytes
572577 if self ._header_index >= 16 :
578+ self ._expecting_header = False
573579 try :
574580 self ._message_size , self ._op_code = self .process_header ()
575581 except ProtocolError as exc :
@@ -580,11 +586,13 @@ def buffer_updated(self, nbytes: int) -> None:
580586 if self ._expecting_compression :
581587 self ._compression_index += nbytes
582588 if self ._compression_index >= 9 :
583- self ._op_code = self .process_compression_header ()
589+ self ._expecting_compression = False
590+ self ._op_code , self ._compressor_id = self .process_compression_header ()
584591 return
585592
586593 self ._message_index += nbytes
587594 if self ._message_index >= self ._message_size :
595+ self ._expecting_header = True
588596 # Pause reading to avoid storing an arbitrary number of messages in memory.
589597 self .transport .pause_reading ()
590598 if self ._pending_messages :
@@ -599,7 +607,6 @@ def buffer_updated(self, nbytes: int) -> None:
599607 result .set_result ((self ._op_code , self ._compressor_id , self ._message ))
600608 self ._done_messages .append (result )
601609 # Reset internal state to expect a new message
602- self ._expecting_header = True
603610 self ._header_index = 0
604611 self ._compression_index = 0
605612 self ._message_index = 0
@@ -611,6 +618,13 @@ def buffer_updated(self, nbytes: int) -> None:
611618 def process_header (self ) -> tuple [int , int ]:
612619 """Unpack a MongoDB Wire Protocol header."""
613620 length , _ , response_to , op_code = _UNPACK_HEADER (self ._header )
621+ if op_code == 2012 : # OP_COMPRESSED
622+ if length <= 25 :
623+ raise ProtocolError (
624+ f"Message length ({ length !r} ) not longer than standard OP_COMPRESSED message header size (25)"
625+ )
626+ self ._expecting_compression = True
627+ length -= 9
614628 # No request_id for exhaust cursor "getMore".
615629 if self ._request_id is not None :
616630 if self ._request_id != response_to :
@@ -626,25 +640,17 @@ def process_header(self) -> tuple[int, int]:
626640 f"Message length ({ length !r} ) is larger than server max "
627641 f"message size ({ self ._max_message_size !r} )"
628642 )
629- if op_code == 2012 : # OP_COMPRESSED
630- if length <= 25 :
631- raise ProtocolError (
632- f"Message length ({ length !r} ) not longer than standard OP_COMPRESSED message header size (25)"
633- )
634- self ._expecting_compression = True
635- length -= 9
636643
637- self ._expecting_header = False
638644 return length - 16 , op_code
639645
640- def process_compression_header (self ) -> int :
646+ def process_compression_header (self ) -> tuple [ int , int ] :
641647 """Unpack a MongoDB Wire Protocol compression header."""
642- op_code , _ , self ._compressor_id = _UNPACK_COMPRESSION_HEADER (self ._compression_header )
643- self ._expecting_compression = False
644- return op_code
648+ op_code , _ , compressor_id = _UNPACK_COMPRESSION_HEADER (self ._compression_header )
649+ return op_code , compressor_id
645650
646651 def connection_lost (self , exc : Exception | None ) -> None :
647652 self ._connection_lost = True
653+ super ().connection_lost (exc )
648654 pending = list (self ._pending_messages )
649655 for msg in pending :
650656 if not msg .done ():
0 commit comments