@@ -511,12 +511,8 @@ async def write(self, message: bytes) -> None:
511
511
"""Write a message to this connection's transport."""
512
512
if self .transport .is_closing ():
513
513
raise OSError ("Connection is closed" )
514
- try :
515
- self .transport .resume_reading ()
516
- # Known bug in SSL Protocols, fixed in Python 3.11: https://github.com/python/cpython/issues/89322
517
- except AttributeError :
518
- raise OSError ("connection is already closed" ) from None
519
514
self .transport .write (message )
515
+ self .transport .resume_reading ()
520
516
521
517
async def read (self , request_id : Optional [int ], max_message_size : int ) -> tuple [bytes , int ]:
522
518
"""Read a single MongoDB Wire Protocol message from this connection."""
@@ -553,6 +549,13 @@ def get_buffer(self, sizehint: int) -> memoryview:
553
549
If any data does not fit into the returned buffer, this method will be called again until
554
550
either no data remains or an empty buffer is returned.
555
551
"""
552
+ # Due to a bug, Python <=3.11 will call get_buffer() even after we raise
553
+ # ProtocolError in buffer_updated() and call connection_lost(). We allocate
554
+ # a temp buffer to drain the waiting data.
555
+ if self ._connection_lost :
556
+ if not self ._message :
557
+ self ._message = memoryview (bytearray (2 ** 14 ))
558
+ return self ._message
556
559
# TODO: optimize this by caching pointers to the buffers.
557
560
# return self._buffer[self._index:]
558
561
if self ._expecting_header :
@@ -567,9 +570,12 @@ def buffer_updated(self, nbytes: int) -> None:
567
570
if nbytes == 0 :
568
571
self .connection_lost (OSError ("connection closed" ))
569
572
return
573
+ if self ._connection_lost :
574
+ return
570
575
if self ._expecting_header :
571
576
self ._header_index += nbytes
572
577
if self ._header_index >= 16 :
578
+ self ._expecting_header = False
573
579
try :
574
580
self ._message_size , self ._op_code = self .process_header ()
575
581
except ProtocolError as exc :
@@ -580,11 +586,13 @@ def buffer_updated(self, nbytes: int) -> None:
580
586
if self ._expecting_compression :
581
587
self ._compression_index += nbytes
582
588
if self ._compression_index >= 9 :
583
- self ._op_code = self .process_compression_header ()
589
+ self ._expecting_compression = False
590
+ self ._op_code , self ._compressor_id = self .process_compression_header ()
584
591
return
585
592
586
593
self ._message_index += nbytes
587
594
if self ._message_index >= self ._message_size :
595
+ self ._expecting_header = True
588
596
# Pause reading to avoid storing an arbitrary number of messages in memory.
589
597
self .transport .pause_reading ()
590
598
if self ._pending_messages :
@@ -599,7 +607,6 @@ def buffer_updated(self, nbytes: int) -> None:
599
607
result .set_result ((self ._op_code , self ._compressor_id , self ._message ))
600
608
self ._done_messages .append (result )
601
609
# Reset internal state to expect a new message
602
- self ._expecting_header = True
603
610
self ._header_index = 0
604
611
self ._compression_index = 0
605
612
self ._message_index = 0
@@ -611,6 +618,13 @@ def buffer_updated(self, nbytes: int) -> None:
611
618
def process_header (self ) -> tuple [int , int ]:
612
619
"""Unpack a MongoDB Wire Protocol header."""
613
620
length , _ , response_to , op_code = _UNPACK_HEADER (self ._header )
621
+ if op_code == 2012 : # OP_COMPRESSED
622
+ if length <= 25 :
623
+ raise ProtocolError (
624
+ f"Message length ({ length !r} ) not longer than standard OP_COMPRESSED message header size (25)"
625
+ )
626
+ self ._expecting_compression = True
627
+ length -= 9
614
628
# No request_id for exhaust cursor "getMore".
615
629
if self ._request_id is not None :
616
630
if self ._request_id != response_to :
@@ -626,25 +640,17 @@ def process_header(self) -> tuple[int, int]:
626
640
f"Message length ({ length !r} ) is larger than server max "
627
641
f"message size ({ self ._max_message_size !r} )"
628
642
)
629
- if op_code == 2012 : # OP_COMPRESSED
630
- if length <= 25 :
631
- raise ProtocolError (
632
- f"Message length ({ length !r} ) not longer than standard OP_COMPRESSED message header size (25)"
633
- )
634
- self ._expecting_compression = True
635
- length -= 9
636
643
637
- self ._expecting_header = False
638
644
return length - 16 , op_code
639
645
640
- def process_compression_header (self ) -> int :
646
+ def process_compression_header (self ) -> tuple [ int , int ] :
641
647
"""Unpack a MongoDB Wire Protocol compression header."""
642
- op_code , _ , self ._compressor_id = _UNPACK_COMPRESSION_HEADER (self ._compression_header )
643
- self ._expecting_compression = False
644
- return op_code
648
+ op_code , _ , compressor_id = _UNPACK_COMPRESSION_HEADER (self ._compression_header )
649
+ return op_code , compressor_id
645
650
646
651
def connection_lost (self , exc : Exception | None ) -> None :
647
652
self ._connection_lost = True
653
+ super ().connection_lost (exc )
648
654
pending = list (self ._pending_messages )
649
655
for msg in pending :
650
656
if not msg .done ():
0 commit comments