Skip to content

Commit 7006cf4

Browse files
committed
Don't store request_id
1 parent 988ee60 commit 7006cf4

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

pymongo/network_layer.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def __init__(self, timeout: Optional[float] = None):
487487
self._is_compressed = False
488488
self._compressor_id: Optional[int] = None
489489
self._max_message_size = MAX_MESSAGE_SIZE
490-
self._request_id: Optional[int] = None
490+
self._response_to: Optional[int] = None
491491
self._closed = asyncio.get_running_loop().create_future()
492492
self._pending_messages: collections.deque[Future] = collections.deque()
493493
self._done_messages: collections.deque[Future] = collections.deque()
@@ -512,11 +512,10 @@ async def write(self, message: bytes) -> None:
512512
if self.transport.is_closing():
513513
raise OSError("Connection is closed")
514514
self.transport.write(message)
515-
# self.transport.resume_reading()
515+
self.transport.resume_reading()
516516

517517
async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[bytes, int]:
518518
"""Read a single MongoDB Wire Protocol message from this connection."""
519-
self._request_id = request_id
520519
if self.transport:
521520
try:
522521
self.transport.resume_reading()
@@ -537,7 +536,13 @@ async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[
537536
if read_waiter in self._done_messages:
538537
self._done_messages.remove(read_waiter)
539538
if message:
540-
op_code, compressor_id, data = message
539+
op_code, compressor_id, response_to, data = message
540+
# No request_id for exhaust cursor "getMore".
541+
if request_id is not None:
542+
if request_id != response_to:
543+
raise ProtocolError(
544+
f"Got response id {response_to!r} but expected {request_id!r}"
545+
)
541546
if compressor_id is not None:
542547
data = decompress(data, compressor_id)
543548
return data, op_code
@@ -577,7 +582,12 @@ def buffer_updated(self, nbytes: int) -> None:
577582
if self._header_index >= 16:
578583
self._expecting_header = False
579584
try:
580-
self._message_size, self._op_code = self.process_header()
585+
(
586+
self._message_size,
587+
self._op_code,
588+
self._response_to,
589+
self._expecting_compression,
590+
) = self.process_header()
581591
except ProtocolError as exc:
582592
self.close(exc)
583593
return
@@ -603,8 +613,10 @@ def buffer_updated(self, nbytes: int) -> None:
603613
if result.done():
604614
self.close(None)
605615
return
606-
# Necessary values to construct message from buffers
607-
result.set_result((self._op_code, self._compressor_id, self._message))
616+
# Necessary values to reconstruct and verify message
617+
result.set_result(
618+
(self._op_code, self._compressor_id, self._response_to, self._message)
619+
)
608620
self._done_messages.append(result)
609621
# Reset internal state to expect a new message
610622
self._header_index = 0
@@ -614,23 +626,19 @@ def buffer_updated(self, nbytes: int) -> None:
614626
self._message = None
615627
self._op_code = 0
616628
self._compressor_id = None
629+
self._response_to = None
617630

618-
def process_header(self) -> tuple[int, int]:
631+
def process_header(self) -> tuple[int, int, int, bool]:
619632
"""Unpack a MongoDB Wire Protocol header."""
620633
length, _, response_to, op_code = _UNPACK_HEADER(self._header)
634+
expecting_compression = False
621635
if op_code == 2012: # OP_COMPRESSED
622636
if length <= 25:
623637
raise ProtocolError(
624638
f"Message length ({length!r}) not longer than standard OP_COMPRESSED message header size (25)"
625639
)
626-
self._expecting_compression = True
640+
expecting_compression = True
627641
length -= 9
628-
# No request_id for exhaust cursor "getMore".
629-
if self._request_id is not None:
630-
if self._request_id != response_to:
631-
raise ProtocolError(
632-
f"Got response id {response_to!r} but expected {self._request_id!r}"
633-
)
634642
if length <= 16:
635643
raise ProtocolError(
636644
f"Message length ({length!r}) not longer than standard message header size (16)"
@@ -641,7 +649,7 @@ def process_header(self) -> tuple[int, int]:
641649
f"message size ({self._max_message_size!r})"
642650
)
643651

644-
return length - 16, op_code
652+
return length - 16, op_code, response_to, expecting_compression
645653

646654
def process_compression_header(self) -> tuple[int, int]:
647655
"""Unpack a MongoDB Wire Protocol compression header."""

0 commit comments

Comments
 (0)