@@ -568,7 +568,7 @@ def buffer_updated(self, nbytes: int) -> None:
568
568
"""Called when the buffer was updated with the received data"""
569
569
# Wrote 0 bytes into a non-empty buffer, signal connection closed
570
570
if nbytes == 0 :
571
- self .connection_lost (OSError ("connection closed" ))
571
+ self .close (OSError ("connection closed" ))
572
572
return
573
573
if self ._connection_lost :
574
574
return
@@ -579,7 +579,7 @@ def buffer_updated(self, nbytes: int) -> None:
579
579
try :
580
580
self ._message_size , self ._op_code = self .process_header ()
581
581
except ProtocolError as exc :
582
- self .connection_lost (exc )
582
+ self .close (exc )
583
583
return
584
584
self ._message = memoryview (bytearray (self ._message_size ))
585
585
return
@@ -601,7 +601,7 @@ def buffer_updated(self, nbytes: int) -> None:
601
601
result = asyncio .get_running_loop ().create_future ()
602
602
# Future has been cancelled, close this connection
603
603
if result .done ():
604
- self .connection_lost (None )
604
+ self .close (None )
605
605
return
606
606
# Necessary values to construct message from buffers
607
607
result .set_result ((self ._op_code , self ._compressor_id , self ._message ))
@@ -648,9 +648,7 @@ def process_compression_header(self) -> tuple[int, int]:
648
648
op_code , _ , compressor_id = _UNPACK_COMPRESSION_HEADER (self ._compression_header )
649
649
return op_code , compressor_id
650
650
651
- def connection_lost (self , exc : Exception | None ) -> None :
652
- self ._connection_lost = True
653
- super ().connection_lost (exc )
651
+ def _resolve_pending_messages (self , exc : Exception | None ) -> None :
654
652
pending = list (self ._pending_messages )
655
653
for msg in pending :
656
654
if not msg .done ():
@@ -660,6 +658,13 @@ def connection_lost(self, exc: Exception | None) -> None:
660
658
msg .set_exception (exc )
661
659
self ._done_messages .append (msg )
662
660
661
+ def close (self , exc : Exception | None ) -> None :
662
+ self ._connection_lost = True
663
+ self ._resolve_pending_messages (exc )
664
+ self .transport .close ()
665
+
666
+ def connection_lost (self , exc : Exception | None ) -> None :
667
+ self ._resolve_pending_messages (exc )
663
668
if not self ._closed .done ():
664
669
self ._closed .set_result (None )
665
670
0 commit comments