@@ -487,7 +487,7 @@ def __init__(self, timeout: Optional[float] = None):
487487 self ._is_compressed = False
488488 self ._compressor_id : Optional [int ] = None
489489 self ._max_message_size = MAX_MESSAGE_SIZE
490- self ._request_id : Optional [int ] = None
490+ self ._response_to : Optional [int ] = None
491491 self ._closed = asyncio .get_running_loop ().create_future ()
492492 self ._pending_messages : collections .deque [Future ] = collections .deque ()
493493 self ._done_messages : collections .deque [Future ] = collections .deque ()
@@ -512,11 +512,10 @@ async def write(self, message: bytes) -> None:
512512 if self .transport .is_closing ():
513513 raise OSError ("Connection is closed" )
514514 self .transport .write (message )
515- # self.transport.resume_reading()
515+ self .transport .resume_reading ()
516516
517517 async def read (self , request_id : Optional [int ], max_message_size : int ) -> tuple [bytes , int ]:
518518 """Read a single MongoDB Wire Protocol message from this connection."""
519- self ._request_id = request_id
520519 if self .transport :
521520 try :
522521 self .transport .resume_reading ()
@@ -537,7 +536,13 @@ async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[
537536 if read_waiter in self ._done_messages :
538537 self ._done_messages .remove (read_waiter )
539538 if message :
540- op_code , compressor_id , data = message
539+ op_code , compressor_id , response_to , data = message
540+ # No request_id for exhaust cursor "getMore".
541+ if request_id is not None :
542+ if request_id != response_to :
543+ raise ProtocolError (
544+ f"Got response id { response_to !r} but expected { request_id !r} "
545+ )
541546 if compressor_id is not None :
542547 data = decompress (data , compressor_id )
543548 return data , op_code
@@ -577,7 +582,12 @@ def buffer_updated(self, nbytes: int) -> None:
577582 if self ._header_index >= 16 :
578583 self ._expecting_header = False
579584 try :
580- self ._message_size , self ._op_code = self .process_header ()
585+ (
586+ self ._message_size ,
587+ self ._op_code ,
588+ self ._response_to ,
589+ self ._expecting_compression ,
590+ ) = self .process_header ()
581591 except ProtocolError as exc :
582592 self .close (exc )
583593 return
@@ -603,8 +613,10 @@ def buffer_updated(self, nbytes: int) -> None:
603613 if result .done ():
604614 self .close (None )
605615 return
606- # Necessary values to construct message from buffers
607- result .set_result ((self ._op_code , self ._compressor_id , self ._message ))
616+ # Necessary values to reconstruct and verify message
617+ result .set_result (
618+ (self ._op_code , self ._compressor_id , self ._response_to , self ._message )
619+ )
608620 self ._done_messages .append (result )
609621 # Reset internal state to expect a new message
610622 self ._header_index = 0
@@ -614,23 +626,19 @@ def buffer_updated(self, nbytes: int) -> None:
614626 self ._message = None
615627 self ._op_code = 0
616628 self ._compressor_id = None
629+ self ._response_to = None
617630
618- def process_header (self ) -> tuple [int , int ]:
631+ def process_header (self ) -> tuple [int , int , int , bool ]:
619632 """Unpack a MongoDB Wire Protocol header."""
620633 length , _ , response_to , op_code = _UNPACK_HEADER (self ._header )
634+ expecting_compression = False
621635 if op_code == 2012 : # OP_COMPRESSED
622636 if length <= 25 :
623637 raise ProtocolError (
624638 f"Message length ({ length !r} ) not longer than standard OP_COMPRESSED message header size (25)"
625639 )
626- self . _expecting_compression = True
640+ expecting_compression = True
627641 length -= 9
628- # No request_id for exhaust cursor "getMore".
629- if self ._request_id is not None :
630- if self ._request_id != response_to :
631- raise ProtocolError (
632- f"Got response id { response_to !r} but expected { self ._request_id !r} "
633- )
634642 if length <= 16 :
635643 raise ProtocolError (
636644 f"Message length ({ length !r} ) not longer than standard message header size (16)"
@@ -641,7 +649,7 @@ def process_header(self) -> tuple[int, int]:
641649 f"message size ({ self ._max_message_size !r} )"
642650 )
643651
644- return length - 16 , op_code
652+ return length - 16 , op_code , response_to , expecting_compression
645653
646654 def process_compression_header (self ) -> tuple [int , int ]:
647655 """Unpack a MongoDB Wire Protocol compression header."""
0 commit comments