Skip to content

Commit 36fb9cf

Browse files
committed
Fix multiple message processing in buffer_updated
1 parent 02cf3c1 commit 36fb9cf

File tree

1 file changed

+20
-13
lines changed

1 file changed

+20
-13
lines changed

pymongo/network_layer.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,8 @@ async def read(
521521
self, request_id: Optional[int], max_message_size: int, debug: bool = False
522522
) -> tuple[bytes, int]:
523523
"""Read a single MongoDB Wire Protocol message from this connection."""
524-
self.transport.resume_reading()
524+
if self.transport:
525+
self.transport.resume_reading()
525526
if self._done_messages:
526527
message = await self._done_messages.popleft()
527528
else:
@@ -532,9 +533,10 @@ async def read(
532533
self._length = 0
533534
self._overflow_length = 0
534535
self._body_length = 0
536+
self._start = 0
535537
self._op_code = None # type: ignore[assignment]
536538
self._overflow = None
537-
if self.transport.is_closing():
539+
if self.transport and self.transport.is_closing():
538540
raise OSError("Connection is closed")
539541
read_waiter = asyncio.get_running_loop().create_future()
540542
self._pending_messages.append(read_waiter)
@@ -545,10 +547,12 @@ async def read(
545547
self._done_messages.remove(read_waiter)
546548
if message:
547549
start, end, op_code = message[0], message[1], message[2]
548-
header_size = 16
550+
if self._is_compressed:
551+
header_size = 25
552+
else:
553+
header_size = 16
549554
if self._body_length > self._buffer_size and self._overflow is not None:
550555
if self._is_compressed and self._compressor_id is not None:
551-
header_size = 25
552556
return decompress(
553557
memoryview(
554558
bytearray(self._buffer[header_size : self._length])
@@ -563,7 +567,6 @@ async def read(
563567
), op_code
564568
else:
565569
if self._is_compressed and self._compressor_id is not None:
566-
header_size = 25
567570
return decompress(
568571
memoryview(self._buffer[start + header_size : end]),
569572
self._compressor_id,
@@ -594,24 +597,24 @@ def buffer_updated(self, nbytes: int) -> None:
594597
self.connection_lost(exc)
595598
return
596599
self._expecting_header = False
600+
# TODO: account for multiple messages processed within a single read() call
597601
if self._body_length > self._buffer_size:
598602
self._overflow = memoryview(
599603
bytearray(self._body_length - (self._length + nbytes) + 1024)
600604
)
601605
self._length += nbytes
602-
if self._length + self._overflow_length >= self._body_length:
606+
if self._length + self._overflow_length - self._start >= self._body_length:
603607
if self._pending_messages:
604608
done = self._pending_messages.popleft()
605609
else:
606610
done = asyncio.get_running_loop().create_future()
607-
done.set_result((self._start, self._body_length, self._op_code))
608-
self._start = 0
611+
done.set_result((self._start, self._body_length + self._start, self._op_code))
612+
self._start += self._body_length
609613
self._done_messages.append(done)
610-
if self._length > self._body_length:
614+
if self._length - self._start > self._body_length:
611615
self._read_waiter = asyncio.get_running_loop().create_future()
612616
self._pending_messages.append(self._read_waiter)
613-
self._start = self._body_length
614-
extra = self._length - self._body_length
617+
extra = self._length - self._start
615618
self._length -= extra
616619
self._expecting_header = True
617620
self._body_length = 0
@@ -621,7 +624,9 @@ def buffer_updated(self, nbytes: int) -> None:
621624

622625
def process_header(self) -> tuple[int, int]:
623626
"""Unpack a MongoDB Wire Protocol header."""
624-
length, _, response_to, op_code = _UNPACK_HEADER(self._buffer[self._start : 16])
627+
length, _, response_to, op_code = _UNPACK_HEADER(
628+
self._buffer[self._start : self._start + 16]
629+
)
625630
# No request_id for exhaust cursor "getMore".
626631
if self._request_id is not None:
627632
if self._request_id != response_to:
@@ -640,7 +645,9 @@ def process_header(self) -> tuple[int, int]:
640645
if op_code == 2012:
641646
self._is_compressed = True
642647
if self._length >= 25:
643-
op_code, _, self._compressor_id = _UNPACK_COMPRESSION_HEADER(self._buffer[16:25])
648+
op_code, _, self._compressor_id = _UNPACK_COMPRESSION_HEADER(
649+
self._buffer[self._start + 16 : self._start + 25]
650+
)
644651
else:
645652
self._need_compression_header = True
646653

0 commit comments

Comments
 (0)