@@ -101,7 +101,7 @@ async def close(self):
101101 await self .conn [1 ].wait_closed ()
102102
103103 def is_closing (self ):
104- self .conn [0 ].is_closing ()
104+ return self .conn [0 ].is_closing ()
105105
106106 @property
107107 def get_conn (self ) -> PyMongoProtocol :
@@ -340,10 +340,11 @@ def connection_lost(self, exc):
340340 self ._connection_lost = True
341341 pending = list (self ._pending_messages )
342342 for msg in pending :
343- if exc is None :
344- msg .set_result (None )
345- else :
346- msg .set_exception (exc )
343+ if not msg .done ():
344+ if exc is None :
345+ msg .set_result (None )
346+ else :
347+ msg .set_exception (exc )
347348 self ._done_messages .append (msg )
348349
349350 if not self ._closed .done ():
@@ -374,7 +375,7 @@ def data(self):
374375 return self ._buffer
375376
376377 async def wait_closed (self ):
377- await self ._closed
378+ await asyncio . wait ([ self ._closed ])
378379
379380
380381async def async_sendall (conn : PyMongoProtocol , buf : bytes ) -> None :
@@ -500,10 +501,10 @@ async def async_receive_message(
500501) -> Union [_OpReply , _OpMsg ]:
501502 """Receive a raw BSON message or raise socket.error."""
502503 timeout : Optional [Union [float , int ]]
504+ timeout = conn .conn .gettimeout
503505 if _csot .get_timeout ():
504506 deadline = _csot .get_deadline ()
505507 else :
506- timeout = conn .conn .get_conn .gettimeout
507508 if timeout :
508509 deadline = time .monotonic () + timeout
509510 else :
@@ -517,24 +518,31 @@ async def async_receive_message(
517518 cancellation_task = create_task (_poll_cancellation (conn ))
518519 read_task = create_task (conn .conn .get_conn .read (request_id , max_message_size , debug ))
519520 tasks = [read_task , cancellation_task ]
520- done , pending = await asyncio .wait (tasks , timeout = timeout , return_when = asyncio .FIRST_COMPLETED )
521- for task in pending :
522- task .cancel ()
523- if pending :
524- await asyncio .wait (pending )
525- if len (done ) == 0 :
526- raise socket .timeout ("timed out" )
527- if read_task in done :
528- data , op_code = read_task .result ()
529-
530- try :
531- unpack_reply = _UNPACK_REPLY [op_code ]
532- except KeyError :
533- raise ProtocolError (
534- f"Got opcode { op_code !r} but expected { _UNPACK_REPLY .keys ()!r} "
535- ) from None
536- return unpack_reply (data )
537- raise _OperationCancelled ("operation cancelled" )
521+ try :
522+ done , pending = await asyncio .wait (
523+ tasks , timeout = timeout , return_when = asyncio .FIRST_COMPLETED
524+ )
525+ for task in pending :
526+ task .cancel ()
527+ if pending :
528+ await asyncio .wait (pending )
529+ if len (done ) == 0 :
530+ raise socket .timeout ("timed out" )
531+ if read_task in done :
532+ data , op_code = read_task .result ()
533+ try :
534+ unpack_reply = _UNPACK_REPLY [op_code ]
535+ except KeyError :
536+ raise ProtocolError (
537+ f"Got opcode { op_code !r} but expected { _UNPACK_REPLY .keys ()!r} "
538+ ) from None
539+ return unpack_reply (data )
540+ raise _OperationCancelled ("operation cancelled" )
541+ except asyncio .CancelledError :
542+ for task in tasks :
543+ task .cancel ()
544+ await asyncio .wait (tasks )
545+ raise
538546
539547
540548def receive_message (
0 commit comments