@@ -487,7 +487,7 @@ def __init__(self, timeout: Optional[float] = None):
487
487
self ._is_compressed = False
488
488
self ._compressor_id : Optional [int ] = None
489
489
self ._max_message_size = MAX_MESSAGE_SIZE
490
- self ._request_id : Optional [int ] = None
490
+ self ._response_to : Optional [int ] = None
491
491
self ._closed = asyncio .get_running_loop ().create_future ()
492
492
self ._pending_messages : collections .deque [Future ] = collections .deque ()
493
493
self ._done_messages : collections .deque [Future ] = collections .deque ()
@@ -512,11 +512,10 @@ async def write(self, message: bytes) -> None:
512
512
if self .transport .is_closing ():
513
513
raise OSError ("Connection is closed" )
514
514
self .transport .write (message )
515
- # self.transport.resume_reading()
515
+ self .transport .resume_reading ()
516
516
517
517
async def read (self , request_id : Optional [int ], max_message_size : int ) -> tuple [bytes , int ]:
518
518
"""Read a single MongoDB Wire Protocol message from this connection."""
519
- self ._request_id = request_id
520
519
if self .transport :
521
520
try :
522
521
self .transport .resume_reading ()
@@ -537,7 +536,13 @@ async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[
537
536
if read_waiter in self ._done_messages :
538
537
self ._done_messages .remove (read_waiter )
539
538
if message :
540
- op_code , compressor_id , data = message
539
+ op_code , compressor_id , response_to , data = message
540
+ # No request_id for exhaust cursor "getMore".
541
+ if request_id is not None :
542
+ if request_id != response_to :
543
+ raise ProtocolError (
544
+ f"Got response id { response_to !r} but expected { request_id !r} "
545
+ )
541
546
if compressor_id is not None :
542
547
data = decompress (data , compressor_id )
543
548
return data , op_code
@@ -577,7 +582,12 @@ def buffer_updated(self, nbytes: int) -> None:
577
582
if self ._header_index >= 16 :
578
583
self ._expecting_header = False
579
584
try :
580
- self ._message_size , self ._op_code = self .process_header ()
585
+ (
586
+ self ._message_size ,
587
+ self ._op_code ,
588
+ self ._response_to ,
589
+ self ._expecting_compression ,
590
+ ) = self .process_header ()
581
591
except ProtocolError as exc :
582
592
self .close (exc )
583
593
return
@@ -603,8 +613,10 @@ def buffer_updated(self, nbytes: int) -> None:
603
613
if result .done ():
604
614
self .close (None )
605
615
return
606
- # Necessary values to construct message from buffers
607
- result .set_result ((self ._op_code , self ._compressor_id , self ._message ))
616
+ # Necessary values to reconstruct and verify message
617
+ result .set_result (
618
+ (self ._op_code , self ._compressor_id , self ._response_to , self ._message )
619
+ )
608
620
self ._done_messages .append (result )
609
621
# Reset internal state to expect a new message
610
622
self ._header_index = 0
@@ -614,23 +626,19 @@ def buffer_updated(self, nbytes: int) -> None:
614
626
self ._message = None
615
627
self ._op_code = 0
616
628
self ._compressor_id = None
629
+ self ._response_to = None
617
630
618
- def process_header (self ) -> tuple [int , int ]:
631
+ def process_header (self ) -> tuple [int , int , int , bool ]:
619
632
"""Unpack a MongoDB Wire Protocol header."""
620
633
length , _ , response_to , op_code = _UNPACK_HEADER (self ._header )
634
+ expecting_compression = False
621
635
if op_code == 2012 : # OP_COMPRESSED
622
636
if length <= 25 :
623
637
raise ProtocolError (
624
638
f"Message length ({ length !r} ) not longer than standard OP_COMPRESSED message header size (25)"
625
639
)
626
- self . _expecting_compression = True
640
+ expecting_compression = True
627
641
length -= 9
628
- # No request_id for exhaust cursor "getMore".
629
- if self ._request_id is not None :
630
- if self ._request_id != response_to :
631
- raise ProtocolError (
632
- f"Got response id { response_to !r} but expected { self ._request_id !r} "
633
- )
634
642
if length <= 16 :
635
643
raise ProtocolError (
636
644
f"Message length ({ length !r} ) not longer than standard message header size (16)"
@@ -641,7 +649,7 @@ def process_header(self) -> tuple[int, int]:
641
649
f"message size ({ self ._max_message_size !r} )"
642
650
)
643
651
644
- return length - 16 , op_code
652
+ return length - 16 , op_code , response_to , expecting_compression
645
653
646
654
def process_compression_header (self ) -> tuple [int , int ]:
647
655
"""Unpack a MongoDB Wire Protocol compression header."""
0 commit comments