@@ -482,16 +482,14 @@ def _resolve_pending(self, exc: Optional[Exception] = None) -> None:
482
482
class KMSBuffer :
483
483
buffer : memoryview
484
484
start_index : int
485
+ length : int
485
486
486
487
487
488
class PyMongoKMSProtocol (PyMongoBaseProtocol ):
488
489
def __init__ (self , timeout : Optional [float ] = None ):
489
490
super ().__init__ (timeout )
490
491
self ._buffers : collections .deque [KMSBuffer ] = collections .deque ()
491
- # pool for buffers that have been exhausted and can be reused.
492
- self ._buffer_pool : collections .deque [KMSBuffer ] = collections .deque (maxlen = 3 )
493
492
self ._bytes_ready = 0
494
- self ._bytes_requested = 0
495
493
self ._pending_reads : collections .deque [int ] = collections .deque ()
496
494
self ._pending_listeners : collections .deque [Future [Any ]] = collections .deque ()
497
495
@@ -507,7 +505,7 @@ async def read(self, bytes_needed: int, first=False) -> bytes:
507
505
# Wait for other listeners first.
508
506
if len (self ._pending_listeners ):
509
507
await asyncio .gather (* self ._pending_listeners )
510
- return self ._read (bytes_needed )
508
+ return self ._read (bytes_needed , first )
511
509
if self .transport :
512
510
try :
513
511
self .transport .resume_reading ()
@@ -527,10 +525,8 @@ def get_buffer(self, sizehint: int) -> memoryview:
527
525
If any data does not fit into the returned buffer, this method will be called again until
528
526
either no data remains or an empty buffer is returned.
529
527
"""
530
- if self ._buffer_pool :
531
- buffer = self ._buffer_pool .popleft ()
532
- else :
533
- buffer = KMSBuffer (memoryview (bytearray (sizehint )), 0 )
528
+ sizehint = max (sizehint , 1024 )
529
+ buffer = KMSBuffer (memoryview (bytearray (sizehint )), 0 , 0 )
534
530
self ._buffers .append (buffer )
535
531
return buffer .buffer
536
532
@@ -544,13 +540,17 @@ def buffer_updated(self, nbytes: int) -> None:
544
540
return
545
541
self ._bytes_ready += nbytes
546
542
547
- # Bail we don't have the current requested number of bytes.
548
- bytes_needed = self ._bytes_requested
549
- first = False
550
- if bytes_needed == 0 and self ._pending_reads :
551
- bytes_needed , first = self ._pending_reads .popleft ()
552
- read_first = first and self ._bytes_ready > 0
553
- if not read_first and (bytes_needed == 0 or self ._bytes_ready < bytes_needed ):
543
+ # Update the length of the current buffer.
544
+ current_buffer = self ._buffers .pop ()
545
+ current_buffer .length += nbytes
546
+ self ._buffers .append (current_buffer )
547
+
548
+ if not len (self ._pending_reads ):
549
+ return
550
+
551
+ bytes_needed , first = self ._pending_reads .popleft ()
552
+ if not first and (bytes_needed == 0 or self ._bytes_ready < bytes_needed ):
553
+ self ._pending_reads .appendleft ((bytes_needed , first ))
554
554
return
555
555
556
556
data = self ._read (bytes_needed , first )
@@ -563,31 +563,28 @@ def _read(self, bytes_needed, first=False):
563
563
if first and self ._bytes_ready < bytes_needed :
564
564
bytes_needed = self ._bytes_ready
565
565
self ._bytes_ready -= bytes_needed
566
- self ._bytes_requested = 0
567
566
568
567
output_buf = bytearray (bytes_needed )
569
568
n_remaining = bytes_needed
570
569
out_index = 0
571
570
while n_remaining > 0 :
572
571
buffer = self ._buffers .popleft ()
573
- buffer_remaining = len ( buffer .buffer ) - buffer .start_index
574
- # if we didn't exhaust the buffer, read the partial data and put it back .
572
+ buffer_remaining = buffer .length - buffer .start_index
573
+ # if we didn't exhaust the buffer, read the partial data and return the buffer .
575
574
if buffer_remaining > n_remaining :
576
575
output_buf [out_index : n_remaining + out_index ] = buffer .buffer [
577
576
buffer .start_index : buffer .start_index + n_remaining
578
577
]
579
578
buffer .start_index += n_remaining
580
579
n_remaining = 0
581
580
self ._buffers .appendleft (buffer )
582
- # otherwise exhaust the buffer and return it to the pool .
581
+ # otherwise exhaust the buffer.
583
582
else :
584
583
output_buf [out_index : out_index + buffer_remaining ] = buffer .buffer [
585
- buffer .start_index :
584
+ buffer .start_index : buffer . length
586
585
]
587
586
out_index += buffer_remaining
588
587
n_remaining -= buffer_remaining
589
- buffer .start_index = 0
590
- self ._buffer_pool .append (buffer )
591
588
return output_buf
592
589
593
590
0 commit comments