Skip to content

Commit 90e6720

Browse files
committed
PYTHON-4493 Fix read perf
1 parent 9b00835 commit 90e6720

File tree

1 file changed

+82
-164
lines changed

1 file changed

+82
-164
lines changed

pymongo/network_layer.py

Lines changed: 82 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -466,27 +466,29 @@ def recv_into(self, buffer: bytes) -> int:
466466

467467

468468
class PyMongoProtocol(BufferedProtocol):
469-
def __init__(self, timeout: Optional[float] = None, buffer_size: int = 2**14):
470-
self._buffer_size = buffer_size
469+
def __init__(self, timeout: Optional[float] = None, buffer_size: int = 2**10):
471470
self.transport: Transport = None # type: ignore[assignment]
472-
self._buffer = memoryview(bytearray(self._buffer_size))
473-
self._overflow: Optional[memoryview] = None
474-
self._start_index = 0
475-
self._end_index = 0
476-
self._overflow_index = 0
477-
self._body_size = 0
478-
self._op_code: int = None # type: ignore[assignment]
471+
# Each message is reader in 2-3 parts: header, compression header, and message body
472+
# The message buffer is allocated after the header is read.
473+
self._header = memoryview(bytearray(16))
474+
self._header_index = 0
475+
self._compression_header = memoryview(bytearray(9))
476+
self._compression_index = 0
477+
self._message: Optional[memoryview] = None
478+
self._message_index = 0
479+
# State. TODO: replace booleans with an enum?
480+
self._expecting_header = True
481+
self._expecting_compression = False
482+
self._message_size = 0
483+
self._op_code = 0
479484
self._connection_lost = False
480-
self._paused = False
481-
self._drain_waiter: Optional[Future] = None
482485
self._read_waiter: Optional[Future] = None
483486
self._timeout = timeout
484487
self._is_compressed = False
485488
self._compressor_id: Optional[int] = None
486489
self._max_message_size = MAX_MESSAGE_SIZE
487490
self._request_id: Optional[int] = None
488491
self._closed = asyncio.get_running_loop().create_future()
489-
self._expecting_header = True
490492
self._pending_messages: collections.deque[Future] = collections.deque()
491493
self._done_messages: collections.deque[Future] = collections.deque()
492494

@@ -515,8 +517,6 @@ async def write(self, message: bytes) -> None:
515517
except AttributeError:
516518
raise OSError("connection is already closed") from None
517519
self.transport.write(message)
518-
# await self._drain_helper()
519-
# self.transport.resume_reading()
520520

521521
async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[bytes, int]:
522522
"""Read a single MongoDB Wire Protocol message from this connection."""
@@ -526,20 +526,11 @@ async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[
526526
# Known bug in SSL Protocols, fixed in Python 3.11: https://github.com/python/cpython/issues/89322
527527
except AttributeError:
528528
raise OSError("connection is already closed") from None
529+
self._max_message_size = max_message_size
530+
self._request_id = request_id
529531
if self._done_messages:
530532
message = await self._done_messages.popleft()
531533
else:
532-
self._expecting_header = True
533-
self._max_message_size = max_message_size
534-
self._request_id = request_id
535-
self._end_index = 0
536-
self._overflow_index = 0
537-
self._body_size = 0
538-
self._start_index = 0
539-
self._op_code = None # type: ignore[assignment]
540-
self._overflow = None
541-
self._is_compressed = False
542-
self._compressor_id = None
543534
if self.transport and self.transport.is_closing():
544535
raise OSError("connection is already closed")
545536
read_waiter = asyncio.get_running_loop().create_future()
@@ -550,39 +541,10 @@ async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[
550541
if read_waiter in self._done_messages:
551542
self._done_messages.remove(read_waiter)
552543
if message:
553-
start, end, op_code, is_compressed, compressor_id, overflow, overflow_index = (
554-
message[0],
555-
message[1],
556-
message[2],
557-
message[3],
558-
message[4],
559-
message[5],
560-
message[6],
561-
)
562-
if is_compressed:
563-
header_size = 25
564-
else:
565-
header_size = 16
566-
if overflow is not None:
567-
if is_compressed and compressor_id is not None:
568-
return decompress(
569-
self._buffer[start + header_size : self._end_index].tobytes()
570-
+ overflow[:overflow_index].tobytes(),
571-
compressor_id,
572-
), op_code
573-
else:
574-
return (
575-
self._buffer[start + header_size : self._end_index].tobytes()
576-
+ overflow[:overflow_index].tobytes()
577-
), op_code
578-
else:
579-
if is_compressed and compressor_id is not None:
580-
return decompress(
581-
self._buffer[start + header_size : end],
582-
compressor_id,
583-
), op_code
584-
else:
585-
return self._buffer[start + header_size : end].tobytes(), op_code
544+
op_code, compressor_id, data = message
545+
if compressor_id is not None:
546+
data = decompress(data, compressor_id)
547+
return data, op_code
586548
raise OSError("connection closed")
587549

588550
def get_buffer(self, sizehint: int) -> memoryview:
@@ -592,88 +554,67 @@ def get_buffer(self, sizehint: int) -> memoryview:
592554
either no data remains or an empty buffer is returned.
593555
"""
594556
# Check for SSL EOF edge case, no data will be written to the buffer we return
557+
# TODO: is this needed?
595558
if sizehint == 0:
596559
return memoryview(bytearray(16))
597-
if self._overflow is not None:
598-
return self._overflow[self._overflow_index :]
599-
return self._buffer[self._end_index :]
560+
# TODO: optimize this by caching pointers to the buffers.
561+
# return self._buffer[self._index:]
562+
if self._expecting_header:
563+
return self._header[self._header_index :]
564+
if self._expecting_compression:
565+
return self._compression_header[self._compression_index :]
566+
return self._message[self._message_index :]
600567

601568
def buffer_updated(self, nbytes: int) -> None:
602569
"""Called when the buffer was updated with the received data"""
603570
# Wrote 0 bytes into a non-empty buffer, signal connection closed
604571
if nbytes == 0:
605572
self.connection_lost(OSError("connection closed"))
606573
return
607-
else:
608-
# Wrote data into overflow buffer
609-
if self._overflow is not None:
610-
self._overflow_index += nbytes
611-
# Wrote data into default buffer
612-
else:
613-
if self._expecting_header:
614-
try:
615-
self._body_size, self._op_code = self.process_header()
616-
self._request_id = None
617-
except ProtocolError as exc:
618-
self.connection_lost(exc)
619-
return
620-
self._expecting_header = False
621-
# The new message's data is too large for the default buffer, allocate a new buffer for this message
622-
if self._body_size > self._buffer_size - (self._end_index + nbytes):
623-
self._overflow = memoryview(bytearray(self._body_size))
624-
self._end_index += nbytes
625-
# All data of the current message has been received
626-
if self._end_index + self._overflow_index - self._start_index >= self._body_size:
627-
# Pause reading to avoid storing an arbitrary number of messages in memory before necessary
628-
self.transport.pause_reading()
629-
if self._pending_messages:
630-
result = self._pending_messages.popleft()
631-
else:
632-
result = asyncio.get_running_loop().create_future()
633-
# Future has been cancelled, close this connection
634-
if result.done():
635-
self.connection_lost(None)
574+
if self._expecting_header:
575+
self._header_index += nbytes
576+
if self._header_index >= 16:
577+
try:
578+
self._message_size, self._op_code = self.process_header()
579+
except ProtocolError as exc:
580+
self.connection_lost(exc)
636581
return
637-
# Necessary values to construct message from buffers
638-
result.set_result(
639-
(
640-
self._start_index,
641-
self._body_size + self._start_index,
642-
self._op_code,
643-
self._is_compressed,
644-
self._compressor_id,
645-
self._overflow,
646-
self._overflow_index,
647-
)
648-
)
649-
# If the current message has an overflow buffer, then the entire default buffer is full
650-
if self._overflow:
651-
self._start_index = self._end_index
652-
# Update the buffer's first written offset to reflect this message's size
653-
else:
654-
self._start_index += self._body_size
655-
self._done_messages.append(result)
656-
# Reset internal state to expect a new message
657-
self._expecting_header = True
658-
self._body_size = 0
659-
self._op_code = None # type: ignore[assignment]
660-
self._overflow = None
661-
self._overflow_index = 0
662-
self._is_compressed = False
663-
self._compressor_id = None
664-
# If at least one header's worth of data remains after the current message, reprocess all leftover data
665-
if self._end_index - self._start_index >= 16:
666-
self._read_waiter = asyncio.get_running_loop().create_future()
667-
self._pending_messages.append(self._read_waiter)
668-
nbytes_reprocess = self._end_index - self._start_index
669-
self._end_index -= nbytes_reprocess
670-
self.buffer_updated(nbytes_reprocess)
582+
self._message = memoryview(bytearray(self._message_size))
583+
return
584+
if self._expecting_compression:
585+
self._compression_index += nbytes
586+
if self._compression_index >= 9:
587+
self._op_code = self.process_compression_header()
588+
return
589+
590+
self._message_index += nbytes
591+
if self._message_index >= self._message_size:
592+
# Pause reading to avoid storing an arbitrary number of messages in memory.
593+
self.transport.pause_reading()
594+
if self._pending_messages:
595+
result = self._pending_messages.popleft()
596+
else:
597+
result = asyncio.get_running_loop().create_future()
598+
# Future has been cancelled, close this connection
599+
if result.done():
600+
self.connection_lost(None)
601+
return
602+
# Necessary values to construct message from buffers
603+
result.set_result((self._op_code, self._compressor_id, self._message))
604+
self._done_messages.append(result)
605+
# Reset internal state to expect a new message
606+
self._expecting_header = True
607+
self._header_index = 0
608+
self._compression_index = 0
609+
self._message_index = 0
610+
self._message_size = 0
611+
self._message = None
612+
self._op_code = 0
613+
self._compressor_id = None
671614

672615
def process_header(self) -> tuple[int, int]:
673616
"""Unpack a MongoDB Wire Protocol header."""
674-
length, _, response_to, op_code = _UNPACK_HEADER(
675-
self._buffer[self._start_index : self._start_index + 16]
676-
)
617+
length, _, response_to, op_code = _UNPACK_HEADER(self._header)
677618
# No request_id for exhaust cursor "getMore".
678619
if self._request_id is not None:
679620
if self._request_id != response_to:
@@ -689,24 +630,22 @@ def process_header(self) -> tuple[int, int]:
689630
f"Message length ({length!r}) is larger than server max "
690631
f"message size ({self._max_message_size!r})"
691632
)
692-
if op_code == 2012:
693-
self._is_compressed = True
694-
op_code, _, self._compressor_id = _UNPACK_COMPRESSION_HEADER(
695-
self._buffer[self._start_index + 16 : self._start_index + 25]
696-
)
697-
698-
return length, op_code
633+
if op_code == 2012: # OP_COMPRESSED
634+
if length <= 25:
635+
raise ProtocolError(
636+
f"Message length ({length!r}) not longer than standard OP_COMPRESSED message header size (25)"
637+
)
638+
self._expecting_compression = True
639+
length -= 9
699640

700-
def pause_writing(self) -> None:
701-
assert not self._paused
702-
self._paused = True
641+
self._expecting_header = False
642+
return length - 16, op_code
703643

704-
def resume_writing(self) -> None:
705-
assert self._paused
706-
self._paused = False
707-
#
708-
# if self._drain_waiter and not self._drain_waiter.done():
709-
# self._drain_waiter.set_result(None)
644+
def process_compression_header(self) -> tuple[int, int]:
645+
"""Unpack a MongoDB Wire Protocol header."""
646+
op_code, _, self._compressor_id = _UNPACK_COMPRESSION_HEADER(self._compression_header)
647+
self._expecting_compression = False
648+
return op_code
710649

711650
def connection_lost(self, exc: Exception | None) -> None:
712651
self._connection_lost = True
@@ -725,27 +664,6 @@ def connection_lost(self, exc: Exception | None) -> None:
725664
else:
726665
self._closed.set_exception(exc)
727666

728-
# # Wake up the writer(s) if currently paused.
729-
# if not self._paused:
730-
# return
731-
#
732-
# if self._drain_waiter and not self._drain_waiter.done():
733-
# if exc is None:
734-
# self._drain_waiter.set_result(None)
735-
# else:
736-
# self._drain_waiter.set_exception(exc)
737-
738-
async def _drain_helper(self) -> None:
739-
if self._connection_lost:
740-
raise ConnectionResetError("Connection lost")
741-
if not self._paused:
742-
return
743-
self._drain_waiter = asyncio.get_running_loop().create_future()
744-
await self._drain_waiter
745-
746-
def data(self) -> bytes:
747-
return self._buffer.tobytes()
748-
749667
async def wait_closed(self) -> None:
750668
await asyncio.wait([self._closed])
751669

0 commit comments

Comments
 (0)