Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Lib/asyncio/sslproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_protocol(self):
return self._ssl_protocol._app_protocol

def is_closing(self):
return self._closed
return self._closed or self._ssl_protocol._is_transport_closing()

def close(self):
"""Close the transport.
Expand Down Expand Up @@ -379,6 +379,9 @@ def _get_app_transport(self):
self._app_transport_created = True
return self._app_transport

def _is_transport_closing(self):
return self._transport is not None and self._transport.is_closing()

def connection_made(self, transport):
"""Called when the low-level connection is made.
Expand Down
45 changes: 45 additions & 0 deletions Lib/test/test_asyncio/test_sslproto.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,51 @@ def test_connection_lost(self):
test_utils.run_briefly(self.loop)
self.assertIsInstance(waiter.exception(), ConnectionAbortedError)

def test_connection_lost_when_busy(self):
# gh-118950: SSLProtocol.connection_lost not being called when OSError
# is thrown on asyncio.write.
sock = mock.Mock()
sock.fileno = mock.Mock(return_value=12345)
sock.send = mock.Mock(side_effect=BrokenPipeError)

# construct StreamWriter chain that contains loop dependant logic this emulates
# what _make_ssl_transport() does in BaseSelectorEventLoop
reader = asyncio.StreamReader(limit=2 ** 16, loop=self.loop)
protocol = asyncio.StreamReaderProtocol(reader, loop=self.loop)
ssl_proto = self.ssl_protocol(proto=protocol)

# emulate reading decompressed data
sslobj = mock.Mock()
sslobj.read.side_effect = ssl.SSLWantReadError
sslobj.write.side_effect = ssl.SSLWantReadError
ssl_proto._sslobj = sslobj

# emulate outgoing data
data = b'An interesting message'

outgoing = mock.Mock()
outgoing.read = mock.Mock(return_value=data)
outgoing.pending = len(data)
ssl_proto._outgoing = outgoing

# use correct socket transport to initialize the SSLProtocol
self.loop._make_socket_transport(sock, ssl_proto)

transport = ssl_proto._app_transport
writer = asyncio.StreamWriter(transport, protocol, reader, self.loop)

# Write data to the transport n times in a task that blocks the
# asyncio event loop from a user perspective.
async def _write_loop(n):
for i in range(n):
writer.write(data)
await writer.drain()

# The test is successful if we raise the error the next time
# we try to write to the transport.
with self.assertRaises(ConnectionResetError):
self.loop.run_until_complete(_write_loop(2))

def test_close_during_handshake(self):
# bpo-29743 Closing transport during handshake process leaks socket
waiter = self.loop.create_future()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Let _SSLProtocolTransport.is_closing reflect SSLProtocol internal transport
closing state so that StreamWriter.drain will invoke sleep(0) which calls
connection_lost and correctly notifies waiters of connection lost.