1717import asyncio
1818import os
1919import socket
20+ import ssl
2021from threading import Thread
21- from typing import Any , AsyncGenerator , Generator
22+ from typing import Any , AsyncGenerator
2223
24+ from aiofiles .tempfile import TemporaryDirectory
2325from aiohttp import web
26+ from cryptography .hazmat .primitives import serialization
2427import pytest # noqa F401 Needed to run the tests
28+ from unit .mocks import create_ssl_context # type: ignore
2529from unit .mocks import FakeCredentials # type: ignore
2630from unit .mocks import FakeCSQLInstance # type: ignore
2731
2832from google .cloud .sql .connector .client import CloudSQLClient
2933from google .cloud .sql .connector .connection_name import ConnectionName
3034from google .cloud .sql .connector .instance import RefreshAheadCache
3135from google .cloud .sql .connector .utils import generate_keys
36+ from google .cloud .sql .connector .utils import write_to_file
3237
3338SCOPES = ["https://www.googleapis.com/auth/sqlservice.admin" ]
3439
@@ -79,25 +84,60 @@ def fake_credentials() -> FakeCredentials:
7984 return FakeCredentials ()
8085
8186
82- def mock_server ( server_sock : socket . socket ) -> None :
83- """Create mock server listening on specified ip_address and port. """
87+ async def start_proxy_server ( instance : FakeCSQLInstance ) -> None :
88+ """Run local proxy server capable of performing mTLS """
8489 ip_address = "127.0.0.1"
8590 port = 3307
86- server_sock .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEADDR , 1 )
87- server_sock .bind ((ip_address , port ))
88- server_sock .listen (0 )
89- server_sock .accept ()
91+ # create socket
92+ with socket .socket (socket .AF_INET , socket .SOCK_STREAM ) as sock :
93+ # create SSL/TLS context
94+ context = ssl .SSLContext (ssl .PROTOCOL_TLS_SERVER )
95+ context .minimum_version = ssl .TLSVersion .TLSv1_3
96+ # tmpdir and its contents are automatically deleted after the CA cert
97+ # and cert chain are loaded into the SSLcontext. The values
98+ # need to be written to files in order to be loaded by the SSLContext
99+ server_key_bytes = instance .server_key .private_bytes (
100+ encoding = serialization .Encoding .PEM ,
101+ format = serialization .PrivateFormat .TraditionalOpenSSL ,
102+ encryption_algorithm = serialization .NoEncryption (),
103+ )
104+ async with TemporaryDirectory () as tmpdir :
105+ server_filename , _ , key_filename = await write_to_file (
106+ tmpdir , instance .server_cert_pem , "" , server_key_bytes
107+ )
108+ context .load_cert_chain (server_filename , key_filename )
109+ # allow socket to be re-used
110+ sock .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEADDR , 1 )
111+ # bind socket to Cloud SQL proxy server port on localhost
112+ sock .bind ((ip_address , port ))
113+ # listen for incoming connections
114+ sock .listen (5 )
115+
116+ with context .wrap_socket (sock , server_side = True ) as ssock :
117+ while True :
118+ conn , _ = ssock .accept ()
119+ conn .close ()
120+
121+
122+ @pytest .fixture (scope = "session" )
123+ def proxy_server (fake_instance : FakeCSQLInstance ) -> None :
124+ """Run local proxy server capable of performing mTLS"""
125+ thread = Thread (
126+ target = asyncio .run ,
127+ args = (
128+ start_proxy_server (
129+ fake_instance ,
130+ ),
131+ ),
132+ daemon = True ,
133+ )
134+ thread .start ()
135+ thread .join (1.0 ) # add a delay to allow the proxy server to start
90136
91137
92138@pytest .fixture
93- def server () -> Generator :
94- """Create thread with server listening on proper port"""
95- server_sock = socket .socket ()
96- thread = Thread (target = mock_server , args = (server_sock ,), daemon = True )
97- thread .start ()
98- yield thread
99- server_sock .close ()
100- thread .join ()
139+ async def context (fake_instance : FakeCSQLInstance ) -> ssl .SSLContext :
140+ return await create_ssl_context (fake_instance )
101141
102142
103143@pytest .fixture
@@ -107,7 +147,7 @@ def kwargs() -> Any:
107147 return kwargs
108148
109149
110- @pytest .fixture
150+ @pytest .fixture ( scope = "session" )
111151def fake_instance () -> FakeCSQLInstance :
112152 return FakeCSQLInstance ()
113153
0 commit comments