Skip to content

Commit 68dd733

Browse files
committed
Add ssl_handshake_timeout parameter to loop connection/server APIs
As part of this commit, fix error propagation in SSLProtocol implementation.
1 parent 5eb1fb3 commit 68dd733

File tree

12 files changed

+283
-69
lines changed

12 files changed

+283
-69
lines changed

tests/test_tcp.py

Lines changed: 121 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,16 @@ async def test():
318318

319319
self.loop.run_until_complete(test())
320320

321+
def test_create_server_8(self):
322+
if self.implementation == 'asyncio' and not self.PY37:
323+
raise unittest.SkipTest()
324+
325+
with self.assertRaisesRegex(
326+
ValueError, 'ssl_handshake_timeout is only meaningful'):
327+
self.loop.run_until_complete(
328+
self.loop.create_server(
329+
lambda: None, host='::', port=0, ssl_handshake_timeout=10))
330+
321331
def test_create_connection_open_con_addr(self):
322332
async def client(addr):
323333
reader, writer = await asyncio.open_connection(
@@ -501,6 +511,16 @@ async def client(addr):
501511
backlog=1) as srv:
502512
self.loop.run_until_complete(client(srv.addr))
503513

514+
def test_create_connection_6(self):
515+
if self.implementation == 'asyncio' and not self.PY37:
516+
raise unittest.SkipTest()
517+
518+
with self.assertRaisesRegex(
519+
ValueError, 'ssl_handshake_timeout is only meaningful'):
520+
self.loop.run_until_complete(
521+
self.loop.create_connection(
522+
lambda: None, host='::', port=0, ssl_handshake_timeout=10))
523+
504524
def test_transport_shutdown(self):
505525
CNT = 0 # number of clients that were successful
506526
TOTAL_CNT = 100 # total number of clients that test will create
@@ -855,6 +875,17 @@ async def runner():
855875
srv.close()
856876
self.loop.run_until_complete(srv.wait_closed())
857877

878+
def test_connect_accepted_socket_ssl_args(self):
879+
if self.implementation == 'asyncio' and not self.PY37:
880+
raise unittest.SkipTest()
881+
882+
with self.assertRaisesRegex(
883+
ValueError, 'ssl_handshake_timeout is only meaningful'):
884+
with socket.socket() as s:
885+
self.loop.run_until_complete(
886+
self.loop.connect_accepted_socket(
887+
(lambda: None), s, ssl_handshake_timeout=10.0))
888+
858889
def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None):
859890
loop = self.loop
860891

@@ -898,9 +929,15 @@ def client():
898929
conn, _ = lsock.accept()
899930
proto = MyProto(loop=loop)
900931
proto.loop = loop
932+
933+
extras = {}
934+
if server_ssl and (self.implementation != 'asyncio' or self.PY37):
935+
extras = dict(ssl_handshake_timeout=10.0)
936+
901937
f = loop.create_task(
902938
loop.connect_accepted_socket(
903-
(lambda: proto), conn, ssl=server_ssl))
939+
(lambda: proto), conn, ssl=server_ssl,
940+
**extras))
904941
loop.run_forever()
905942
conn.close()
906943
lsock.close()
@@ -1017,12 +1054,17 @@ def prog(sock):
10171054
await fut
10181055

10191056
async def start_server():
1057+
extras = {}
1058+
if self.implementation != 'asyncio' or self.PY37:
1059+
extras = dict(ssl_handshake_timeout=10.0)
1060+
10201061
srv = await asyncio.start_server(
10211062
handle_client,
10221063
'127.0.0.1', 0,
10231064
family=socket.AF_INET,
10241065
ssl=sslctx,
1025-
loop=self.loop)
1066+
loop=self.loop,
1067+
**extras)
10261068

10271069
try:
10281070
srv_socks = srv.sockets
@@ -1080,11 +1122,16 @@ def server(sock):
10801122
sock.close()
10811123

10821124
async def client(addr):
1125+
extras = {}
1126+
if self.implementation != 'asyncio' or self.PY37:
1127+
extras = dict(ssl_handshake_timeout=10.0)
1128+
10831129
reader, writer = await asyncio.open_connection(
10841130
*addr,
10851131
ssl=client_sslctx,
10861132
server_hostname='',
1087-
loop=self.loop)
1133+
loop=self.loop,
1134+
**extras)
10881135

10891136
writer.write(A_DATA)
10901137
self.assertEqual(await reader.readexactly(2), b'OK')
@@ -1140,6 +1187,77 @@ def run(coro):
11401187
with self._silence_eof_received_warning():
11411188
run(client_sock)
11421189

1190+
def test_create_connection_ssl_slow_handshake(self):
1191+
if self.implementation == 'asyncio':
1192+
raise unittest.SkipTest()
1193+
1194+
client_sslctx = self._create_client_ssl_context()
1195+
1196+
# silence error logger
1197+
self.loop.set_exception_handler(lambda *args: None)
1198+
1199+
def server(sock):
1200+
try:
1201+
sock.recv_all(1024 * 1024)
1202+
except ConnectionAbortedError:
1203+
pass
1204+
finally:
1205+
sock.close()
1206+
1207+
async def client(addr):
1208+
reader, writer = await asyncio.open_connection(
1209+
*addr,
1210+
ssl=client_sslctx,
1211+
server_hostname='',
1212+
loop=self.loop,
1213+
ssl_handshake_timeout=1.0)
1214+
1215+
with self.tcp_server(server,
1216+
max_clients=1,
1217+
backlog=1) as srv:
1218+
1219+
with self.assertRaisesRegex(
1220+
ConnectionAbortedError,
1221+
r'SSL handshake.*is taking longer'):
1222+
1223+
self.loop.run_until_complete(client(srv.addr))
1224+
1225+
def test_create_connection_ssl_failed_certificate(self):
1226+
if self.implementation == 'asyncio':
1227+
raise unittest.SkipTest()
1228+
1229+
# silence error logger
1230+
self.loop.set_exception_handler(lambda *args: None)
1231+
1232+
sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
1233+
client_sslctx = self._create_client_ssl_context(disable_verify=False)
1234+
1235+
def server(sock):
1236+
try:
1237+
sock.starttls(
1238+
sslctx,
1239+
server_side=True)
1240+
sock.connect()
1241+
except ssl.SSLError:
1242+
pass
1243+
finally:
1244+
sock.close()
1245+
1246+
async def client(addr):
1247+
reader, writer = await asyncio.open_connection(
1248+
*addr,
1249+
ssl=client_sslctx,
1250+
server_hostname='',
1251+
loop=self.loop,
1252+
ssl_handshake_timeout=1.0)
1253+
1254+
with self.tcp_server(server,
1255+
max_clients=1,
1256+
backlog=1) as srv:
1257+
1258+
with self.assertRaises(ssl.SSLCertVerificationError):
1259+
self.loop.run_until_complete(client(srv.addr))
1260+
11431261
def test_ssl_connect_accepted_socket(self):
11441262
server_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
11451263
server_context.load_cert_chain(self.ONLYCERT, self.ONLYKEY)

tests/test_unix.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,16 @@ def test_create_unix_server_2(self):
155155
self.loop.run_until_complete(
156156
self.loop.create_unix_server(object, sock_name))
157157

158+
def test_create_unix_server_3(self):
159+
if self.implementation == 'asyncio' and not self.PY37:
160+
raise unittest.SkipTest()
161+
162+
with self.assertRaisesRegex(
163+
ValueError, 'ssl_handshake_timeout is only meaningful'):
164+
self.loop.run_until_complete(
165+
self.loop.create_unix_server(
166+
lambda: None, path='/tmp/a', ssl_handshake_timeout=10))
167+
158168
def test_create_unix_server_existing_path_sock(self):
159169
with self.unix_sock_name() as path:
160170
sock = socket.socket(socket.AF_UNIX)
@@ -363,7 +373,18 @@ async def client():
363373
self.assertIn(excs[0].__class__,
364374
(BrokenPipeError, ConnectionResetError))
365375

366-
@unittest.skipUnless(sys.version_info < (3, 7), 'Python version must be < 3.7')
376+
def test_create_unix_connection_6(self):
377+
if self.implementation == 'asyncio' and not self.PY37:
378+
raise unittest.SkipTest()
379+
380+
with self.assertRaisesRegex(
381+
ValueError, 'ssl_handshake_timeout is only meaningful'):
382+
self.loop.run_until_complete(
383+
self.loop.create_unix_connection(
384+
lambda: None, path='/tmp/a', ssl_handshake_timeout=10))
385+
386+
@unittest.skipUnless(sys.version_info < (3, 7),
387+
'Python version must be < 3.7')
367388
def test_transport_unclosed_warning(self):
368389
async def test(sock):
369390
return await self.loop.create_unix_connection(
@@ -523,14 +544,19 @@ def prog(sock):
523544
await fut
524545

525546
async def start_server():
547+
extras = {}
548+
if self.implementation != 'asyncio' or self.PY37:
549+
extras = dict(ssl_handshake_timeout=10.0)
550+
526551
with tempfile.TemporaryDirectory() as td:
527552
sock_name = os.path.join(td, 'sock')
528553

529554
srv = await asyncio.start_unix_server(
530555
handle_client,
531556
sock_name,
532557
ssl=sslctx,
533-
loop=self.loop)
558+
loop=self.loop,
559+
**extras)
534560

535561
try:
536562
tasks = []
@@ -584,11 +610,16 @@ def server(sock):
584610
sock.close()
585611

586612
async def client(addr):
613+
extras = {}
614+
if self.implementation != 'asyncio' or self.PY37:
615+
extras = dict(ssl_handshake_timeout=10.0)
616+
587617
reader, writer = await asyncio.open_unix_connection(
588618
addr,
589619
ssl=client_sslctx,
590620
server_hostname='',
591-
loop=self.loop)
621+
loop=self.loop,
622+
**extras)
592623

593624
writer.write(A_DATA)
594625
self.assertEqual(await reader.readexactly(2), b'OK')

uvloop/_testbase.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import select
1414
import socket
1515
import ssl
16+
import sys
1617
import tempfile
1718
import threading
1819
import time
@@ -89,6 +90,9 @@ def setUp(self):
8990
self._get_running_loop = asyncio.events._get_running_loop
9091
asyncio.events._get_running_loop = lambda: None
9192

93+
self.PY37 = sys.version_info[:2] >= (3, 7)
94+
self.PY36 = sys.version_info[:2] >= (3, 6)
95+
9296
def tearDown(self):
9397
self.loop.close()
9498

@@ -268,10 +272,11 @@ def _create_server_ssl_context(self, certfile, keyfile=None):
268272
sslcontext.load_cert_chain(certfile, keyfile)
269273
return sslcontext
270274

271-
def _create_client_ssl_context(self):
275+
def _create_client_ssl_context(self, *, disable_verify=True):
272276
sslcontext = ssl.create_default_context()
273277
sslcontext.check_hostname = False
274-
sslcontext.verify_mode = ssl.CERT_NONE
278+
if disable_verify:
279+
sslcontext.verify_mode = ssl.CERT_NONE
275280
return sslcontext
276281

277282
@contextlib.contextmanager

uvloop/handles/pipe.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ cdef class UnixServer(UVStreamServer):
44

55
@staticmethod
66
cdef UnixServer new(Loop loop, object protocol_factory, Server server,
7-
object ssl)
7+
object ssl, object ssl_handshake_timeout)
88

99

1010
cdef class UnixTransport(UVStream):

uvloop/handles/pipe.pyx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@ cdef class UnixServer(UVStreamServer):
3939

4040
@staticmethod
4141
cdef UnixServer new(Loop loop, object protocol_factory, Server server,
42-
object ssl):
42+
object ssl, object ssl_handshake_timeout):
4343

4444
cdef UnixServer handle
4545
handle = UnixServer.__new__(UnixServer)
46-
handle._init(loop, protocol_factory, server, ssl)
46+
handle._init(loop, protocol_factory, server,
47+
ssl, ssl_handshake_timeout)
4748
__pipe_init_uv_handle(<UVStream>handle, loop)
4849
return handle
4950

uvloop/handles/streamserver.pxd

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
cdef class UVStreamServer(UVSocketHandle):
22
cdef:
33
object ssl
4+
object ssl_handshake_timeout
45
object protocol_factory
56
bint opened
67
Server _server
78

89
# All "inline" methods are final
910

1011
cdef inline _init(self, Loop loop, object protocol_factory,
11-
Server server, object ssl)
12+
Server server, object ssl, object ssl_handshake_timeout)
1213

1314
cdef inline _mark_as_open(self)
1415

uvloop/handles/streamserver.pyx

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,24 @@ cdef class UVStreamServer(UVSocketHandle):
55
self.opened = 0
66
self._server = None
77
self.ssl = None
8+
self.ssl_handshake_timeout = None
89
self.protocol_factory = None
910

1011
cdef inline _init(self, Loop loop, object protocol_factory,
11-
Server server, object ssl):
12+
Server server, object ssl, object ssl_handshake_timeout):
13+
14+
if ssl is not None:
15+
if not isinstance(ssl, ssl_SSLContext):
16+
raise TypeError(
17+
'ssl is expected to be None or an instance of '
18+
'ssl.SSLContext, got {!r}'.format(ssl))
19+
else:
20+
if ssl_handshake_timeout is not None:
21+
raise ValueError(
22+
'ssl_handshake_timeout is only meaningful with ssl')
1223

13-
if ssl is not None and not isinstance(ssl, ssl_SSLContext):
14-
raise TypeError(
15-
'ssl is expected to be None or an instance of '
16-
'ssl.SSLContext, got {!r}'.format(ssl))
1724
self.ssl = ssl
25+
self.ssl_handshake_timeout = ssl_handshake_timeout
1826

1927
self._start_init(loop)
2028
self.protocol_factory = protocol_factory
@@ -57,8 +65,9 @@ cdef class UVStreamServer(UVSocketHandle):
5765
ssl_protocol = SSLProtocol(
5866
self._loop, protocol, self.ssl,
5967
waiter,
60-
True, # server_side
61-
None) # server_hostname
68+
server_side=True,
69+
server_hostname=None,
70+
ssl_handshake_timeout=self.ssl_handshake_timeout)
6271

6372
client = self._make_new_transport(ssl_protocol, None)
6473

uvloop/handles/tcp.pxd

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ cdef class TCPServer(UVStreamServer):
33

44
@staticmethod
55
cdef TCPServer new(Loop loop, object protocol_factory, Server server,
6-
object ssl, unsigned int flags)
6+
object ssl, unsigned int flags,
7+
object ssl_handshake_timeout)
78

89

910
cdef class TCPTransport(UVStream):

0 commit comments

Comments
 (0)