@@ -1724,6 +1724,72 @@ async def client(addr):
1724
1724
self .loop .run_until_complete (
1725
1725
asyncio .wait_for (client (srv .addr ), loop = self .loop , timeout = 10 ))
1726
1726
1727
+ def test_create_connection_memory_leak (self ):
1728
+ if self .implementation == 'asyncio' :
1729
+ raise unittest .SkipTest ()
1730
+
1731
+ HELLO_MSG = b'1' * self .PAYLOAD_SIZE
1732
+
1733
+ server_context = self ._create_server_ssl_context (
1734
+ self .ONLYCERT , self .ONLYKEY )
1735
+ client_context = self ._create_client_ssl_context ()
1736
+
1737
+ def serve (sock ):
1738
+ sock .settimeout (self .TIMEOUT )
1739
+
1740
+ sock .starttls (server_context , server_side = True )
1741
+
1742
+ sock .sendall (b'O' )
1743
+ data = sock .recv_all (len (HELLO_MSG ))
1744
+ self .assertEqual (len (data ), len (HELLO_MSG ))
1745
+
1746
+ sock .unwrap ()
1747
+ sock .close ()
1748
+
1749
+ class ClientProto (asyncio .Protocol ):
1750
+ def __init__ (self , on_data , on_eof ):
1751
+ self .on_data = on_data
1752
+ self .on_eof = on_eof
1753
+ self .con_made_cnt = 0
1754
+
1755
+ def connection_made (proto , tr ):
1756
+ # XXX: We assume user stores the transport in protocol
1757
+ proto .tr = tr
1758
+ proto .con_made_cnt += 1
1759
+ # Ensure connection_made gets called only once.
1760
+ self .assertEqual (proto .con_made_cnt , 1 )
1761
+
1762
+ def data_received (self , data ):
1763
+ self .on_data .set_result (data )
1764
+
1765
+ def eof_received (self ):
1766
+ self .on_eof .set_result (True )
1767
+
1768
+ async def client (addr ):
1769
+ await asyncio .sleep (0.5 , loop = self .loop )
1770
+
1771
+ on_data = self .loop .create_future ()
1772
+ on_eof = self .loop .create_future ()
1773
+
1774
+ tr , proto = await self .loop .create_connection (
1775
+ lambda : ClientProto (on_data , on_eof ), * addr ,
1776
+ ssl = client_context )
1777
+
1778
+ self .assertEqual (await on_data , b'O' )
1779
+ tr .write (HELLO_MSG )
1780
+ await on_eof
1781
+
1782
+ tr .close ()
1783
+
1784
+ with self .tcp_server (serve , timeout = self .TIMEOUT ) as srv :
1785
+ self .loop .run_until_complete (
1786
+ asyncio .wait_for (client (srv .addr ), loop = self .loop , timeout = 10 ))
1787
+
1788
+ # No garbage is left for SSL client from loop.create_connection, even
1789
+ # if user stores the SSLTransport in corresponding protocol instance
1790
+ client_context = weakref .ref (client_context )
1791
+ self .assertIsNone (client_context ())
1792
+
1727
1793
def test_start_tls_client_buf_proto_1 (self ):
1728
1794
if self .implementation == 'asyncio' :
1729
1795
raise unittest .SkipTest ()
0 commit comments