|
3 | 3 | import socket
|
4 | 4 | import unittest.mock
|
5 | 5 | import uvloop
|
| 6 | +import ssl |
6 | 7 | import sys
|
| 8 | +import threading |
7 | 9 |
|
8 | 10 | from uvloop import _testbase as tb
|
9 | 11 |
|
10 | 12 |
|
| 13 | +class MyBaseProto(asyncio.Protocol): |
| 14 | + connected = None |
| 15 | + done = None |
| 16 | + |
| 17 | + def __init__(self, loop=None): |
| 18 | + self.transport = None |
| 19 | + self.state = 'INITIAL' |
| 20 | + self.nbytes = 0 |
| 21 | + if loop is not None: |
| 22 | + self.connected = asyncio.Future(loop=loop) |
| 23 | + self.done = asyncio.Future(loop=loop) |
| 24 | + |
| 25 | + def connection_made(self, transport): |
| 26 | + self.transport = transport |
| 27 | + assert self.state == 'INITIAL', self.state |
| 28 | + self.state = 'CONNECTED' |
| 29 | + if self.connected: |
| 30 | + self.connected.set_result(None) |
| 31 | + |
| 32 | + def data_received(self, data): |
| 33 | + assert self.state == 'CONNECTED', self.state |
| 34 | + self.nbytes += len(data) |
| 35 | + |
| 36 | + def eof_received(self): |
| 37 | + assert self.state == 'CONNECTED', self.state |
| 38 | + self.state = 'EOF' |
| 39 | + |
| 40 | + def connection_lost(self, exc): |
| 41 | + assert self.state in ('CONNECTED', 'EOF'), self.state |
| 42 | + self.state = 'CLOSED' |
| 43 | + if self.done: |
| 44 | + self.done.set_result(None) |
| 45 | + |
| 46 | + |
11 | 47 | class _TestTCP:
|
12 | 48 | def test_create_server_1(self):
|
13 | 49 | if self.is_asyncio_loop() and sys.version_info[:3] == (3, 5, 2):
|
@@ -699,6 +735,62 @@ async def runner():
|
699 | 735 | srv.close()
|
700 | 736 | self.loop.run_until_complete(srv.wait_closed())
|
701 | 737 |
|
| 738 | + def test_connect_accepted_socket(self, server_ssl=None, client_ssl=None): |
| 739 | + loop = self.loop |
| 740 | + |
| 741 | + class MyProto(MyBaseProto): |
| 742 | + |
| 743 | + def connection_lost(self, exc): |
| 744 | + super().connection_lost(exc) |
| 745 | + loop.call_soon(loop.stop) |
| 746 | + |
| 747 | + def data_received(self, data): |
| 748 | + super().data_received(data) |
| 749 | + self.transport.write(expected_response) |
| 750 | + |
| 751 | + lsock = socket.socket() |
| 752 | + lsock.bind(('127.0.0.1', 0)) |
| 753 | + lsock.listen(1) |
| 754 | + addr = lsock.getsockname() |
| 755 | + |
| 756 | + message = b'test data' |
| 757 | + response = None |
| 758 | + expected_response = b'roger' |
| 759 | + |
| 760 | + def client(): |
| 761 | + nonlocal response |
| 762 | + try: |
| 763 | + csock = socket.socket() |
| 764 | + if client_ssl is not None: |
| 765 | + csock = client_ssl.wrap_socket(csock) |
| 766 | + csock.connect(addr) |
| 767 | + csock.sendall(message) |
| 768 | + response = csock.recv(99) |
| 769 | + csock.close() |
| 770 | + except Exception as exc: |
| 771 | + print( |
| 772 | + "Failure in client thread in test_connect_accepted_socket", |
| 773 | + exc) |
| 774 | + |
| 775 | + thread = threading.Thread(target=client, daemon=True) |
| 776 | + thread.start() |
| 777 | + |
| 778 | + conn, _ = lsock.accept() |
| 779 | + proto = MyProto(loop=loop) |
| 780 | + proto.loop = loop |
| 781 | + loop.create_task( |
| 782 | + loop.connect_accepted_socket( |
| 783 | + (lambda: proto), conn, ssl=server_ssl)) |
| 784 | + loop.run_forever() |
| 785 | + conn.close() |
| 786 | + lsock.close() |
| 787 | + |
| 788 | + thread.join(1) |
| 789 | + self.assertFalse(thread.is_alive()) |
| 790 | + self.assertEqual(proto.state, 'CLOSED') |
| 791 | + self.assertEqual(proto.nbytes, len(message)) |
| 792 | + self.assertEqual(response, expected_response) |
| 793 | + |
702 | 794 |
|
703 | 795 | class Test_AIO_TCP(_TestTCP, tb.AIOTestCase):
|
704 | 796 | pass
|
@@ -864,7 +956,22 @@ def run(coro):
|
864 | 956 |
|
865 | 957 |
|
866 | 958 | class Test_UV_TCPSSL(_TestSSL, tb.UVTestCase):
|
867 |
| - pass |
| 959 | + |
| 960 | + def test_ssl_connect_accepted_socket(self): |
| 961 | + server_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) |
| 962 | + server_context.load_cert_chain(self.ONLYCERT, self.ONLYKEY) |
| 963 | + if hasattr(server_context, 'check_hostname'): |
| 964 | + server_context.check_hostname = False |
| 965 | + server_context.verify_mode = ssl.CERT_NONE |
| 966 | + |
| 967 | + client_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) |
| 968 | + if hasattr(server_context, 'check_hostname'): |
| 969 | + client_context.check_hostname = False |
| 970 | + client_context.verify_mode = ssl.CERT_NONE |
| 971 | + |
| 972 | + Test_UV_TCP.test_connect_accepted_socket( |
| 973 | + self, server_context, client_context) |
| 974 | + |
868 | 975 |
|
869 | 976 |
|
870 | 977 | class Test_AIO_TCPSSL(_TestSSL, tb.AIOTestCase):
|
|
0 commit comments