@@ -568,7 +568,7 @@ def buffer_updated(self, nbytes: int) -> None:
568568 """Called when the buffer was updated with the received data"""
569569 # Wrote 0 bytes into a non-empty buffer, signal connection closed
570570 if nbytes == 0 :
571- self .connection_lost (OSError ("connection closed" ))
571+ self .close (OSError ("connection closed" ))
572572 return
573573 if self ._connection_lost :
574574 return
@@ -579,7 +579,7 @@ def buffer_updated(self, nbytes: int) -> None:
579579 try :
580580 self ._message_size , self ._op_code = self .process_header ()
581581 except ProtocolError as exc :
582- self .connection_lost (exc )
582+ self .close (exc )
583583 return
584584 self ._message = memoryview (bytearray (self ._message_size ))
585585 return
@@ -601,7 +601,7 @@ def buffer_updated(self, nbytes: int) -> None:
601601 result = asyncio .get_running_loop ().create_future ()
602602 # Future has been cancelled, close this connection
603603 if result .done ():
604- self .connection_lost (None )
604+ self .close (None )
605605 return
606606 # Necessary values to construct message from buffers
607607 result .set_result ((self ._op_code , self ._compressor_id , self ._message ))
@@ -648,9 +648,7 @@ def process_compression_header(self) -> tuple[int, int]:
648648 op_code , _ , compressor_id = _UNPACK_COMPRESSION_HEADER (self ._compression_header )
649649 return op_code , compressor_id
650650
651- def connection_lost (self , exc : Exception | None ) -> None :
652- self ._connection_lost = True
653- super ().connection_lost (exc )
651+ def _resolve_pending_messages (self , exc : Exception | None ) -> None :
654652 pending = list (self ._pending_messages )
655653 for msg in pending :
656654 if not msg .done ():
@@ -660,6 +658,13 @@ def connection_lost(self, exc: Exception | None) -> None:
660658 msg .set_exception (exc )
661659 self ._done_messages .append (msg )
662660
661+ def close (self , exc : Exception | None ) -> None :
662+ self ._connection_lost = True
663+ self ._resolve_pending_messages (exc )
664+ self .transport .close ()
665+
666+ def connection_lost (self , exc : Exception | None ) -> None :
667+ self ._resolve_pending_messages (exc )
663668 if not self ._closed .done ():
664669 self ._closed .set_result (None )
665670
0 commit comments