@@ -101,7 +101,7 @@ async def close(self):
101
101
await self .conn [1 ].wait_closed ()
102
102
103
103
def is_closing (self ):
104
- self .conn [0 ].is_closing ()
104
+ return self .conn [0 ].is_closing ()
105
105
106
106
@property
107
107
def get_conn (self ) -> PyMongoProtocol :
@@ -340,10 +340,11 @@ def connection_lost(self, exc):
340
340
self ._connection_lost = True
341
341
pending = list (self ._pending_messages )
342
342
for msg in pending :
343
- if exc is None :
344
- msg .set_result (None )
345
- else :
346
- msg .set_exception (exc )
343
+ if not msg .done ():
344
+ if exc is None :
345
+ msg .set_result (None )
346
+ else :
347
+ msg .set_exception (exc )
347
348
self ._done_messages .append (msg )
348
349
349
350
if not self ._closed .done ():
@@ -374,7 +375,7 @@ def data(self):
374
375
return self ._buffer
375
376
376
377
async def wait_closed (self ):
377
- await self ._closed
378
+ await asyncio . wait ([ self ._closed ])
378
379
379
380
380
381
async def async_sendall (conn : PyMongoProtocol , buf : bytes ) -> None :
@@ -500,10 +501,10 @@ async def async_receive_message(
500
501
) -> Union [_OpReply , _OpMsg ]:
501
502
"""Receive a raw BSON message or raise socket.error."""
502
503
timeout : Optional [Union [float , int ]]
504
+ timeout = conn .conn .gettimeout
503
505
if _csot .get_timeout ():
504
506
deadline = _csot .get_deadline ()
505
507
else :
506
- timeout = conn .conn .get_conn .gettimeout
507
508
if timeout :
508
509
deadline = time .monotonic () + timeout
509
510
else :
@@ -517,24 +518,31 @@ async def async_receive_message(
517
518
cancellation_task = create_task (_poll_cancellation (conn ))
518
519
read_task = create_task (conn .conn .get_conn .read (request_id , max_message_size , debug ))
519
520
tasks = [read_task , cancellation_task ]
520
- done , pending = await asyncio .wait (tasks , timeout = timeout , return_when = asyncio .FIRST_COMPLETED )
521
- for task in pending :
522
- task .cancel ()
523
- if pending :
524
- await asyncio .wait (pending )
525
- if len (done ) == 0 :
526
- raise socket .timeout ("timed out" )
527
- if read_task in done :
528
- data , op_code = read_task .result ()
529
-
530
- try :
531
- unpack_reply = _UNPACK_REPLY [op_code ]
532
- except KeyError :
533
- raise ProtocolError (
534
- f"Got opcode { op_code !r} but expected { _UNPACK_REPLY .keys ()!r} "
535
- ) from None
536
- return unpack_reply (data )
537
- raise _OperationCancelled ("operation cancelled" )
521
+ try :
522
+ done , pending = await asyncio .wait (
523
+ tasks , timeout = timeout , return_when = asyncio .FIRST_COMPLETED
524
+ )
525
+ for task in pending :
526
+ task .cancel ()
527
+ if pending :
528
+ await asyncio .wait (pending )
529
+ if len (done ) == 0 :
530
+ raise socket .timeout ("timed out" )
531
+ if read_task in done :
532
+ data , op_code = read_task .result ()
533
+ try :
534
+ unpack_reply = _UNPACK_REPLY [op_code ]
535
+ except KeyError :
536
+ raise ProtocolError (
537
+ f"Got opcode { op_code !r} but expected { _UNPACK_REPLY .keys ()!r} "
538
+ ) from None
539
+ return unpack_reply (data )
540
+ raise _OperationCancelled ("operation cancelled" )
541
+ except asyncio .CancelledError :
542
+ for task in tasks :
543
+ task .cancel ()
544
+ await asyncio .wait (tasks )
545
+ raise
538
546
539
547
540
548
def receive_message (
0 commit comments