Skip to content

Commit 2ed225e

Browse files
committed
Implement SSL for create_connection and create_unix_connection
1 parent efee263 commit 2ed225e

File tree

3 files changed

+181
-12
lines changed

3 files changed

+181
-12
lines changed

tests/test_tcp.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,70 @@ async def start_server():
480480
for client in clients:
481481
client.stop()
482482

483+
def test_create_connection_ssl_1(self):
484+
CNT = 0
485+
TOTAL_CNT = 25
486+
487+
A_DATA = b'A' * 1024 * 1024
488+
B_DATA = b'B' * 1024 * 1024
489+
490+
sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
491+
client_sslctx = self._create_client_ssl_context()
492+
493+
def server():
494+
yield tb.starttls(
495+
sslctx,
496+
server_side=True)
497+
498+
data = yield tb.read(len(A_DATA))
499+
self.assertEqual(data, A_DATA)
500+
yield tb.write(b'OK')
501+
502+
data = yield tb.read(len(B_DATA))
503+
self.assertEqual(data, B_DATA)
504+
yield tb.write(b'SPAM')
505+
506+
yield tb.close()
507+
508+
async def client(addr):
509+
reader, writer = await asyncio.open_connection(
510+
*addr,
511+
ssl=client_sslctx,
512+
server_hostname='',
513+
loop=self.loop)
514+
515+
writer.write(A_DATA)
516+
self.assertEqual(await reader.readexactly(2), b'OK')
517+
518+
writer.write(B_DATA)
519+
self.assertEqual(await reader.readexactly(4), b'SPAM')
520+
521+
nonlocal CNT
522+
CNT += 1
523+
524+
writer.close()
525+
526+
def run(coro):
527+
nonlocal CNT
528+
CNT = 0
529+
530+
srv = tb.tcp_server(server,
531+
max_clients=TOTAL_CNT,
532+
backlog=TOTAL_CNT)
533+
srv.start()
534+
535+
tasks = []
536+
for _ in range(TOTAL_CNT):
537+
tasks.append(coro(srv.addr))
538+
539+
self.loop.run_until_complete(
540+
asyncio.gather(*tasks, loop=self.loop))
541+
srv.join()
542+
self.assertEqual(CNT, TOTAL_CNT)
543+
544+
with self._silence_eof_received_warning():
545+
run(client)
546+
483547

484548
class Test_UV_TCPSSL(_TestSSL, tb.UVTestCase):
485549
pass

tests/test_unix.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,71 @@ async def start_server():
413413
for client in clients:
414414
client.stop()
415415

416+
def test_create_unix_connection_ssl_1(self):
417+
CNT = 0
418+
TOTAL_CNT = 25
419+
420+
A_DATA = b'A' * 1024 * 1024
421+
B_DATA = b'B' * 1024 * 1024
422+
423+
sslctx = self._create_server_ssl_context(self.ONLYCERT, self.ONLYKEY)
424+
client_sslctx = self._create_client_ssl_context()
425+
426+
def server():
427+
yield tb.starttls(
428+
sslctx,
429+
server_side=True)
430+
431+
data = yield tb.read(len(A_DATA))
432+
self.assertEqual(data, A_DATA)
433+
yield tb.write(b'OK')
434+
435+
data = yield tb.read(len(B_DATA))
436+
self.assertEqual(data, B_DATA)
437+
yield tb.write(b'SPAM')
438+
439+
yield tb.close()
440+
441+
async def client(addr):
442+
reader, writer = await asyncio.open_unix_connection(
443+
addr,
444+
ssl=client_sslctx,
445+
server_hostname='',
446+
loop=self.loop)
447+
448+
writer.write(A_DATA)
449+
self.assertEqual(await reader.readexactly(2), b'OK')
450+
451+
writer.write(B_DATA)
452+
self.assertEqual(await reader.readexactly(4), b'SPAM')
453+
454+
nonlocal CNT
455+
CNT += 1
456+
457+
writer.close()
458+
459+
def run(coro):
460+
nonlocal CNT
461+
CNT = 0
462+
463+
srv = tb.tcp_server(server,
464+
family=socket.AF_UNIX,
465+
max_clients=TOTAL_CNT,
466+
backlog=TOTAL_CNT)
467+
srv.start()
468+
469+
tasks = []
470+
for _ in range(TOTAL_CNT):
471+
tasks.append(coro(srv.addr))
472+
473+
self.loop.run_until_complete(
474+
asyncio.gather(*tasks, loop=self.loop))
475+
srv.join()
476+
self.assertEqual(CNT, TOTAL_CNT)
477+
478+
with self._silence_eof_received_warning():
479+
run(client)
480+
416481

417482
class Test_UV_UnixSSL(_TestSSL, tb.UVTestCase):
418483
pass

uvloop/loop.pyx

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,8 +1041,27 @@ cdef class Loop:
10411041
system.addrinfo *lai
10421042
UVTCPTransport tr
10431043

1044-
if ssl is not None:
1045-
raise NotImplementedError('SSL is not yet supported')
1044+
object app_protocol
1045+
object protocol
1046+
object ssl_waiter
1047+
1048+
app_protocol = protocol = protocol_factory()
1049+
ssl_waiter = None
1050+
if ssl:
1051+
if server_hostname is None:
1052+
if not host:
1053+
raise ValueError('You must set server_hostname '
1054+
'when using ssl without a host')
1055+
server_hostname = host
1056+
1057+
ssl_waiter = aio_Future(loop=self)
1058+
sslcontext = None if isinstance(ssl, bool) else ssl
1059+
protocol = aio_SSLProtocol(
1060+
self, app_protocol, sslcontext, ssl_waiter,
1061+
False, server_hostname)
1062+
else:
1063+
if server_hostname is not None:
1064+
raise ValueError('server_hostname is only meaningful with ssl')
10461065

10471066
if host is not None or port is not None:
10481067
f1 = self._getaddrinfo(host, port, family,
@@ -1075,8 +1094,6 @@ cdef class Loop:
10751094
raise OSError(
10761095
'getaddrinfo() returned empty list for local_addr')
10771096

1078-
protocol = protocol_factory()
1079-
10801097
exceptions = []
10811098
rai = ai_remote.data
10821099
while rai is not NULL:
@@ -1146,7 +1163,11 @@ cdef class Loop:
11461163
tr._close()
11471164
raise
11481165

1149-
return tr, protocol
1166+
if ssl:
1167+
await ssl_waiter
1168+
return protocol._app_transport, app_protocol
1169+
else:
1170+
return tr, protocol
11501171

11511172
@aio_coroutine
11521173
async def create_unix_server(self, protocol_factory, str path=None,
@@ -1204,10 +1225,27 @@ cdef class Loop:
12041225
ssl=None, sock=None,
12051226
server_hostname=None):
12061227

1207-
cdef UVPipeTransport tr
1208-
1209-
if ssl is not None:
1210-
raise NotImplementedError('SSL is not yet supported')
1228+
cdef:
1229+
UVPipeTransport tr
1230+
object app_protocol
1231+
object protocol
1232+
object ssl_waiter
1233+
1234+
app_protocol = protocol = protocol_factory()
1235+
ssl_waiter = None
1236+
if ssl:
1237+
if server_hostname is None:
1238+
raise ValueError('You must set server_hostname '
1239+
'when using ssl without a host')
1240+
1241+
ssl_waiter = aio_Future(loop=self)
1242+
sslcontext = None if isinstance(ssl, bool) else ssl
1243+
protocol = aio_SSLProtocol(
1244+
self, app_protocol, sslcontext, ssl_waiter,
1245+
False, server_hostname)
1246+
else:
1247+
if server_hostname is not None:
1248+
raise ValueError('server_hostname is only meaningful with ssl')
12111249

12121250
if path is not None:
12131251
if isinstance(path, str):
@@ -1217,7 +1255,6 @@ cdef class Loop:
12171255
raise ValueError(
12181256
'path and sock can not be specified at the same time')
12191257

1220-
protocol = protocol_factory()
12211258
waiter = aio_Future(loop=self)
12221259
tr = UVPipeTransport.new(self, protocol, None, waiter)
12231260
tr.connect(path)
@@ -1236,7 +1273,6 @@ cdef class Loop:
12361273
'A UNIX Domain Socket was expected, got {!r}'.format(sock))
12371274

12381275
waiter = aio_Future(loop=self)
1239-
protocol = protocol_factory()
12401276
tr = UVPipeTransport.new(self, protocol, None, waiter)
12411277
try:
12421278
# libuv will make socket non-blocking
@@ -1249,7 +1285,11 @@ cdef class Loop:
12491285
tr._close()
12501286
raise
12511287

1252-
return tr, protocol
1288+
if ssl:
1289+
await ssl_waiter
1290+
return protocol._app_transport, app_protocol
1291+
else:
1292+
return tr, protocol
12531293

12541294
def default_exception_handler(self, context):
12551295
message = context.get('message')

0 commit comments

Comments
 (0)