diff --git a/tests/conftest.py b/tests/conftest.py index c75de48cb..83d7a78f3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,11 +17,15 @@ import asyncio import os import socket +import ssl from threading import Thread -from typing import Any, AsyncGenerator, Generator +from typing import Any, AsyncGenerator +from aiofiles.tempfile import TemporaryDirectory from aiohttp import web +from cryptography.hazmat.primitives import serialization import pytest # noqa F401 Needed to run the tests +from unit.mocks import create_ssl_context # type: ignore from unit.mocks import FakeCredentials # type: ignore from unit.mocks import FakeCSQLInstance # type: ignore @@ -29,6 +33,7 @@ from google.cloud.sql.connector.connection_name import ConnectionName from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.utils import generate_keys +from google.cloud.sql.connector.utils import write_to_file SCOPES = ["https://www.googleapis.com/auth/sqlservice.admin"] @@ -79,25 +84,60 @@ def fake_credentials() -> FakeCredentials: return FakeCredentials() -def mock_server(server_sock: socket.socket) -> None: - """Create mock server listening on specified ip_address and port.""" +async def start_proxy_server(instance: FakeCSQLInstance) -> None: + """Run local proxy server capable of performing mTLS""" ip_address = "127.0.0.1" port = 3307 - server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server_sock.bind((ip_address, port)) - server_sock.listen(0) - server_sock.accept() + # create socket + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + # create SSL/TLS context + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + context.minimum_version = ssl.TLSVersion.TLSv1_3 + # tmpdir and its contents are automatically deleted after the CA cert + # and cert chain are loaded into the SSLcontext. The values + # need to be written to files in order to be loaded by the SSLContext + server_key_bytes = instance.server_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=serialization.NoEncryption(), + ) + async with TemporaryDirectory() as tmpdir: + server_filename, _, key_filename = await write_to_file( + tmpdir, instance.server_cert_pem, "", server_key_bytes + ) + context.load_cert_chain(server_filename, key_filename) + # allow socket to be re-used + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + # bind socket to Cloud SQL proxy server port on localhost + sock.bind((ip_address, port)) + # listen for incoming connections + sock.listen(5) + + with context.wrap_socket(sock, server_side=True) as ssock: + while True: + conn, _ = ssock.accept() + conn.close() + + +@pytest.fixture(scope="session") +def proxy_server(fake_instance: FakeCSQLInstance) -> None: + """Run local proxy server capable of performing mTLS""" + thread = Thread( + target=asyncio.run, + args=( + start_proxy_server( + fake_instance, + ), + ), + daemon=True, + ) + thread.start() + thread.join(1.0) # add a delay to allow the proxy server to start @pytest.fixture -def server() -> Generator: - """Create thread with server listening on proper port""" - server_sock = socket.socket() - thread = Thread(target=mock_server, args=(server_sock,), daemon=True) - thread.start() - yield thread - server_sock.close() - thread.join() +async def context(fake_instance: FakeCSQLInstance) -> ssl.SSLContext: + return await create_ssl_context(fake_instance) @pytest.fixture @@ -107,7 +147,7 @@ def kwargs() -> Any: return kwargs -@pytest.fixture +@pytest.fixture(scope="session") def fake_instance() -> FakeCSQLInstance: return FakeCSQLInstance() diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index cd3299b7f..66bf64a32 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2022 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,6 +16,8 @@ # file containing all mocks used for Cloud SQL Python Connector unit tests +from __future__ import annotations + import datetime import json import ssl @@ -184,28 +186,28 @@ def client_key_signed_cert( .not_valid_after(cert_expiration) # type: ignore ) return ( - cert.sign(priv_key, hashes.SHA256(), default_backend()) + cert.sign(priv_key, hashes.SHA256()) .public_bytes(encoding=serialization.Encoding.PEM) .decode("UTF-8") ) -async def create_ssl_context() -> ssl.SSLContext: +async def create_ssl_context(instance: FakeCSQLInstance) -> ssl.SSLContext: """Helper method to build an ssl.SSLContext for tests""" - # generate keys and certs for test - cert, private_key = generate_cert("my-project", "my-instance") - server_ca_cert = self_signed_cert(cert, private_key) client_private, client_bytes = await generate_keys() client_key: rsa.RSAPublicKey = serialization.load_pem_public_key( - client_bytes.encode("UTF-8"), default_backend() + client_bytes.encode("UTF-8"), ) # type: ignore - ephemeral_cert = client_key_signed_cert(cert, private_key, client_key) - # build default ssl.SSLContext - context = ssl.create_default_context() + ephemeral_cert = client_key_signed_cert( + instance.server_ca, instance.server_key, client_key + ) + # create SSL/TLS context + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname = False # load ssl.SSLContext with certs async with TemporaryDirectory() as tmpdir: ca_filename, cert_filename, key_filename = await write_to_file( - tmpdir, server_ca_cert, ephemeral_cert, client_private + tmpdir, instance.server_cert_pem, ephemeral_cert, client_private ) context.load_cert_chain(cert_filename, keyfile=key_filename) context.load_verify_locations(cafile=ca_filename) @@ -279,8 +281,8 @@ async def generate_ephemeral(self, request: Any) -> web.Response: body = await request.json() pub_key = body["public_key"] client_key: rsa.RSAPublicKey = serialization.load_pem_public_key( - pub_key.encode("UTF-8"), default_backend() - ) # type: ignore + pub_key.encode("UTF-8"), + ) ephemeral_cert = client_key_signed_cert( self.server_ca, self.server_key, diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 1a3d60917..3699ddc2d 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -186,17 +186,18 @@ async def test_RefreshAheadCache_close(cache: RefreshAheadCache) -> None: @pytest.mark.asyncio async def test_perform_refresh( cache: RefreshAheadCache, - fake_instance: mocks.FakeCSQLInstance, ) -> None: """ Test that _perform_refresh returns valid ConnectionInfo object. """ instance_metadata = await cache._perform_refresh() - # verify instance metadata object is returned assert isinstance(instance_metadata, ConnectionInfo) # verify instance metadata expiration - assert fake_instance.server_cert.not_valid_after_utc == instance_metadata.expiration + assert ( + cache._client.instance.cert_expiration.replace(microsecond=0) + == instance_metadata.expiration + ) @pytest.mark.asyncio diff --git a/tests/unit/test_monitored_cache.py b/tests/unit/test_monitored_cache.py index 1eea4eb46..1c1f1df86 100644 --- a/tests/unit/test_monitored_cache.py +++ b/tests/unit/test_monitored_cache.py @@ -14,13 +14,13 @@ import asyncio import socket +import ssl import dns.message import dns.rdataclass import dns.rdatatype import dns.resolver from mock import patch -from mocks import create_ssl_context import pytest from google.cloud.sql.connector.client import CloudSQLClient @@ -149,8 +149,10 @@ async def test_MonitoredCache_with_disabled_failover( assert monitored_cache.closed is True -@pytest.mark.usefixtures("server") -async def test_MonitoredCache_check_domain_name(fake_client: CloudSQLClient) -> None: +@pytest.mark.usefixtures("proxy_server") +async def test_MonitoredCache_check_domain_name( + context: ssl.SSLContext, fake_client: CloudSQLClient +) -> None: """ Test that MonitoredCache is closed when _check_domain_name has domain change. """ @@ -177,11 +179,9 @@ async def test_MonitoredCache_check_domain_name(fake_client: CloudSQLClient) -> # configure a local socket ip_addr = "127.0.0.1" - context = await create_ssl_context() sock = context.wrap_socket( socket.create_connection((ip_addr, 3307)), server_hostname=ip_addr, - do_handshake_on_connect=False, ) # verify socket is open assert sock.fileno() != -1 @@ -198,8 +198,10 @@ async def test_MonitoredCache_check_domain_name(fake_client: CloudSQLClient) -> assert sock.fileno() == -1 -@pytest.mark.usefixtures("server") -async def test_MonitoredCache_purge_closed_sockets(fake_client: CloudSQLClient) -> None: +@pytest.mark.usefixtures("proxy_server") +async def test_MonitoredCache_purge_closed_sockets( + context: ssl.SSLContext, fake_client: CloudSQLClient +) -> None: """ Test that MonitoredCache._purge_closed_sockets removes closed sockets from cache. @@ -215,11 +217,9 @@ async def test_MonitoredCache_purge_closed_sockets(fake_client: CloudSQLClient) ) # configure a local socket ip_addr = "127.0.0.1" - context = await create_ssl_context() sock = context.wrap_socket( socket.create_connection((ip_addr, 3307)), server_hostname=ip_addr, - do_handshake_on_connect=False, ) # set failover to 0 to disable polling diff --git a/tests/unit/test_pg8000.py b/tests/unit/test_pg8000.py index e01a53445..2c003b8a9 100644 --- a/tests/unit/test_pg8000.py +++ b/tests/unit/test_pg8000.py @@ -15,26 +15,22 @@ """ import socket +import ssl from typing import Any from mock import patch -from mocks import create_ssl_context import pytest from google.cloud.sql.connector.pg8000 import connect -@pytest.mark.usefixtures("server") -@pytest.mark.asyncio -async def test_pg8000(kwargs: Any) -> None: +@pytest.mark.usefixtures("proxy_server") +async def test_pg8000(context: ssl.SSLContext, kwargs: Any) -> None: """Test to verify that pg8000 gets to proper connection call.""" ip_addr = "127.0.0.1" - # build ssl.SSLContext - context = await create_ssl_context() sock = context.wrap_socket( socket.create_connection((ip_addr, 3307)), server_hostname=ip_addr, - do_handshake_on_connect=False, ) with patch("pg8000.dbapi.connect") as mock_connect: mock_connect.return_value = True diff --git a/tests/unit/test_pymysql.py b/tests/unit/test_pymysql.py index 66b1f22a3..13cd8e98a 100644 --- a/tests/unit/test_pymysql.py +++ b/tests/unit/test_pymysql.py @@ -19,7 +19,6 @@ from typing import Any from mock import patch -from mocks import create_ssl_context import pytest from google.cloud.sql.connector.pymysql import connect as pymysql_connect @@ -33,17 +32,14 @@ def connect(sock: ssl.SSLSocket) -> None: # type: ignore assert isinstance(sock, ssl.SSLSocket) -@pytest.mark.usefixtures("server") +@pytest.mark.usefixtures("proxy_server") @pytest.mark.asyncio -async def test_pymysql(kwargs: Any) -> None: +async def test_pymysql(context: ssl.SSLContext, kwargs: Any) -> None: """Test to verify that pymysql gets to proper connection call.""" ip_addr = "127.0.0.1" - # build ssl.SSLContext - context = await create_ssl_context() sock = context.wrap_socket( socket.create_connection((ip_addr, 3307)), server_hostname=ip_addr, - do_handshake_on_connect=False, ) kwargs["timeout"] = 30 with patch("pymysql.Connection") as mock_connect: diff --git a/tests/unit/test_pytds.py b/tests/unit/test_pytds.py index 9efe00ee5..faa20ad8c 100644 --- a/tests/unit/test_pytds.py +++ b/tests/unit/test_pytds.py @@ -16,10 +16,10 @@ import platform import socket +import ssl from typing import Any from mock import patch -from mocks import create_ssl_context import pytest from google.cloud.sql.connector.exceptions import PlatformNotSupportedError @@ -36,17 +36,13 @@ def stub_platform_windows() -> str: return "Windows" -@pytest.mark.usefixtures("server") -@pytest.mark.asyncio -async def test_pytds(kwargs: Any) -> None: +@pytest.mark.usefixtures("proxy_server") +async def test_pytds(context: ssl.SSLContext, kwargs: Any) -> None: """Test to verify that pytds gets to proper connection call.""" ip_addr = "127.0.0.1" - # build ssl.SSLContext - context = await create_ssl_context() sock = context.wrap_socket( socket.create_connection((ip_addr, 3307)), server_hostname=ip_addr, - do_handshake_on_connect=False, ) with patch("pytds.connect") as mock_connect: @@ -57,20 +53,16 @@ async def test_pytds(kwargs: Any) -> None: assert mock_connect.assert_called_once -@pytest.mark.usefixtures("server") -@pytest.mark.asyncio -async def test_pytds_platform_error(kwargs: Any) -> None: +@pytest.mark.usefixtures("proxy_server") +async def test_pytds_platform_error(context: ssl.SSLContext, kwargs: Any) -> None: """Test to verify that pytds.connect throws proper PlatformNotSupportedError.""" ip_addr = "127.0.0.1" # stub operating system to Linux setattr(platform, "system", stub_platform_linux) assert platform.system() == "Linux" - # build ssl.SSLContext - context = await create_ssl_context() sock = context.wrap_socket( socket.create_connection((ip_addr, 3307)), server_hostname=ip_addr, - do_handshake_on_connect=False, ) # add active_directory_auth to kwargs kwargs["active_directory_auth"] = True @@ -79,9 +71,10 @@ async def test_pytds_platform_error(kwargs: Any) -> None: connect(ip_addr, sock, **kwargs) -@pytest.mark.usefixtures("server") -@pytest.mark.asyncio -async def test_pytds_windows_active_directory_auth(kwargs: Any) -> None: +@pytest.mark.usefixtures("proxy_server") +async def test_pytds_windows_active_directory_auth( + context: ssl.SSLContext, kwargs: Any +) -> None: """ Test to verify that pytds gets to connection call on Windows with active_directory_auth arg set. @@ -90,12 +83,9 @@ async def test_pytds_windows_active_directory_auth(kwargs: Any) -> None: # stub operating system to Windows setattr(platform, "system", stub_platform_windows) assert platform.system() == "Windows" - # build ssl.SSLContext - context = await create_ssl_context() sock = context.wrap_socket( socket.create_connection((ip_addr, 3307)), server_hostname=ip_addr, - do_handshake_on_connect=False, ) # add active_directory_auth and server_name to kwargs kwargs["active_directory_auth"] = True