Skip to content

Commit 24aa733

Browse files
committed
wip
1 parent 18c51cb commit 24aa733

File tree

3 files changed

+23
-12
lines changed

3 files changed

+23
-12
lines changed

pymongo/asynchronous/encryption.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,15 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
191191
conn = AsyncBaseConnection(interface, opts)
192192
try:
193193
await async_sendall(interface.get_conn, message)
194+
first = True
194195
while kms_context.bytes_needed > 0:
195196
# CSOT: update timeout.
196197
interface.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
197-
data = await async_receive_kms(conn, kms_context.bytes_needed)
198+
data = await async_receive_kms(conn, kms_context.bytes_needed, first)
198199
if not data:
199200
raise OSError("KMS connection closed")
200201
kms_context.feed(data)
202+
first = False
201203
except MongoCryptError:
202204
raise # Propagate MongoCryptError errors directly.
203205
except Exception as exc:

pymongo/network_layer.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -501,8 +501,13 @@ def connection_made(self, transport: BaseTransport) -> None:
501501
"""
502502
self.transport = transport # type: ignore[assignment]
503503

504-
async def read(self, bytes_needed: int) -> bytes:
504+
async def read(self, bytes_needed: int, first=False) -> bytes:
505505
"""Read the requested bytes from this connection."""
506+
if self._bytes_ready >= bytes_needed or (self._bytes_ready > 0 and first):
507+
# Wait for other listeners first.
508+
if len(self._pending_listeners):
509+
await asyncio.gather(*self._pending_listeners)
510+
return self._read(bytes_needed)
506511
if self.transport:
507512
try:
508513
self.transport.resume_reading()
@@ -511,9 +516,7 @@ async def read(self, bytes_needed: int) -> bytes:
511516
raise OSError("connection is already closed") from None
512517
if self.transport and self.transport.is_closing():
513518
raise OSError("connection is already closed")
514-
if self._bytes_ready >= bytes_needed:
515-
return self._read(bytes_needed)
516-
self._pending_reads.append(bytes_needed)
519+
self._pending_reads.append((bytes_needed, first))
517520
read_waiter = asyncio.get_running_loop().create_future()
518521
self._pending_listeners.append(read_waiter)
519522
return await read_waiter
@@ -543,18 +546,22 @@ def buffer_updated(self, nbytes: int) -> None:
543546

544547
# Bail we don't have the current requested number of bytes.
545548
bytes_needed = self._bytes_requested
549+
first = False
546550
if bytes_needed == 0 and self._pending_reads:
547-
bytes_needed = self._pending_reads.popleft()
548-
if bytes_needed == 0 or self._bytes_ready < bytes_needed:
551+
bytes_needed, first = self._pending_reads.popleft()
552+
read_first = first and self._bytes_ready > 0
553+
if not read_first and (bytes_needed == 0 or self._bytes_ready < bytes_needed):
549554
return
550555

551-
data = self._read(bytes_needed)
556+
data = self._read(bytes_needed, first)
552557
waiter = self._pending_listeners.popleft()
553558
waiter.set_result(data)
554559

555-
def _read(self, bytes_needed):
560+
def _read(self, bytes_needed, first=False):
556561
"""Read bytes from the buffer."""
557562
# Send the bytes to the listener.
563+
if first and self._bytes_ready < bytes_needed:
564+
bytes_needed = self._bytes_ready
558565
self._bytes_ready -= bytes_needed
559566
self._bytes_requested = 0
560567

@@ -591,13 +598,13 @@ async def async_sendall(conn: PyMongoBaseProtocol, buf: bytes) -> None:
591598
raise socket.timeout("timed out") from exc
592599

593600

594-
async def async_receive_kms(conn: AsyncBaseConnection, bytes_needed: int) -> bytes:
601+
async def async_receive_kms(conn: AsyncBaseConnection, bytes_needed: int, first=False) -> bytes:
595602
"""Receive raw bytes from the kms connection."""
596603

597604
def callback(result: Any) -> bytes:
598605
return result
599606

600-
return await _async_receive_data(conn, callback, bytes_needed)
607+
return await _async_receive_data(conn, callback, bytes_needed, first)
601608

602609

603610
async def _async_receive_data(

pymongo/synchronous/encryption.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,13 +190,15 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
190190
conn = BaseConnection(interface, opts)
191191
try:
192192
sendall(interface.get_conn, message)
193+
first = True
193194
while kms_context.bytes_needed > 0:
194195
# CSOT: update timeout.
195196
interface.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
196-
data = receive_kms(conn, kms_context.bytes_needed)
197+
data = receive_kms(conn, kms_context.bytes_needed, first)
197198
if not data:
198199
raise OSError("KMS connection closed")
199200
kms_context.feed(data)
201+
first = False
200202
except MongoCryptError:
201203
raise # Propagate MongoCryptError errors directly.
202204
except Exception as exc:

0 commit comments

Comments
 (0)