Skip to content

Commit 993de69

Browse files
committed
PYTHON-4493 Add workaround for SSL ProtocolError issues
1 parent f5cbce0 commit 993de69

File tree

1 file changed

+25
-19
lines changed

1 file changed

+25
-19
lines changed

pymongo/network_layer.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -511,12 +511,8 @@ async def write(self, message: bytes) -> None:
511511
"""Write a message to this connection's transport."""
512512
if self.transport.is_closing():
513513
raise OSError("Connection is closed")
514-
try:
515-
self.transport.resume_reading()
516-
# Known bug in SSL Protocols, fixed in Python 3.11: https://github.com/python/cpython/issues/89322
517-
except AttributeError:
518-
raise OSError("connection is already closed") from None
519514
self.transport.write(message)
515+
self.transport.resume_reading()
520516

521517
async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[bytes, int]:
522518
"""Read a single MongoDB Wire Protocol message from this connection."""
@@ -553,6 +549,13 @@ def get_buffer(self, sizehint: int) -> memoryview:
553549
If any data does not fit into the returned buffer, this method will be called again until
554550
either no data remains or an empty buffer is returned.
555551
"""
552+
# Due to a bug, Python <=3.11 will call get_buffer() even after we raise
553+
# ProtocolError in buffer_updated() and call connection_lost(). We allocate
554+
# a temp buffer to drain the waiting data.
555+
if self._connection_lost:
556+
if not self._message:
557+
self._message = memoryview(bytearray(2**14))
558+
return self._message
556559
# TODO: optimize this by caching pointers to the buffers.
557560
# return self._buffer[self._index:]
558561
if self._expecting_header:
@@ -567,9 +570,12 @@ def buffer_updated(self, nbytes: int) -> None:
567570
if nbytes == 0:
568571
self.connection_lost(OSError("connection closed"))
569572
return
573+
if self._connection_lost:
574+
return
570575
if self._expecting_header:
571576
self._header_index += nbytes
572577
if self._header_index >= 16:
578+
self._expecting_header = False
573579
try:
574580
self._message_size, self._op_code = self.process_header()
575581
except ProtocolError as exc:
@@ -580,11 +586,13 @@ def buffer_updated(self, nbytes: int) -> None:
580586
if self._expecting_compression:
581587
self._compression_index += nbytes
582588
if self._compression_index >= 9:
583-
self._op_code = self.process_compression_header()
589+
self._expecting_compression = False
590+
self._op_code, self._compressor_id = self.process_compression_header()
584591
return
585592

586593
self._message_index += nbytes
587594
if self._message_index >= self._message_size:
595+
self._expecting_header = True
588596
# Pause reading to avoid storing an arbitrary number of messages in memory.
589597
self.transport.pause_reading()
590598
if self._pending_messages:
@@ -599,7 +607,6 @@ def buffer_updated(self, nbytes: int) -> None:
599607
result.set_result((self._op_code, self._compressor_id, self._message))
600608
self._done_messages.append(result)
601609
# Reset internal state to expect a new message
602-
self._expecting_header = True
603610
self._header_index = 0
604611
self._compression_index = 0
605612
self._message_index = 0
@@ -611,6 +618,13 @@ def buffer_updated(self, nbytes: int) -> None:
611618
def process_header(self) -> tuple[int, int]:
612619
"""Unpack a MongoDB Wire Protocol header."""
613620
length, _, response_to, op_code = _UNPACK_HEADER(self._header)
621+
if op_code == 2012: # OP_COMPRESSED
622+
if length <= 25:
623+
raise ProtocolError(
624+
f"Message length ({length!r}) not longer than standard OP_COMPRESSED message header size (25)"
625+
)
626+
self._expecting_compression = True
627+
length -= 9
614628
# No request_id for exhaust cursor "getMore".
615629
if self._request_id is not None:
616630
if self._request_id != response_to:
@@ -626,25 +640,17 @@ def process_header(self) -> tuple[int, int]:
626640
f"Message length ({length!r}) is larger than server max "
627641
f"message size ({self._max_message_size!r})"
628642
)
629-
if op_code == 2012: # OP_COMPRESSED
630-
if length <= 25:
631-
raise ProtocolError(
632-
f"Message length ({length!r}) not longer than standard OP_COMPRESSED message header size (25)"
633-
)
634-
self._expecting_compression = True
635-
length -= 9
636643

637-
self._expecting_header = False
638644
return length - 16, op_code
639645

640-
def process_compression_header(self) -> int:
646+
def process_compression_header(self) -> tuple[int, int]:
641647
"""Unpack a MongoDB Wire Protocol compression header."""
642-
op_code, _, self._compressor_id = _UNPACK_COMPRESSION_HEADER(self._compression_header)
643-
self._expecting_compression = False
644-
return op_code
648+
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(self._compression_header)
649+
return op_code, compressor_id
645650

646651
def connection_lost(self, exc: Exception | None) -> None:
647652
self._connection_lost = True
653+
super().connection_lost(exc)
648654
pending = list(self._pending_messages)
649655
for msg in pending:
650656
if not msg.done():

0 commit comments

Comments
 (0)