Skip to content

Commit 4b1bdd6

Browse files
committed
fixup
1 parent b33f78e commit 4b1bdd6

File tree

1 file changed

+19
-22
lines changed

1 file changed

+19
-22
lines changed

pymongo/network_layer.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -482,16 +482,14 @@ def _resolve_pending(self, exc: Optional[Exception] = None) -> None:
482482
class KMSBuffer:
483483
buffer: memoryview
484484
start_index: int
485+
length: int
485486

486487

487488
class PyMongoKMSProtocol(PyMongoBaseProtocol):
488489
def __init__(self, timeout: Optional[float] = None):
489490
super().__init__(timeout)
490491
self._buffers: collections.deque[KMSBuffer] = collections.deque()
491-
# pool for buffers that have been exhausted and can be reused.
492-
self._buffer_pool: collections.deque[KMSBuffer] = collections.deque(maxlen=3)
493492
self._bytes_ready = 0
494-
self._bytes_requested = 0
495493
self._pending_reads: collections.deque[int] = collections.deque()
496494
self._pending_listeners: collections.deque[Future[Any]] = collections.deque()
497495

@@ -507,7 +505,7 @@ async def read(self, bytes_needed: int, first=False) -> bytes:
507505
# Wait for other listeners first.
508506
if len(self._pending_listeners):
509507
await asyncio.gather(*self._pending_listeners)
510-
return self._read(bytes_needed)
508+
return self._read(bytes_needed, first)
511509
if self.transport:
512510
try:
513511
self.transport.resume_reading()
@@ -527,10 +525,8 @@ def get_buffer(self, sizehint: int) -> memoryview:
527525
If any data does not fit into the returned buffer, this method will be called again until
528526
either no data remains or an empty buffer is returned.
529527
"""
530-
if self._buffer_pool:
531-
buffer = self._buffer_pool.popleft()
532-
else:
533-
buffer = KMSBuffer(memoryview(bytearray(sizehint)), 0)
528+
sizehint = max(sizehint, 1024)
529+
buffer = KMSBuffer(memoryview(bytearray(sizehint)), 0, 0)
534530
self._buffers.append(buffer)
535531
return buffer.buffer
536532

@@ -544,13 +540,17 @@ def buffer_updated(self, nbytes: int) -> None:
544540
return
545541
self._bytes_ready += nbytes
546542

547-
# Bail we don't have the current requested number of bytes.
548-
bytes_needed = self._bytes_requested
549-
first = False
550-
if bytes_needed == 0 and self._pending_reads:
551-
bytes_needed, first = self._pending_reads.popleft()
552-
read_first = first and self._bytes_ready > 0
553-
if not read_first and (bytes_needed == 0 or self._bytes_ready < bytes_needed):
543+
# Update the length of the current buffer.
544+
current_buffer = self._buffers.pop()
545+
current_buffer.length += nbytes
546+
self._buffers.append(current_buffer)
547+
548+
if not len(self._pending_reads):
549+
return
550+
551+
bytes_needed, first = self._pending_reads.popleft()
552+
if not first and (bytes_needed == 0 or self._bytes_ready < bytes_needed):
553+
self._pending_reads.appendleft((bytes_needed, first))
554554
return
555555

556556
data = self._read(bytes_needed, first)
@@ -563,31 +563,28 @@ def _read(self, bytes_needed, first=False):
563563
if first and self._bytes_ready < bytes_needed:
564564
bytes_needed = self._bytes_ready
565565
self._bytes_ready -= bytes_needed
566-
self._bytes_requested = 0
567566

568567
output_buf = bytearray(bytes_needed)
569568
n_remaining = bytes_needed
570569
out_index = 0
571570
while n_remaining > 0:
572571
buffer = self._buffers.popleft()
573-
buffer_remaining = len(buffer.buffer) - buffer.start_index
574-
# if we didn't exhaust the buffer, read the partial data and put it back.
572+
buffer_remaining = buffer.length - buffer.start_index
573+
# if we didn't exhaust the buffer, read the partial data and return the buffer.
575574
if buffer_remaining > n_remaining:
576575
output_buf[out_index : n_remaining + out_index] = buffer.buffer[
577576
buffer.start_index : buffer.start_index + n_remaining
578577
]
579578
buffer.start_index += n_remaining
580579
n_remaining = 0
581580
self._buffers.appendleft(buffer)
582-
# otherwise exhaust the buffer and return it to the pool.
581+
# otherwise exhaust the buffer.
583582
else:
584583
output_buf[out_index : out_index + buffer_remaining] = buffer.buffer[
585-
buffer.start_index :
584+
buffer.start_index : buffer.length
586585
]
587586
out_index += buffer_remaining
588587
n_remaining -= buffer_remaining
589-
buffer.start_index = 0
590-
self._buffer_pool.append(buffer)
591588
return output_buf
592589

593590

0 commit comments

Comments
 (0)