Skip to content

Commit 0ef0f4a

Browse files
committed
Add test
1 parent 55e91bc commit 0ef0f4a

File tree

2 files changed

+84
-0
lines changed

2 files changed

+84
-0
lines changed

edb/protocol/protocol.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ from typing import Any
2020

2121
from . import messages
2222

23+
class Protocol: ...
24+
2325
class Connection:
2426
async def connect(self) -> None:
2527
...

tests/test_protocol.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@
1616
# limitations under the License.
1717
#
1818

19+
from typing import Optional
1920

2021
import asyncio
22+
import asyncio.sslproto
2123
import contextlib
2224
import struct
2325

2426
import edgedb
27+
from gel import con_utils
2528

2629
from edb.server import args as srv_args
2730
from edb.server import compiler
@@ -30,6 +33,7 @@
3033
from edb.testbase import server as tb
3134
from edb.testbase import connection as tconn
3235
from edb.testbase.protocol.test import ProtocolTestCase
36+
from edb.tools import test
3337

3438

3539
def pack_i32s(*args):
@@ -896,6 +900,54 @@ async def _test_proto_discard_prepared_statement_in_script(self):
896900
finally:
897901
await self.con.recv_match(protocol.ReadyForCommand)
898902

903+
@test.xerror("FIXME")
904+
async def test_proto_tls_close_notify(self):
905+
# Setup connection with custom protocols
906+
args = self.get_connect_args(database=self.get_database_name())
907+
args.setdefault('dsn', None)
908+
args.setdefault('host', None)
909+
args.setdefault('port', None)
910+
args.setdefault('user', None)
911+
args.setdefault('password', None)
912+
args.setdefault('secret_key', None)
913+
args.setdefault('branch', None)
914+
args.setdefault('database', None)
915+
timeout = args.setdefault('timeout', 60)
916+
args.setdefault('tls_ca', None)
917+
args.setdefault('tls_ca_file', None)
918+
args.setdefault('tls_security', 'default')
919+
args.setdefault('credentials', None)
920+
args.setdefault('credentials_file', None)
921+
connect_config, client_config = con_utils.parse_connect_arguments(
922+
**args,
923+
command_timeout=None,
924+
server_settings=None,
925+
tls_server_name=None,
926+
wait_until_available=timeout,
927+
)
928+
loop = asyncio.get_running_loop()
929+
gel_protocol = GelProtocol(connect_config, loop)
930+
protocol_factory = lambda: StrictTlsClientProtocol(
931+
loop, gel_protocol, connect_config.ssl_ctx, None
932+
)
933+
addr = connect_config.address
934+
if isinstance(addr, str):
935+
connector = loop.create_unix_connection(protocol_factory, addr)
936+
else:
937+
connector = loop.create_connection(protocol_factory, *addr)
938+
tls_transport, tls_protocol = await connector
939+
940+
# Complete the Gel handshake
941+
await gel_protocol.connect()
942+
943+
# Now, close the connection without sending a `Terminate`, but with
944+
# only a TLS `close_notify`.
945+
tls_transport.close()
946+
947+
# We expect the server to reply with a `close_notify` too. If not,
948+
# this will fail with the error in StrictTlsClientProtocol.
949+
await gel_protocol.wait_closed()
950+
899951

900952
class TestServerCancellation(tb.TestCase):
901953
@contextlib.asynccontextmanager
@@ -1029,3 +1081,33 @@ async def test_proto_gh3170_connection_lost_error(self):
10291081
except Exception:
10301082
await con.aclose()
10311083
raise
1084+
1085+
1086+
class GelProtocol(protocol.protocol.Protocol):
1087+
_close_fut: Optional[asyncio.Future] = None
1088+
1089+
def connection_lost(self, exc):
1090+
if self._close_fut is not None:
1091+
if exc is None:
1092+
self._close_fut.set_result(None)
1093+
else:
1094+
self._close_fut.set_exception(exc)
1095+
super().connection_lost(exc)
1096+
1097+
async def wait_closed(self):
1098+
self._close_fut = asyncio.Future()
1099+
try:
1100+
await self._close_fut
1101+
finally:
1102+
self._close_fut = None
1103+
1104+
1105+
class StrictTlsClientProtocol(asyncio.sslproto.SSLProtocol):
1106+
def connection_lost(self, exc):
1107+
if self._state == asyncio.sslproto.SSLProtocolState.WRAPPED:
1108+
if exc is None:
1109+
exc = ConnectionResetError(
1110+
'peer closed connection without sending '
1111+
'TLS close_notify'
1112+
)
1113+
super().connection_lost(exc)

0 commit comments

Comments
 (0)