Skip to content

Commit 5e73fc5

Browse files
test: update server used for tests to use SSL/TLS (#1262)
This PR improves the local server used for unit tests. The previous server fixture was just a plain socket unable to perform SSL/TLS. This meant in certain tests we were skipping the SSL/TLS handshake because it would fail. The new proxy_server fixture more closely behaves to that of the real Proxy server. It allows SSL/TLS by loading in the server CA used for testing, the client can then use the new context fixture to properly use SSL to connect to the server. This will provide a foundation for testing the metadata exchange which will be introduced in the near future.
1 parent be6b154 commit 5e73fc5

File tree

7 files changed

+98
-73
lines changed

7 files changed

+98
-73
lines changed

tests/conftest.py

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,23 @@
1717
import asyncio
1818
import os
1919
import socket
20+
import ssl
2021
from threading import Thread
21-
from typing import Any, AsyncGenerator, Generator
22+
from typing import Any, AsyncGenerator
2223

24+
from aiofiles.tempfile import TemporaryDirectory
2325
from aiohttp import web
26+
from cryptography.hazmat.primitives import serialization
2427
import pytest # noqa F401 Needed to run the tests
28+
from unit.mocks import create_ssl_context # type: ignore
2529
from unit.mocks import FakeCredentials # type: ignore
2630
from unit.mocks import FakeCSQLInstance # type: ignore
2731

2832
from google.cloud.sql.connector.client import CloudSQLClient
2933
from google.cloud.sql.connector.connection_name import ConnectionName
3034
from google.cloud.sql.connector.instance import RefreshAheadCache
3135
from google.cloud.sql.connector.utils import generate_keys
36+
from google.cloud.sql.connector.utils import write_to_file
3237

3338
SCOPES = ["https://www.googleapis.com/auth/sqlservice.admin"]
3439

@@ -79,25 +84,60 @@ def fake_credentials() -> FakeCredentials:
7984
return FakeCredentials()
8085

8186

82-
def mock_server(server_sock: socket.socket) -> None:
83-
"""Create mock server listening on specified ip_address and port."""
87+
async def start_proxy_server(instance: FakeCSQLInstance) -> None:
88+
"""Run local proxy server capable of performing mTLS"""
8489
ip_address = "127.0.0.1"
8590
port = 3307
86-
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
87-
server_sock.bind((ip_address, port))
88-
server_sock.listen(0)
89-
server_sock.accept()
91+
# create socket
92+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
93+
# create SSL/TLS context
94+
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
95+
context.minimum_version = ssl.TLSVersion.TLSv1_3
96+
# tmpdir and its contents are automatically deleted after the CA cert
97+
# and cert chain are loaded into the SSLcontext. The values
98+
# need to be written to files in order to be loaded by the SSLContext
99+
server_key_bytes = instance.server_key.private_bytes(
100+
encoding=serialization.Encoding.PEM,
101+
format=serialization.PrivateFormat.TraditionalOpenSSL,
102+
encryption_algorithm=serialization.NoEncryption(),
103+
)
104+
async with TemporaryDirectory() as tmpdir:
105+
server_filename, _, key_filename = await write_to_file(
106+
tmpdir, instance.server_cert_pem, "", server_key_bytes
107+
)
108+
context.load_cert_chain(server_filename, key_filename)
109+
# allow socket to be re-used
110+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
111+
# bind socket to Cloud SQL proxy server port on localhost
112+
sock.bind((ip_address, port))
113+
# listen for incoming connections
114+
sock.listen(5)
115+
116+
with context.wrap_socket(sock, server_side=True) as ssock:
117+
while True:
118+
conn, _ = ssock.accept()
119+
conn.close()
120+
121+
122+
@pytest.fixture(scope="session")
123+
def proxy_server(fake_instance: FakeCSQLInstance) -> None:
124+
"""Run local proxy server capable of performing mTLS"""
125+
thread = Thread(
126+
target=asyncio.run,
127+
args=(
128+
start_proxy_server(
129+
fake_instance,
130+
),
131+
),
132+
daemon=True,
133+
)
134+
thread.start()
135+
thread.join(1.0) # add a delay to allow the proxy server to start
90136

91137

92138
@pytest.fixture
93-
def server() -> Generator:
94-
"""Create thread with server listening on proper port"""
95-
server_sock = socket.socket()
96-
thread = Thread(target=mock_server, args=(server_sock,), daemon=True)
97-
thread.start()
98-
yield thread
99-
server_sock.close()
100-
thread.join()
139+
async def context(fake_instance: FakeCSQLInstance) -> ssl.SSLContext:
140+
return await create_ssl_context(fake_instance)
101141

102142

103143
@pytest.fixture
@@ -107,7 +147,7 @@ def kwargs() -> Any:
107147
return kwargs
108148

109149

110-
@pytest.fixture
150+
@pytest.fixture(scope="session")
111151
def fake_instance() -> FakeCSQLInstance:
112152
return FakeCSQLInstance()
113153

tests/unit/mocks.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""""
1+
"""
22
Copyright 2022 Google LLC
33
44
Licensed under the Apache License, Version 2.0 (the "License");
@@ -16,6 +16,8 @@
1616

1717
# file containing all mocks used for Cloud SQL Python Connector unit tests
1818

19+
from __future__ import annotations
20+
1921
import datetime
2022
import json
2123
import ssl
@@ -184,28 +186,28 @@ def client_key_signed_cert(
184186
.not_valid_after(cert_expiration) # type: ignore
185187
)
186188
return (
187-
cert.sign(priv_key, hashes.SHA256(), default_backend())
189+
cert.sign(priv_key, hashes.SHA256())
188190
.public_bytes(encoding=serialization.Encoding.PEM)
189191
.decode("UTF-8")
190192
)
191193

192194

193-
async def create_ssl_context() -> ssl.SSLContext:
195+
async def create_ssl_context(instance: FakeCSQLInstance) -> ssl.SSLContext:
194196
"""Helper method to build an ssl.SSLContext for tests"""
195-
# generate keys and certs for test
196-
cert, private_key = generate_cert("my-project", "my-instance")
197-
server_ca_cert = self_signed_cert(cert, private_key)
198197
client_private, client_bytes = await generate_keys()
199198
client_key: rsa.RSAPublicKey = serialization.load_pem_public_key(
200-
client_bytes.encode("UTF-8"), default_backend()
199+
client_bytes.encode("UTF-8"),
201200
) # type: ignore
202-
ephemeral_cert = client_key_signed_cert(cert, private_key, client_key)
203-
# build default ssl.SSLContext
204-
context = ssl.create_default_context()
201+
ephemeral_cert = client_key_signed_cert(
202+
instance.server_ca, instance.server_key, client_key
203+
)
204+
# create SSL/TLS context
205+
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
206+
context.check_hostname = False
205207
# load ssl.SSLContext with certs
206208
async with TemporaryDirectory() as tmpdir:
207209
ca_filename, cert_filename, key_filename = await write_to_file(
208-
tmpdir, server_ca_cert, ephemeral_cert, client_private
210+
tmpdir, instance.server_cert_pem, ephemeral_cert, client_private
209211
)
210212
context.load_cert_chain(cert_filename, keyfile=key_filename)
211213
context.load_verify_locations(cafile=ca_filename)
@@ -279,8 +281,8 @@ async def generate_ephemeral(self, request: Any) -> web.Response:
279281
body = await request.json()
280282
pub_key = body["public_key"]
281283
client_key: rsa.RSAPublicKey = serialization.load_pem_public_key(
282-
pub_key.encode("UTF-8"), default_backend()
283-
) # type: ignore
284+
pub_key.encode("UTF-8"),
285+
)
284286
ephemeral_cert = client_key_signed_cert(
285287
self.server_ca,
286288
self.server_key,

tests/unit/test_instance.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,17 +186,18 @@ async def test_RefreshAheadCache_close(cache: RefreshAheadCache) -> None:
186186
@pytest.mark.asyncio
187187
async def test_perform_refresh(
188188
cache: RefreshAheadCache,
189-
fake_instance: mocks.FakeCSQLInstance,
190189
) -> None:
191190
"""
192191
Test that _perform_refresh returns valid ConnectionInfo object.
193192
"""
194193
instance_metadata = await cache._perform_refresh()
195-
196194
# verify instance metadata object is returned
197195
assert isinstance(instance_metadata, ConnectionInfo)
198196
# verify instance metadata expiration
199-
assert fake_instance.server_cert.not_valid_after_utc == instance_metadata.expiration
197+
assert (
198+
cache._client.instance.cert_expiration.replace(microsecond=0)
199+
== instance_metadata.expiration
200+
)
200201

201202

202203
@pytest.mark.asyncio

tests/unit/test_monitored_cache.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414

1515
import asyncio
1616
import socket
17+
import ssl
1718

1819
import dns.message
1920
import dns.rdataclass
2021
import dns.rdatatype
2122
import dns.resolver
2223
from mock import patch
23-
from mocks import create_ssl_context
2424
import pytest
2525

2626
from google.cloud.sql.connector.client import CloudSQLClient
@@ -149,8 +149,10 @@ async def test_MonitoredCache_with_disabled_failover(
149149
assert monitored_cache.closed is True
150150

151151

152-
@pytest.mark.usefixtures("server")
153-
async def test_MonitoredCache_check_domain_name(fake_client: CloudSQLClient) -> None:
152+
@pytest.mark.usefixtures("proxy_server")
153+
async def test_MonitoredCache_check_domain_name(
154+
context: ssl.SSLContext, fake_client: CloudSQLClient
155+
) -> None:
154156
"""
155157
Test that MonitoredCache is closed when _check_domain_name has domain change.
156158
"""
@@ -177,11 +179,9 @@ async def test_MonitoredCache_check_domain_name(fake_client: CloudSQLClient) ->
177179

178180
# configure a local socket
179181
ip_addr = "127.0.0.1"
180-
context = await create_ssl_context()
181182
sock = context.wrap_socket(
182183
socket.create_connection((ip_addr, 3307)),
183184
server_hostname=ip_addr,
184-
do_handshake_on_connect=False,
185185
)
186186
# verify socket is open
187187
assert sock.fileno() != -1
@@ -198,8 +198,10 @@ async def test_MonitoredCache_check_domain_name(fake_client: CloudSQLClient) ->
198198
assert sock.fileno() == -1
199199

200200

201-
@pytest.mark.usefixtures("server")
202-
async def test_MonitoredCache_purge_closed_sockets(fake_client: CloudSQLClient) -> None:
201+
@pytest.mark.usefixtures("proxy_server")
202+
async def test_MonitoredCache_purge_closed_sockets(
203+
context: ssl.SSLContext, fake_client: CloudSQLClient
204+
) -> None:
203205
"""
204206
Test that MonitoredCache._purge_closed_sockets removes closed sockets from
205207
cache.
@@ -215,11 +217,9 @@ async def test_MonitoredCache_purge_closed_sockets(fake_client: CloudSQLClient)
215217
)
216218
# configure a local socket
217219
ip_addr = "127.0.0.1"
218-
context = await create_ssl_context()
219220
sock = context.wrap_socket(
220221
socket.create_connection((ip_addr, 3307)),
221222
server_hostname=ip_addr,
222-
do_handshake_on_connect=False,
223223
)
224224

225225
# set failover to 0 to disable polling

tests/unit/test_pg8000.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,22 @@
1515
"""
1616

1717
import socket
18+
import ssl
1819
from typing import Any
1920

2021
from mock import patch
21-
from mocks import create_ssl_context
2222
import pytest
2323

2424
from google.cloud.sql.connector.pg8000 import connect
2525

2626

27-
@pytest.mark.usefixtures("server")
28-
@pytest.mark.asyncio
29-
async def test_pg8000(kwargs: Any) -> None:
27+
@pytest.mark.usefixtures("proxy_server")
28+
async def test_pg8000(context: ssl.SSLContext, kwargs: Any) -> None:
3029
"""Test to verify that pg8000 gets to proper connection call."""
3130
ip_addr = "127.0.0.1"
32-
# build ssl.SSLContext
33-
context = await create_ssl_context()
3431
sock = context.wrap_socket(
3532
socket.create_connection((ip_addr, 3307)),
3633
server_hostname=ip_addr,
37-
do_handshake_on_connect=False,
3834
)
3935
with patch("pg8000.dbapi.connect") as mock_connect:
4036
mock_connect.return_value = True

tests/unit/test_pymysql.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import Any
2020

2121
from mock import patch
22-
from mocks import create_ssl_context
2322
import pytest
2423

2524
from google.cloud.sql.connector.pymysql import connect as pymysql_connect
@@ -33,17 +32,14 @@ def connect(sock: ssl.SSLSocket) -> None: # type: ignore
3332
assert isinstance(sock, ssl.SSLSocket)
3433

3534

36-
@pytest.mark.usefixtures("server")
35+
@pytest.mark.usefixtures("proxy_server")
3736
@pytest.mark.asyncio
38-
async def test_pymysql(kwargs: Any) -> None:
37+
async def test_pymysql(context: ssl.SSLContext, kwargs: Any) -> None:
3938
"""Test to verify that pymysql gets to proper connection call."""
4039
ip_addr = "127.0.0.1"
41-
# build ssl.SSLContext
42-
context = await create_ssl_context()
4340
sock = context.wrap_socket(
4441
socket.create_connection((ip_addr, 3307)),
4542
server_hostname=ip_addr,
46-
do_handshake_on_connect=False,
4743
)
4844
kwargs["timeout"] = 30
4945
with patch("pymysql.Connection") as mock_connect:

0 commit comments

Comments
 (0)