@@ -501,8 +501,13 @@ def connection_made(self, transport: BaseTransport) -> None:
501
501
"""
502
502
self .transport = transport # type: ignore[assignment]
503
503
504
- async def read (self , bytes_needed : int ) -> bytes :
504
+ async def read (self , bytes_needed : int , first = False ) -> bytes :
505
505
"""Read the requested bytes from this connection."""
506
+ if self ._bytes_ready >= bytes_needed or (self ._bytes_ready > 0 and first ):
507
+ # Wait for other listeners first.
508
+ if len (self ._pending_listeners ):
509
+ await asyncio .gather (* self ._pending_listeners )
510
+ return self ._read (bytes_needed )
506
511
if self .transport :
507
512
try :
508
513
self .transport .resume_reading ()
@@ -511,9 +516,7 @@ async def read(self, bytes_needed: int) -> bytes:
511
516
raise OSError ("connection is already closed" ) from None
512
517
if self .transport and self .transport .is_closing ():
513
518
raise OSError ("connection is already closed" )
514
- if self ._bytes_ready >= bytes_needed :
515
- return self ._read (bytes_needed )
516
- self ._pending_reads .append (bytes_needed )
519
+ self ._pending_reads .append ((bytes_needed , first ))
517
520
read_waiter = asyncio .get_running_loop ().create_future ()
518
521
self ._pending_listeners .append (read_waiter )
519
522
return await read_waiter
@@ -543,18 +546,22 @@ def buffer_updated(self, nbytes: int) -> None:
543
546
544
547
# Bail we don't have the current requested number of bytes.
545
548
bytes_needed = self ._bytes_requested
549
+ first = False
546
550
if bytes_needed == 0 and self ._pending_reads :
547
- bytes_needed = self ._pending_reads .popleft ()
548
- if bytes_needed == 0 or self ._bytes_ready < bytes_needed :
551
+ bytes_needed , first = self ._pending_reads .popleft ()
552
+ read_first = first and self ._bytes_ready > 0
553
+ if not read_first and (bytes_needed == 0 or self ._bytes_ready < bytes_needed ):
549
554
return
550
555
551
- data = self ._read (bytes_needed )
556
+ data = self ._read (bytes_needed , first )
552
557
waiter = self ._pending_listeners .popleft ()
553
558
waiter .set_result (data )
554
559
555
- def _read (self , bytes_needed ):
560
+ def _read (self , bytes_needed , first = False ):
556
561
"""Read bytes from the buffer."""
557
562
# Send the bytes to the listener.
563
+ if first and self ._bytes_ready < bytes_needed :
564
+ bytes_needed = self ._bytes_ready
558
565
self ._bytes_ready -= bytes_needed
559
566
self ._bytes_requested = 0
560
567
@@ -591,13 +598,13 @@ async def async_sendall(conn: PyMongoBaseProtocol, buf: bytes) -> None:
591
598
raise socket .timeout ("timed out" ) from exc
592
599
593
600
594
- async def async_receive_kms (conn : AsyncBaseConnection , bytes_needed : int ) -> bytes :
601
+ async def async_receive_kms (conn : AsyncBaseConnection , bytes_needed : int , first = False ) -> bytes :
595
602
"""Receive raw bytes from the kms connection."""
596
603
597
604
def callback (result : Any ) -> bytes :
598
605
return result
599
606
600
- return await _async_receive_data (conn , callback , bytes_needed )
607
+ return await _async_receive_data (conn , callback , bytes_needed , first )
601
608
602
609
603
610
async def _async_receive_data (
0 commit comments