Skip to content

Commit 9fd9176

Browse files
committed
Cleanup + comments
1 parent b5ac22d commit 9fd9176

File tree

2 files changed

+70
-47
lines changed

2 files changed

+70
-47
lines changed

bson/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1109,7 +1109,9 @@ def _decode_all(data: _ReadableBuffer, opts: CodecOptions[_DocumentType]) -> lis
11091109
while position < end:
11101110
obj_size = _UNPACK_INT_FROM(data, position)[0]
11111111
if data_len - position < obj_size:
1112-
raise InvalidBSON("invalid object size")
1112+
raise InvalidBSON(
1113+
f"invalid object size: expected {obj_size}, got {data_len - position}"
1114+
)
11131115
obj_end = position + obj_size - 1
11141116
if data[obj_end] != 0:
11151117
raise InvalidBSON("bad eoo")

pymongo/network_layer.py

Lines changed: 67 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -474,10 +474,10 @@ def __init__(self, timeout: Optional[float] = None, buffer_size: int = 2**14):
474474
self.transport: Transport = None # type: ignore[assignment]
475475
self._buffer = memoryview(bytearray(self._buffer_size))
476476
self._overflow: Optional[memoryview] = None
477-
self._start = 0
478-
self._length = 0
479-
self._overflow_length = 0
480-
self._body_length = 0
477+
self._start_index = 0
478+
self._end_index = 0
479+
self._overflow_index = 0
480+
self._body_size = 0
481481
self._op_code: int = None # type: ignore[assignment]
482482
self._connection_lost = False
483483
self._paused = False
@@ -490,7 +490,6 @@ def __init__(self, timeout: Optional[float] = None, buffer_size: int = 2**14):
490490
self._max_message_size = MAX_MESSAGE_SIZE
491491
self._request_id: Optional[int] = None
492492
self._closed = asyncio.get_running_loop().create_future()
493-
self._debug = False
494493
self._expecting_header = True
495494
self._pending_messages: collections.deque[Future] = collections.deque()
496495
self._done_messages: collections.deque[Future] = collections.deque()
@@ -517,23 +516,20 @@ async def write(self, message: bytes) -> None:
517516
await self._drain_helper()
518517
self.transport.resume_reading()
519518

520-
async def read(
521-
self, request_id: Optional[int], max_message_size: int, debug: bool = False
522-
) -> tuple[bytes, int]:
519+
async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[bytes, int]:
523520
"""Read a single MongoDB Wire Protocol message from this connection."""
524521
if self.transport:
525522
self.transport.resume_reading()
526523
if self._done_messages:
527524
message = await self._done_messages.popleft()
528525
else:
529526
self._expecting_header = True
530-
self._debug = debug
531527
self._max_message_size = max_message_size
532528
self._request_id = request_id
533-
self._length = 0
534-
self._overflow_length = 0
535-
self._body_length = 0
536-
self._start = 0
529+
self._end_index = 0
530+
self._overflow_index = 0
531+
self._body_size = 0
532+
self._start_index = 0
537533
self._op_code = None # type: ignore[assignment]
538534
self._overflow = None
539535
if self.transport and self.transport.is_closing():
@@ -546,24 +542,30 @@ async def read(
546542
if read_waiter in self._done_messages:
547543
self._done_messages.remove(read_waiter)
548544
if message:
549-
start, end, op_code = message[0], message[1], message[2]
545+
start, end, op_code, overflow, overflow_index = (
546+
message[0],
547+
message[1],
548+
message[2],
549+
message[3],
550+
message[4],
551+
)
550552
if self._is_compressed:
551553
header_size = 25
552554
else:
553555
header_size = 16
554-
if self._body_length > self._buffer_size and self._overflow is not None:
556+
if overflow is not None:
555557
if self._is_compressed and self._compressor_id is not None:
556558
return decompress(
557559
memoryview(
558-
bytearray(self._buffer[header_size : self._length])
559-
+ bytearray(self._overflow[: self._overflow_length])
560+
bytearray(self._buffer[start + header_size : self._end_index])
561+
+ bytearray(overflow[:overflow_index])
560562
),
561563
self._compressor_id,
562564
), op_code
563565
else:
564566
return memoryview(
565-
bytearray(self._buffer[header_size : self._length])
566-
+ bytearray(self._overflow[: self._overflow_length])
567+
bytearray(self._buffer[start + header_size : self._end_index])
568+
+ bytearray(overflow[:overflow_index])
567569
), op_code
568570
else:
569571
if self._is_compressed and self._compressor_id is not None:
@@ -578,55 +580,75 @@ async def read(
578580
def get_buffer(self, sizehint: int) -> memoryview:
579581
"""Called to allocate a new receive buffer."""
580582
if self._overflow is not None:
581-
return self._overflow[self._overflow_length :]
582-
return self._buffer[self._length :]
583+
return self._overflow[self._overflow_index :]
584+
return self._buffer[self._end_index :]
583585

584586
def buffer_updated(self, nbytes: int) -> None:
585587
"""Called when the buffer was updated with the received data"""
588+
# Wrote 0 bytes into a non-empty buffer, signal connection closed
586589
if nbytes == 0:
587590
self.connection_lost(OSError("connection closed"))
588591
return
589592
else:
593+
# Wrote data into overflow buffer
590594
if self._overflow is not None:
591-
self._overflow_length += nbytes
595+
self._overflow_index += nbytes
596+
# Wrote data into default buffer
592597
else:
593598
if self._expecting_header:
594599
try:
595-
self._body_length, self._op_code = self.process_header()
600+
self._body_size, self._op_code = self.process_header()
596601
except ProtocolError as exc:
597602
self.connection_lost(exc)
598603
return
599604
self._expecting_header = False
600-
# TODO: account for multiple messages processed within a single read() call
601-
if self._body_length > self._buffer_size:
602-
self._overflow = memoryview(
603-
bytearray(self._body_length - (self._length + nbytes) + 1024)
604-
)
605-
self._length += nbytes
606-
if self._length + self._overflow_length - self._start >= self._body_length:
605+
# The new message's data is too large for the default buffer, allocate a new buffer for this message
606+
if self._body_size > self._buffer_size - (self._end_index + nbytes):
607+
self._overflow = memoryview(bytearray(self._body_size))
608+
self._end_index += nbytes
609+
# All data of the current message has been received
610+
if self._end_index + self._overflow_index - self._start_index >= self._body_size:
607611
if self._pending_messages:
608-
done = self._pending_messages.popleft()
612+
result = self._pending_messages.popleft()
609613
else:
610-
done = asyncio.get_running_loop().create_future()
611-
done.set_result((self._start, self._body_length + self._start, self._op_code))
612-
self._start += self._body_length
613-
self._done_messages.append(done)
614-
# If we have more data after processing the last message, start processing a new message
615-
if self._length - self._start > 0:
614+
result = asyncio.get_running_loop().create_future()
615+
# Future has been cancelled, close this connection
616+
if result.done():
617+
self.connection_lost(None)
618+
return
619+
# Necessary values to construct message from buffers
620+
result.set_result(
621+
(
622+
self._start_index,
623+
self._body_size + self._start_index,
624+
self._op_code,
625+
self._overflow,
626+
self._overflow_index,
627+
)
628+
)
629+
# Update the buffer's first written offset to reflect this message's size
630+
self._start_index += self._body_size
631+
self._done_messages.append(result)
632+
# If at least one header's worth of data remains after the current message, reprocess all leftover data
633+
if self._end_index - self._start_index >= 16:
616634
self._read_waiter = asyncio.get_running_loop().create_future()
617635
self._pending_messages.append(self._read_waiter)
618-
extra = self._length - self._start
619-
self._length -= extra
636+
nbytes_reprocess = self._end_index - self._start_index
637+
self._end_index -= nbytes_reprocess
638+
# Reset internal state to expect a new message
620639
self._expecting_header = True
621-
self._body_length = 0
640+
self._body_size = 0
622641
self._op_code = None # type: ignore[assignment]
623-
self.buffer_updated(extra)
642+
self._overflow = None
643+
self._overflow_index = 0
644+
self.buffer_updated(nbytes_reprocess)
645+
# Pause reading to avoid storing an arbitrary number of messages in memory before necessary
624646
self.transport.pause_reading()
625647

626648
def process_header(self) -> tuple[int, int]:
627649
"""Unpack a MongoDB Wire Protocol header."""
628650
length, _, response_to, op_code = _UNPACK_HEADER(
629-
self._buffer[self._start : self._start + 16]
651+
self._buffer[self._start_index : self._start_index + 16]
630652
)
631653
# No request_id for exhaust cursor "getMore".
632654
if self._request_id is not None:
@@ -645,9 +667,9 @@ def process_header(self) -> tuple[int, int]:
645667
)
646668
if op_code == 2012:
647669
self._is_compressed = True
648-
if self._length >= 25:
670+
if self._end_index >= 25:
649671
op_code, _, self._compressor_id = _UNPACK_COMPRESSION_HEADER(
650-
self._buffer[self._start + 16 : self._start + 25]
672+
self._buffer[self._start_index + 16 : self._start_index + 25]
651673
)
652674
else:
653675
self._need_compression_header = True
@@ -719,7 +741,6 @@ async def async_receive_message(
719741
conn: AsyncConnection,
720742
request_id: Optional[int],
721743
max_message_size: int = MAX_MESSAGE_SIZE,
722-
debug: bool = False,
723744
) -> Union[_OpReply, _OpMsg]:
724745
"""Receive a raw BSON message or raise socket.error."""
725746
timeout: Optional[Union[float, int]]
@@ -738,7 +759,7 @@ async def async_receive_message(
738759
timeout = max(deadline - time.monotonic(), 0)
739760

740761
cancellation_task = create_task(_poll_cancellation(conn))
741-
read_task = create_task(conn.conn.get_conn.read(request_id, max_message_size, debug))
762+
read_task = create_task(conn.conn.get_conn.read(request_id, max_message_size))
742763
tasks = [read_task, cancellation_task]
743764
try:
744765
done, pending = await asyncio.wait(

0 commit comments

Comments
 (0)