Skip to content

Commit fa0dd8d

Browse files
committed
always allow partial reads
1 parent 4b1bdd6 commit fa0dd8d

File tree

3 files changed

+13
-21
lines changed

3 files changed

+13
-21
lines changed

pymongo/asynchronous/encryption.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,15 +191,13 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
191191
conn = AsyncBaseConnection(interface, opts)
192192
try:
193193
await async_sendall(interface.get_conn, message)
194-
first = True
195194
while kms_context.bytes_needed > 0:
196195
# CSOT: update timeout.
197196
interface.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
198-
data = await async_receive_kms(conn, kms_context.bytes_needed, first)
197+
data = await async_receive_kms(conn, kms_context.bytes_needed)
199198
if not data:
200199
raise OSError("KMS connection closed")
201200
kms_context.feed(data)
202-
first = False
203201
except MongoCryptError:
204202
raise # Propagate MongoCryptError errors directly.
205203
except Exception as exc:

pymongo/network_layer.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -499,13 +499,13 @@ def connection_made(self, transport: BaseTransport) -> None:
499499
"""
500500
self.transport = transport # type: ignore[assignment]
501501

502-
async def read(self, bytes_needed: int, first=False) -> bytes:
503-
"""Read the requested bytes from this connection."""
504-
if self._bytes_ready >= bytes_needed or (self._bytes_ready > 0 and first):
502+
async def read(self, bytes_needed: int) -> bytes:
503+
"""Read up to the requested bytes from this connection."""
504+
if self._bytes_ready > 0:
505505
# Wait for other listeners first.
506506
if len(self._pending_listeners):
507507
await asyncio.gather(*self._pending_listeners)
508-
return self._read(bytes_needed, first)
508+
return self._read(bytes_needed)
509509
if self.transport:
510510
try:
511511
self.transport.resume_reading()
@@ -514,7 +514,7 @@ async def read(self, bytes_needed: int, first=False) -> bytes:
514514
raise OSError("connection is already closed") from None
515515
if self.transport and self.transport.is_closing():
516516
raise OSError("connection is already closed")
517-
self._pending_reads.append((bytes_needed, first))
517+
self._pending_reads.append(bytes_needed)
518518
read_waiter = asyncio.get_running_loop().create_future()
519519
self._pending_listeners.append(read_waiter)
520520
return await read_waiter
@@ -548,19 +548,15 @@ def buffer_updated(self, nbytes: int) -> None:
548548
if not len(self._pending_reads):
549549
return
550550

551-
bytes_needed, first = self._pending_reads.popleft()
552-
if not first and (bytes_needed == 0 or self._bytes_ready < bytes_needed):
553-
self._pending_reads.appendleft((bytes_needed, first))
554-
return
555-
556-
data = self._read(bytes_needed, first)
551+
bytes_needed = self._pending_reads.popleft()
552+
data = self._read(bytes_needed)
557553
waiter = self._pending_listeners.popleft()
558554
waiter.set_result(data)
559555

560-
def _read(self, bytes_needed, first=False):
556+
def _read(self, bytes_needed):
561557
"""Read bytes from the buffer."""
562558
# Send the bytes to the listener.
563-
if first and self._bytes_ready < bytes_needed:
559+
if self._bytes_ready < bytes_needed:
564560
bytes_needed = self._bytes_ready
565561
self._bytes_ready -= bytes_needed
566562

@@ -596,13 +592,13 @@ async def async_sendall(conn: PyMongoBaseProtocol, buf: bytes) -> None:
596592
raise socket.timeout("timed out") from exc
597593

598594

599-
async def async_receive_kms(conn: AsyncBaseConnection, bytes_needed: int, first=False) -> bytes:
595+
async def async_receive_kms(conn: AsyncBaseConnection, bytes_needed: int) -> bytes:
600596
"""Receive raw bytes from the kms connection."""
601597

602598
def callback(result: Any) -> bytes:
603599
return result
604600

605-
return await _async_receive_data(conn, callback, bytes_needed, first)
601+
return await _async_receive_data(conn, callback, bytes_needed)
606602

607603

608604
async def _async_receive_data(

pymongo/synchronous/encryption.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,15 +190,13 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
190190
conn = BaseConnection(interface, opts)
191191
try:
192192
sendall(interface.get_conn, message)
193-
first = True
194193
while kms_context.bytes_needed > 0:
195194
# CSOT: update timeout.
196195
interface.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
197-
data = receive_kms(conn, kms_context.bytes_needed, first)
196+
data = receive_kms(conn, kms_context.bytes_needed)
198197
if not data:
199198
raise OSError("KMS connection closed")
200199
kms_context.feed(data)
201-
first = False
202200
except MongoCryptError:
203201
raise # Propagate MongoCryptError errors directly.
204202
except Exception as exc:

0 commit comments

Comments
 (0)