diff --git a/edb/protocol/protocol.pyi b/edb/protocol/protocol.pyi index a9764ff3b0e..d071058b7e4 100644 --- a/edb/protocol/protocol.pyi +++ b/edb/protocol/protocol.pyi @@ -20,6 +20,8 @@ from typing import Any from . import messages +class Protocol: ... + class Connection: async def connect(self) -> None: ... diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 5df0d10d83b..fc0537e11f2 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -16,12 +16,15 @@ # limitations under the License. # +from typing import Optional import asyncio +import asyncio.sslproto import contextlib import struct import edgedb +from gel import con_utils from edb.server import args as srv_args from edb.server import compiler @@ -30,6 +33,7 @@ from edb.testbase import server as tb from edb.testbase import connection as tconn from edb.testbase.protocol.test import ProtocolTestCase +from edb.tools import test def pack_i32s(*args): @@ -896,6 +900,54 @@ async def _test_proto_discard_prepared_statement_in_script(self): finally: await self.con.recv_match(protocol.ReadyForCommand) + @test.xerror("FIXME") + async def test_proto_tls_close_notify(self): + # Setup connection with custom protocols + args = self.get_connect_args(database=self.get_database_name()) + args.setdefault('dsn', None) + args.setdefault('host', None) + args.setdefault('port', None) + args.setdefault('user', None) + args.setdefault('password', None) + args.setdefault('secret_key', None) + args.setdefault('branch', None) + args.setdefault('database', None) + timeout = args.setdefault('timeout', 60) + args.setdefault('tls_ca', None) + args.setdefault('tls_ca_file', None) + args.setdefault('tls_security', 'default') + args.setdefault('credentials', None) + args.setdefault('credentials_file', None) + connect_config, client_config = con_utils.parse_connect_arguments( + **args, + command_timeout=None, + server_settings=None, + tls_server_name=None, + wait_until_available=timeout, + ) + loop = asyncio.get_running_loop() + gel_protocol = GelProtocol(connect_config, loop) + protocol_factory = lambda: StrictTlsClientProtocol( + loop, gel_protocol, connect_config.ssl_ctx, None + ) + addr = connect_config.address + if isinstance(addr, str): + connector = loop.create_unix_connection(protocol_factory, addr) + else: + connector = loop.create_connection(protocol_factory, *addr) + tls_transport, tls_protocol = await connector + + # Complete the Gel handshake + await gel_protocol.connect() + + # Now, close the connection without sending a `Terminate`, but with + # only a TLS `close_notify`. + tls_transport.close() + + # We expect the server to reply with a `close_notify` too. If not, + # this will fail with the error in StrictTlsClientProtocol. + await gel_protocol.wait_closed() + class TestServerCancellation(tb.TestCase): @contextlib.asynccontextmanager @@ -1029,3 +1081,33 @@ async def test_proto_gh3170_connection_lost_error(self): except Exception: await con.aclose() raise + + +class GelProtocol(protocol.protocol.Protocol): + _close_fut: Optional[asyncio.Future] = None + + def connection_lost(self, exc): + if self._close_fut is not None: + if exc is None: + self._close_fut.set_result(None) + else: + self._close_fut.set_exception(exc) + super().connection_lost(exc) + + async def wait_closed(self): + self._close_fut = asyncio.Future() + try: + await self._close_fut + finally: + self._close_fut = None + + +class StrictTlsClientProtocol(asyncio.sslproto.SSLProtocol): + def connection_lost(self, exc): + if self._state == asyncio.sslproto.SSLProtocolState.WRAPPED: + if exc is None: + exc = ConnectionResetError( + 'peer closed connection without sending ' + 'TLS close_notify' + ) + super().connection_lost(exc)