Skip to content

Commit 971139c

Browse files
committed
fix buffer handling and close handling
1 parent 6ed92bb commit 971139c

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

pymongo/network_layer.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ def _resolve_pending(self, exc: Optional[Exception] = None) -> None:
481481
class KMSBuffer:
482482
buffer: memoryview
483483
start_index: int
484-
length: int
484+
end_index: int
485485

486486

487487
class PyMongoKMSProtocol(PyMongoBaseProtocol):
@@ -524,11 +524,21 @@ def get_buffer(self, sizehint: int) -> memoryview:
524524
If any data does not fit into the returned buffer, this method will be called again until
525525
either no data remains or an empty buffer is returned.
526526
"""
527-
sizehint = max(sizehint, 1024)
528-
buffer = KMSBuffer(memoryview(bytearray(sizehint)), 0, 0)
527+
# Reuse the active buffer if it has space.
528+
if len(self._buffers):
529+
buffer = self._buffers[-1]
530+
if len(buffer.buffer) - buffer.end_index > sizehint:
531+
return buffer.buffer[buffer.end_index :]
532+
# Allocate a bit more than the max response size for an AWS KMS response.
533+
buffer = KMSBuffer(memoryview(bytearray(16384)), 0, 0)
529534
self._buffers.append(buffer)
530535
return buffer.buffer
531536

537+
def _resolve_pending(self, exc: Optional[Exception] = None) -> None:
538+
while self._pending_listeners:
539+
fut = self._pending_listeners.popleft()
540+
fut.set_result(b"")
541+
532542
def buffer_updated(self, nbytes: int) -> None:
533543
"""Called when the buffer was updated with the received data"""
534544
# Wrote 0 bytes into a non-empty buffer, signal connection closed
@@ -540,9 +550,7 @@ def buffer_updated(self, nbytes: int) -> None:
540550
self._bytes_ready += nbytes
541551

542552
# Update the length of the current buffer.
543-
current_buffer = self._buffers.pop()
544-
current_buffer.length += nbytes
545-
self._buffers.append(current_buffer)
553+
self._buffers[-1].end_index += nbytes
546554

547555
if not len(self._pending_reads):
548556
return
@@ -564,7 +572,7 @@ def _read(self, bytes_needed: int) -> memoryview:
564572
out_index = 0
565573
while n_remaining > 0:
566574
buffer = self._buffers.popleft()
567-
buffer_remaining = buffer.length - buffer.start_index
575+
buffer_remaining = buffer.end_index - buffer.start_index
568576
# if we didn't exhaust the buffer, read the partial data and return the buffer.
569577
if buffer_remaining > n_remaining:
570578
output_buf[out_index : n_remaining + out_index] = buffer.buffer[
@@ -576,10 +584,14 @@ def _read(self, bytes_needed: int) -> memoryview:
576584
# otherwise exhaust the buffer.
577585
else:
578586
output_buf[out_index : out_index + buffer_remaining] = buffer.buffer[
579-
buffer.start_index : buffer.length
587+
buffer.start_index : buffer.end_index
580588
]
581589
out_index += buffer_remaining
582590
n_remaining -= buffer_remaining
591+
# if this is the only buffer, add it back to the queue.
592+
if not len(self._buffers):
593+
buffer.start_index = buffer.end_index
594+
self._buffers.appendleft(buffer)
583595
return memoryview(output_buf)
584596

585597

0 commit comments

Comments
 (0)