@@ -481,7 +481,7 @@ def _resolve_pending(self, exc: Optional[Exception] = None) -> None:
481
481
class KMSBuffer :
482
482
buffer : memoryview
483
483
start_index : int
484
- length : int
484
+ end_index : int
485
485
486
486
487
487
class PyMongoKMSProtocol (PyMongoBaseProtocol ):
@@ -524,11 +524,21 @@ def get_buffer(self, sizehint: int) -> memoryview:
524
524
If any data does not fit into the returned buffer, this method will be called again until
525
525
either no data remains or an empty buffer is returned.
526
526
"""
527
- sizehint = max (sizehint , 1024 )
528
- buffer = KMSBuffer (memoryview (bytearray (sizehint )), 0 , 0 )
527
+ # Reuse the active buffer if it has space.
528
+ if len (self ._buffers ):
529
+ buffer = self ._buffers [- 1 ]
530
+ if len (buffer .buffer ) - buffer .end_index > sizehint :
531
+ return buffer .buffer [buffer .end_index :]
532
+ # Allocate a bit more than the max response size for an AWS KMS response.
533
+ buffer = KMSBuffer (memoryview (bytearray (16384 )), 0 , 0 )
529
534
self ._buffers .append (buffer )
530
535
return buffer .buffer
531
536
537
+ def _resolve_pending (self , exc : Optional [Exception ] = None ) -> None :
538
+ while self ._pending_listeners :
539
+ fut = self ._pending_listeners .popleft ()
540
+ fut .set_result (b"" )
541
+
532
542
def buffer_updated (self , nbytes : int ) -> None :
533
543
"""Called when the buffer was updated with the received data"""
534
544
# Wrote 0 bytes into a non-empty buffer, signal connection closed
@@ -540,9 +550,7 @@ def buffer_updated(self, nbytes: int) -> None:
540
550
self ._bytes_ready += nbytes
541
551
542
552
# Update the length of the current buffer.
543
- current_buffer = self ._buffers .pop ()
544
- current_buffer .length += nbytes
545
- self ._buffers .append (current_buffer )
553
+ self ._buffers [- 1 ].end_index += nbytes
546
554
547
555
if not len (self ._pending_reads ):
548
556
return
@@ -564,7 +572,7 @@ def _read(self, bytes_needed: int) -> memoryview:
564
572
out_index = 0
565
573
while n_remaining > 0 :
566
574
buffer = self ._buffers .popleft ()
567
- buffer_remaining = buffer .length - buffer .start_index
575
+ buffer_remaining = buffer .end_index - buffer .start_index
568
576
# if we didn't exhaust the buffer, read the partial data and return the buffer.
569
577
if buffer_remaining > n_remaining :
570
578
output_buf [out_index : n_remaining + out_index ] = buffer .buffer [
@@ -576,10 +584,14 @@ def _read(self, bytes_needed: int) -> memoryview:
576
584
# otherwise exhaust the buffer.
577
585
else :
578
586
output_buf [out_index : out_index + buffer_remaining ] = buffer .buffer [
579
- buffer .start_index : buffer .length
587
+ buffer .start_index : buffer .end_index
580
588
]
581
589
out_index += buffer_remaining
582
590
n_remaining -= buffer_remaining
591
+ # if this is the only buffer, add it back to the queue.
592
+ if not len (self ._buffers ):
593
+ buffer .start_index = buffer .end_index
594
+ self ._buffers .appendleft (buffer )
583
595
return memoryview (output_buf )
584
596
585
597
0 commit comments