Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions edb/protocol/protocol.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ from typing import Any

from . import messages

class Protocol: ...

class Connection:
async def connect(self) -> None:
...
Expand Down
82 changes: 82 additions & 0 deletions tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
# limitations under the License.
#

from typing import Optional

import asyncio
import asyncio.sslproto
import contextlib
import struct

import edgedb
from gel import con_utils

from edb.server import args as srv_args
from edb.server import compiler
Expand All @@ -30,6 +33,7 @@
from edb.testbase import server as tb
from edb.testbase import connection as tconn
from edb.testbase.protocol.test import ProtocolTestCase
from edb.tools import test


def pack_i32s(*args):
Expand Down Expand Up @@ -896,6 +900,54 @@ async def _test_proto_discard_prepared_statement_in_script(self):
finally:
await self.con.recv_match(protocol.ReadyForCommand)

@test.xerror("FIXME")
async def test_proto_tls_close_notify(self):
# Setup connection with custom protocols
args = self.get_connect_args(database=self.get_database_name())
args.setdefault('dsn', None)
args.setdefault('host', None)
args.setdefault('port', None)
args.setdefault('user', None)
args.setdefault('password', None)
args.setdefault('secret_key', None)
args.setdefault('branch', None)
args.setdefault('database', None)
timeout = args.setdefault('timeout', 60)
args.setdefault('tls_ca', None)
args.setdefault('tls_ca_file', None)
args.setdefault('tls_security', 'default')
args.setdefault('credentials', None)
args.setdefault('credentials_file', None)
connect_config, client_config = con_utils.parse_connect_arguments(
**args,
command_timeout=None,
server_settings=None,
tls_server_name=None,
wait_until_available=timeout,
)
loop = asyncio.get_running_loop()
gel_protocol = GelProtocol(connect_config, loop)
protocol_factory = lambda: StrictTlsClientProtocol(
loop, gel_protocol, connect_config.ssl_ctx, None
)
addr = connect_config.address
if isinstance(addr, str):
connector = loop.create_unix_connection(protocol_factory, addr)
else:
connector = loop.create_connection(protocol_factory, *addr)
tls_transport, tls_protocol = await connector

# Complete the Gel handshake
await gel_protocol.connect()

# Now, close the connection without sending a `Terminate`, but with
# only a TLS `close_notify`.
tls_transport.close()

# We expect the server to reply with a `close_notify` too. If not,
# this will fail with the error in StrictTlsClientProtocol.
await gel_protocol.wait_closed()


class TestServerCancellation(tb.TestCase):
@contextlib.asynccontextmanager
Expand Down Expand Up @@ -1029,3 +1081,33 @@ async def test_proto_gh3170_connection_lost_error(self):
except Exception:
await con.aclose()
raise


class GelProtocol(protocol.protocol.Protocol):
_close_fut: Optional[asyncio.Future] = None

def connection_lost(self, exc):
if self._close_fut is not None:
if exc is None:
self._close_fut.set_result(None)
else:
self._close_fut.set_exception(exc)
super().connection_lost(exc)

async def wait_closed(self):
self._close_fut = asyncio.Future()
try:
await self._close_fut
finally:
self._close_fut = None


class StrictTlsClientProtocol(asyncio.sslproto.SSLProtocol):
def connection_lost(self, exc):
if self._state == asyncio.sslproto.SSLProtocolState.WRAPPED:
if exc is None:
exc = ConnectionResetError(
'peer closed connection without sending '
'TLS close_notify'
)
super().connection_lost(exc)