@@ -499,13 +499,13 @@ def connection_made(self, transport: BaseTransport) -> None:
499
499
"""
500
500
self .transport = transport # type: ignore[assignment]
501
501
502
- async def read (self , bytes_needed : int , first = False ) -> bytes :
503
- """Read the requested bytes from this connection."""
504
- if self ._bytes_ready >= bytes_needed or ( self . _bytes_ready > 0 and first ) :
502
+ async def read (self , bytes_needed : int ) -> bytes :
503
+ """Read up to the requested bytes from this connection."""
504
+ if self ._bytes_ready > 0 :
505
505
# Wait for other listeners first.
506
506
if len (self ._pending_listeners ):
507
507
await asyncio .gather (* self ._pending_listeners )
508
- return self ._read (bytes_needed , first )
508
+ return self ._read (bytes_needed )
509
509
if self .transport :
510
510
try :
511
511
self .transport .resume_reading ()
@@ -514,7 +514,7 @@ async def read(self, bytes_needed: int, first=False) -> bytes:
514
514
raise OSError ("connection is already closed" ) from None
515
515
if self .transport and self .transport .is_closing ():
516
516
raise OSError ("connection is already closed" )
517
- self ._pending_reads .append (( bytes_needed , first ) )
517
+ self ._pending_reads .append (bytes_needed )
518
518
read_waiter = asyncio .get_running_loop ().create_future ()
519
519
self ._pending_listeners .append (read_waiter )
520
520
return await read_waiter
@@ -548,19 +548,15 @@ def buffer_updated(self, nbytes: int) -> None:
548
548
if not len (self ._pending_reads ):
549
549
return
550
550
551
- bytes_needed , first = self ._pending_reads .popleft ()
552
- if not first and (bytes_needed == 0 or self ._bytes_ready < bytes_needed ):
553
- self ._pending_reads .appendleft ((bytes_needed , first ))
554
- return
555
-
556
- data = self ._read (bytes_needed , first )
551
+ bytes_needed = self ._pending_reads .popleft ()
552
+ data = self ._read (bytes_needed )
557
553
waiter = self ._pending_listeners .popleft ()
558
554
waiter .set_result (data )
559
555
560
- def _read (self , bytes_needed , first = False ):
556
+ def _read (self , bytes_needed ):
561
557
"""Read bytes from the buffer."""
562
558
# Send the bytes to the listener.
563
- if first and self ._bytes_ready < bytes_needed :
559
+ if self ._bytes_ready < bytes_needed :
564
560
bytes_needed = self ._bytes_ready
565
561
self ._bytes_ready -= bytes_needed
566
562
@@ -596,13 +592,13 @@ async def async_sendall(conn: PyMongoBaseProtocol, buf: bytes) -> None:
596
592
raise socket .timeout ("timed out" ) from exc
597
593
598
594
599
- async def async_receive_kms (conn : AsyncBaseConnection , bytes_needed : int , first = False ) -> bytes :
595
+ async def async_receive_kms (conn : AsyncBaseConnection , bytes_needed : int ) -> bytes :
600
596
"""Receive raw bytes from the kms connection."""
601
597
602
598
def callback (result : Any ) -> bytes :
603
599
return result
604
600
605
- return await _async_receive_data (conn , callback , bytes_needed , first )
601
+ return await _async_receive_data (conn , callback , bytes_needed )
606
602
607
603
608
604
async def _async_receive_data (
0 commit comments