|
16 | 16 | # limitations under the License. |
17 | 17 | # |
18 | 18 |
|
| 19 | +from typing import Optional |
19 | 20 |
|
20 | 21 | import asyncio |
| 22 | +import asyncio.sslproto |
21 | 23 | import contextlib |
22 | 24 | import struct |
23 | 25 |
|
24 | 26 | import edgedb |
| 27 | +from gel import con_utils |
25 | 28 |
|
26 | 29 | from edb.server import args as srv_args |
27 | 30 | from edb.server import compiler |
|
30 | 33 | from edb.testbase import server as tb |
31 | 34 | from edb.testbase import connection as tconn |
32 | 35 | from edb.testbase.protocol.test import ProtocolTestCase |
| 36 | +from edb.tools import test |
33 | 37 |
|
34 | 38 |
|
35 | 39 | def pack_i32s(*args): |
@@ -896,6 +900,54 @@ async def _test_proto_discard_prepared_statement_in_script(self): |
896 | 900 | finally: |
897 | 901 | await self.con.recv_match(protocol.ReadyForCommand) |
898 | 902 |
|
| 903 | + @test.xerror("FIXME") |
| 904 | + async def test_proto_tls_close_notify(self): |
| 905 | + # Setup connection with custom protocols |
| 906 | + args = self.get_connect_args(database=self.get_database_name()) |
| 907 | + args.setdefault('dsn', None) |
| 908 | + args.setdefault('host', None) |
| 909 | + args.setdefault('port', None) |
| 910 | + args.setdefault('user', None) |
| 911 | + args.setdefault('password', None) |
| 912 | + args.setdefault('secret_key', None) |
| 913 | + args.setdefault('branch', None) |
| 914 | + args.setdefault('database', None) |
| 915 | + timeout = args.setdefault('timeout', 60) |
| 916 | + args.setdefault('tls_ca', None) |
| 917 | + args.setdefault('tls_ca_file', None) |
| 918 | + args.setdefault('tls_security', 'default') |
| 919 | + args.setdefault('credentials', None) |
| 920 | + args.setdefault('credentials_file', None) |
| 921 | + connect_config, client_config = con_utils.parse_connect_arguments( |
| 922 | + **args, |
| 923 | + command_timeout=None, |
| 924 | + server_settings=None, |
| 925 | + tls_server_name=None, |
| 926 | + wait_until_available=timeout, |
| 927 | + ) |
| 928 | + loop = asyncio.get_running_loop() |
| 929 | + gel_protocol = GelProtocol(connect_config, loop) |
| 930 | + protocol_factory = lambda: StrictTlsClientProtocol( |
| 931 | + loop, gel_protocol, connect_config.ssl_ctx, None |
| 932 | + ) |
| 933 | + addr = connect_config.address |
| 934 | + if isinstance(addr, str): |
| 935 | + connector = loop.create_unix_connection(protocol_factory, addr) |
| 936 | + else: |
| 937 | + connector = loop.create_connection(protocol_factory, *addr) |
| 938 | + tls_transport, tls_protocol = await connector |
| 939 | + |
| 940 | + # Complete the Gel handshake |
| 941 | + await gel_protocol.connect() |
| 942 | + |
| 943 | + # Now, close the connection without sending a `Terminate`, but with |
| 944 | + # only a TLS `close_notify`. |
| 945 | + tls_transport.close() |
| 946 | + |
| 947 | + # We expect the server to reply with a `close_notify` too. If not, |
| 948 | + # this will fail with the error in StrictTlsClientProtocol. |
| 949 | + await gel_protocol.wait_closed() |
| 950 | + |
899 | 951 |
|
900 | 952 | class TestServerCancellation(tb.TestCase): |
901 | 953 | @contextlib.asynccontextmanager |
@@ -1029,3 +1081,33 @@ async def test_proto_gh3170_connection_lost_error(self): |
1029 | 1081 | except Exception: |
1030 | 1082 | await con.aclose() |
1031 | 1083 | raise |
| 1084 | + |
| 1085 | + |
| 1086 | +class GelProtocol(protocol.protocol.Protocol): |
| 1087 | + _close_fut: Optional[asyncio.Future] = None |
| 1088 | + |
| 1089 | + def connection_lost(self, exc): |
| 1090 | + if self._close_fut is not None: |
| 1091 | + if exc is None: |
| 1092 | + self._close_fut.set_result(None) |
| 1093 | + else: |
| 1094 | + self._close_fut.set_exception(exc) |
| 1095 | + super().connection_lost(exc) |
| 1096 | + |
| 1097 | + async def wait_closed(self): |
| 1098 | + self._close_fut = asyncio.Future() |
| 1099 | + try: |
| 1100 | + await self._close_fut |
| 1101 | + finally: |
| 1102 | + self._close_fut = None |
| 1103 | + |
| 1104 | + |
| 1105 | +class StrictTlsClientProtocol(asyncio.sslproto.SSLProtocol): |
| 1106 | + def connection_lost(self, exc): |
| 1107 | + if self._state == asyncio.sslproto.SSLProtocolState.WRAPPED: |
| 1108 | + if exc is None: |
| 1109 | + exc = ConnectionResetError( |
| 1110 | + 'peer closed connection without sending ' |
| 1111 | + 'TLS close_notify' |
| 1112 | + ) |
| 1113 | + super().connection_lost(exc) |
0 commit comments