Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 56 additions & 16 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,23 @@
import asyncio
import os
import socket
import ssl
from threading import Thread
from typing import Any, AsyncGenerator, Generator
from typing import Any, AsyncGenerator

from aiofiles.tempfile import TemporaryDirectory
from aiohttp import web
from cryptography.hazmat.primitives import serialization
import pytest # noqa F401 Needed to run the tests
from unit.mocks import create_ssl_context # type: ignore
from unit.mocks import FakeCredentials # type: ignore
from unit.mocks import FakeCSQLInstance # type: ignore

from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.instance import RefreshAheadCache
from google.cloud.sql.connector.utils import generate_keys
from google.cloud.sql.connector.utils import write_to_file

SCOPES = ["https://www.googleapis.com/auth/sqlservice.admin"]

Expand Down Expand Up @@ -79,25 +84,60 @@ def fake_credentials() -> FakeCredentials:
return FakeCredentials()


def mock_server(server_sock: socket.socket) -> None:
"""Create mock server listening on specified ip_address and port."""
async def start_proxy_server(instance: FakeCSQLInstance) -> None:
"""Run local proxy server capable of performing mTLS"""
ip_address = "127.0.0.1"
port = 3307
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_sock.bind((ip_address, port))
server_sock.listen(0)
server_sock.accept()
# create socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
# create SSL/TLS context
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
context.minimum_version = ssl.TLSVersion.TLSv1_3
# tmpdir and its contents are automatically deleted after the CA cert
# and cert chain are loaded into the SSLcontext. The values
# need to be written to files in order to be loaded by the SSLContext
server_key_bytes = instance.server_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)
async with TemporaryDirectory() as tmpdir:
server_filename, _, key_filename = await write_to_file(
tmpdir, instance.server_cert_pem, "", server_key_bytes
)
context.load_cert_chain(server_filename, key_filename)
# allow socket to be re-used
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# bind socket to Cloud SQL proxy server port on localhost
sock.bind((ip_address, port))
# listen for incoming connections
sock.listen(5)

with context.wrap_socket(sock, server_side=True) as ssock:
while True:
conn, _ = ssock.accept()
conn.close()


@pytest.fixture(scope="session")
def proxy_server(fake_instance: FakeCSQLInstance) -> None:
"""Run local proxy server capable of performing mTLS"""
thread = Thread(
target=asyncio.run,
args=(
start_proxy_server(
fake_instance,
),
),
daemon=True,
)
thread.start()
thread.join(1.0) # add a delay to allow the proxy server to start


@pytest.fixture
def server() -> Generator:
"""Create thread with server listening on proper port"""
server_sock = socket.socket()
thread = Thread(target=mock_server, args=(server_sock,), daemon=True)
thread.start()
yield thread
server_sock.close()
thread.join()
async def context(fake_instance: FakeCSQLInstance) -> ssl.SSLContext:
return await create_ssl_context(fake_instance)


@pytest.fixture
Expand All @@ -107,7 +147,7 @@ def kwargs() -> Any:
return kwargs


@pytest.fixture
@pytest.fixture(scope="session")
def fake_instance() -> FakeCSQLInstance:
return FakeCSQLInstance()

Expand Down
28 changes: 15 additions & 13 deletions tests/unit/mocks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""""
"""
Copyright 2022 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -16,6 +16,8 @@

# file containing all mocks used for Cloud SQL Python Connector unit tests

from __future__ import annotations

import datetime
import json
import ssl
Expand Down Expand Up @@ -184,28 +186,28 @@ def client_key_signed_cert(
.not_valid_after(cert_expiration) # type: ignore
)
return (
cert.sign(priv_key, hashes.SHA256(), default_backend())
cert.sign(priv_key, hashes.SHA256())
.public_bytes(encoding=serialization.Encoding.PEM)
.decode("UTF-8")
)


async def create_ssl_context() -> ssl.SSLContext:
async def create_ssl_context(instance: FakeCSQLInstance) -> ssl.SSLContext:
"""Helper method to build an ssl.SSLContext for tests"""
# generate keys and certs for test
cert, private_key = generate_cert("my-project", "my-instance")
server_ca_cert = self_signed_cert(cert, private_key)
client_private, client_bytes = await generate_keys()
client_key: rsa.RSAPublicKey = serialization.load_pem_public_key(
client_bytes.encode("UTF-8"), default_backend()
client_bytes.encode("UTF-8"),
) # type: ignore
ephemeral_cert = client_key_signed_cert(cert, private_key, client_key)
# build default ssl.SSLContext
context = ssl.create_default_context()
ephemeral_cert = client_key_signed_cert(
instance.server_ca, instance.server_key, client_key
)
# create SSL/TLS context
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
context.check_hostname = False
# load ssl.SSLContext with certs
async with TemporaryDirectory() as tmpdir:
ca_filename, cert_filename, key_filename = await write_to_file(
tmpdir, server_ca_cert, ephemeral_cert, client_private
tmpdir, instance.server_cert_pem, ephemeral_cert, client_private
)
context.load_cert_chain(cert_filename, keyfile=key_filename)
context.load_verify_locations(cafile=ca_filename)
Expand Down Expand Up @@ -279,8 +281,8 @@ async def generate_ephemeral(self, request: Any) -> web.Response:
body = await request.json()
pub_key = body["public_key"]
client_key: rsa.RSAPublicKey = serialization.load_pem_public_key(
pub_key.encode("UTF-8"), default_backend()
) # type: ignore
pub_key.encode("UTF-8"),
)
ephemeral_cert = client_key_signed_cert(
self.server_ca,
self.server_key,
Expand Down
7 changes: 4 additions & 3 deletions tests/unit/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,17 +186,18 @@ async def test_RefreshAheadCache_close(cache: RefreshAheadCache) -> None:
@pytest.mark.asyncio
async def test_perform_refresh(
cache: RefreshAheadCache,
fake_instance: mocks.FakeCSQLInstance,
) -> None:
"""
Test that _perform_refresh returns valid ConnectionInfo object.
"""
instance_metadata = await cache._perform_refresh()

# verify instance metadata object is returned
assert isinstance(instance_metadata, ConnectionInfo)
# verify instance metadata expiration
assert fake_instance.server_cert.not_valid_after_utc == instance_metadata.expiration
assert (
cache._client.instance.cert_expiration.replace(microsecond=0)
== instance_metadata.expiration
)


@pytest.mark.asyncio
Expand Down
18 changes: 9 additions & 9 deletions tests/unit/test_monitored_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

import asyncio
import socket
import ssl

import dns.message
import dns.rdataclass
import dns.rdatatype
import dns.resolver
from mock import patch
from mocks import create_ssl_context
import pytest

from google.cloud.sql.connector.client import CloudSQLClient
Expand Down Expand Up @@ -149,8 +149,10 @@ async def test_MonitoredCache_with_disabled_failover(
assert monitored_cache.closed is True


@pytest.mark.usefixtures("server")
async def test_MonitoredCache_check_domain_name(fake_client: CloudSQLClient) -> None:
@pytest.mark.usefixtures("proxy_server")
async def test_MonitoredCache_check_domain_name(
context: ssl.SSLContext, fake_client: CloudSQLClient
) -> None:
"""
Test that MonitoredCache is closed when _check_domain_name has domain change.
"""
Expand All @@ -177,11 +179,9 @@ async def test_MonitoredCache_check_domain_name(fake_client: CloudSQLClient) ->

# configure a local socket
ip_addr = "127.0.0.1"
context = await create_ssl_context()
sock = context.wrap_socket(
socket.create_connection((ip_addr, 3307)),
server_hostname=ip_addr,
do_handshake_on_connect=False,
)
# verify socket is open
assert sock.fileno() != -1
Expand All @@ -198,8 +198,10 @@ async def test_MonitoredCache_check_domain_name(fake_client: CloudSQLClient) ->
assert sock.fileno() == -1


@pytest.mark.usefixtures("server")
async def test_MonitoredCache_purge_closed_sockets(fake_client: CloudSQLClient) -> None:
@pytest.mark.usefixtures("proxy_server")
async def test_MonitoredCache_purge_closed_sockets(
context: ssl.SSLContext, fake_client: CloudSQLClient
) -> None:
"""
Test that MonitoredCache._purge_closed_sockets removes closed sockets from
cache.
Expand All @@ -215,11 +217,9 @@ async def test_MonitoredCache_purge_closed_sockets(fake_client: CloudSQLClient)
)
# configure a local socket
ip_addr = "127.0.0.1"
context = await create_ssl_context()
sock = context.wrap_socket(
socket.create_connection((ip_addr, 3307)),
server_hostname=ip_addr,
do_handshake_on_connect=False,
)

# set failover to 0 to disable polling
Expand Down
10 changes: 3 additions & 7 deletions tests/unit/test_pg8000.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,22 @@
"""

import socket
import ssl
from typing import Any

from mock import patch
from mocks import create_ssl_context
import pytest

from google.cloud.sql.connector.pg8000 import connect


@pytest.mark.usefixtures("server")
@pytest.mark.asyncio
async def test_pg8000(kwargs: Any) -> None:
@pytest.mark.usefixtures("proxy_server")
async def test_pg8000(context: ssl.SSLContext, kwargs: Any) -> None:
"""Test to verify that pg8000 gets to proper connection call."""
ip_addr = "127.0.0.1"
# build ssl.SSLContext
context = await create_ssl_context()
sock = context.wrap_socket(
socket.create_connection((ip_addr, 3307)),
server_hostname=ip_addr,
do_handshake_on_connect=False,
)
with patch("pg8000.dbapi.connect") as mock_connect:
mock_connect.return_value = True
Expand Down
8 changes: 2 additions & 6 deletions tests/unit/test_pymysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from typing import Any

from mock import patch
from mocks import create_ssl_context
import pytest

from google.cloud.sql.connector.pymysql import connect as pymysql_connect
Expand All @@ -33,17 +32,14 @@ def connect(sock: ssl.SSLSocket) -> None: # type: ignore
assert isinstance(sock, ssl.SSLSocket)


@pytest.mark.usefixtures("server")
@pytest.mark.usefixtures("proxy_server")
@pytest.mark.asyncio
async def test_pymysql(kwargs: Any) -> None:
async def test_pymysql(context: ssl.SSLContext, kwargs: Any) -> None:
"""Test to verify that pymysql gets to proper connection call."""
ip_addr = "127.0.0.1"
# build ssl.SSLContext
context = await create_ssl_context()
sock = context.wrap_socket(
socket.create_connection((ip_addr, 3307)),
server_hostname=ip_addr,
do_handshake_on_connect=False,
)
kwargs["timeout"] = 30
with patch("pymysql.Connection") as mock_connect:
Expand Down
Loading
Loading