Skip to content

Commit 043c611

Browse files
refactor: move SSL/TLS context creation to ConnectionInfo (#1079)
1 parent 0d9b731 commit 043c611

File tree

7 files changed

+163
-130
lines changed

7 files changed

+163
-130
lines changed

google/cloud/sql/connector/connector.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,10 +330,17 @@ async def connect_async(
330330

331331
# async drivers are unblocking and can be awaited directly
332332
if driver in ASYNC_DRIVERS:
333-
return await connector(ip_address, instance_data.context, **kwargs)
333+
return await connector(
334+
ip_address,
335+
instance_data.create_ssl_context(enable_iam_auth),
336+
**kwargs,
337+
)
334338
# synchronous drivers are blocking and run using executor
335339
connect_partial = partial(
336-
connector, ip_address, instance_data.context, **kwargs
340+
connector,
341+
ip_address,
342+
instance_data.create_ssl_context(enable_iam_auth),
343+
**kwargs,
337344
)
338345
return await self._loop.run_in_executor(None, connect_partial)
339346

google/cloud/sql/connector/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,11 @@ class DnsNameResolutionError(Exception):
6262
Exception to be raised when the DnsName of a PSC connection to a
6363
Cloud SQL instance can not be resolved to a proper IP address.
6464
"""
65+
66+
67+
class RefreshNotValidError(Exception):
68+
"""
69+
Exception to be raised when the task returned from refresh is not valid.
70+
"""
71+
72+
pass

google/cloud/sql/connector/instance.py

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from __future__ import annotations
1818

1919
import asyncio
20+
from dataclasses import dataclass
2021
from enum import Enum
2122
import logging
2223
import re
@@ -29,6 +30,7 @@
2930
from google.cloud.sql.connector.client import CloudSQLClient
3031
from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported
3132
from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError
33+
from google.cloud.sql.connector.exceptions import RefreshNotValidError
3234
from google.cloud.sql.connector.exceptions import TLSVersionError
3335
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
3436
from google.cloud.sql.connector.refresh_utils import _is_valid
@@ -79,33 +81,39 @@ def _from_str(cls, ip_type_str: str) -> IPTypes:
7981
return cls(ip_type_str.upper())
8082

8183

84+
@dataclass
8285
class ConnectionInfo:
86+
"""Contains all necessary information to connect securely to the
87+
server-side Proxy running on a Cloud SQL instance."""
88+
89+
client_cert: str
90+
server_ca_cert: str
91+
private_key: bytes
8392
ip_addrs: Dict[str, Any]
84-
context: ssl.SSLContext
8593
database_version: str
8694
expiration: datetime.datetime
95+
context: ssl.SSLContext | None = None
8796

88-
def __init__(
89-
self,
90-
ephemeral_cert: str,
91-
database_version: str,
92-
ip_addrs: Dict[str, Any],
93-
private_key: bytes,
94-
server_ca_cert: str,
95-
expiration: datetime.datetime,
96-
enable_iam_auth: bool,
97-
) -> None:
98-
self.ip_addrs = ip_addrs
99-
self.database_version = database_version
100-
self.context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
97+
def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLContext:
98+
"""Constructs a SSL/TLS context for the given connection info.
99+
100+
Cache the SSL context to ensure we don't read from disk repeatedly when
101+
configuring a secure connection.
102+
"""
103+
# if SSL context is cached, use it
104+
if self.context is not None:
105+
return self.context
106+
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
101107

102108
# update ssl.PROTOCOL_TLS_CLIENT default
103-
self.context.check_hostname = False
109+
context.check_hostname = False
104110

111+
# TODO: remove if/else when Python 3.10 is min version. PEP 644 has been
112+
# implemented. The ssl module requires OpenSSL 1.1.1 or newer.
105113
# verify OpenSSL version supports TLSv1.3
106114
if ssl.HAS_TLSv1_3:
107115
# force TLSv1.3 if supported by client
108-
self.context.minimum_version = ssl.TLSVersion.TLSv1_3
116+
context.minimum_version = ssl.TLSVersion.TLSv1_3
109117
# fallback to TLSv1.2 for older versions of OpenSSL
110118
else:
111119
if enable_iam_auth:
@@ -119,18 +127,20 @@ def __init__(
119127
f"({ssl.OPENSSL_VERSION}), falling back to TLSv1.2\n"
120128
"Upgrade your OpenSSL version to 1.1.1 for TLSv1.3 support."
121129
)
122-
self.context.minimum_version = ssl.TLSVersion.TLSv1_2
123-
self.expiration = expiration
130+
context.minimum_version = ssl.TLSVersion.TLSv1_2
124131

125132
# tmpdir and its contents are automatically deleted after the CA cert
126133
# and ephemeral cert are loaded into the SSLcontext. The values
127134
# need to be written to files in order to be loaded by the SSLContext
128135
with TemporaryDirectory() as tmpdir:
129136
ca_filename, cert_filename, key_filename = write_to_file(
130-
tmpdir, server_ca_cert, ephemeral_cert, private_key
137+
tmpdir, self.server_ca_cert, self.client_cert, self.private_key
131138
)
132-
self.context.load_cert_chain(cert_filename, keyfile=key_filename)
133-
self.context.load_verify_locations(cafile=ca_filename)
139+
context.load_cert_chain(cert_filename, keyfile=key_filename)
140+
context.load_verify_locations(cafile=ca_filename)
141+
# set class attribute to cache context for subsequent calls
142+
self.context = context
143+
return context
134144

135145
def get_preferred_ip(self, ip_type: IPTypes) -> str:
136146
"""Returns the first IP address for the instance, according to the preference
@@ -272,12 +282,11 @@ async def _perform_refresh(self) -> ConnectionInfo:
272282

273283
return ConnectionInfo(
274284
ephemeral_cert,
275-
metadata["database_version"],
276-
metadata["ip_addresses"],
277-
priv_key,
278285
metadata["server_ca_cert"],
286+
priv_key,
287+
metadata["ip_addresses"],
288+
metadata["database_version"],
279289
expiration,
280-
self._enable_iam_auth,
281290
)
282291

283292
def _schedule_refresh(self, delay: int) -> asyncio.Task:
@@ -303,6 +312,11 @@ async def _refresh_task(self: RefreshAheadCache, delay: int) -> ConnectionInfo:
303312
await asyncio.sleep(delay)
304313
refresh_task = asyncio.create_task(self._perform_refresh())
305314
refresh_data = await refresh_task
315+
# check that refresh is valid
316+
if not await _is_valid(refresh_task):
317+
raise RefreshNotValidError(
318+
f"['{self._instance_connection_string}']: Invalid refresh operation. Certficate appears to be expired."
319+
)
306320
except asyncio.CancelledError:
307321
logger.debug(
308322
f"['{self._instance_connection_string}']: Schedule refresh task cancelled."

tests/conftest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def kwargs() -> Any:
105105
return kwargs
106106

107107

108-
@pytest.fixture(scope="session")
108+
@pytest.fixture
109109
def fake_instance() -> FakeCSQLInstance:
110110
return FakeCSQLInstance()
111111

@@ -133,7 +133,10 @@ async def fake_client(
133133
sqlserver_client_cert_uri, sqlserver_instance.generate_ephemeral
134134
)
135135
client_session = await aiohttp_client(app)
136-
return CloudSQLClient("", "", fake_credentials, client=client_session)
136+
client = CloudSQLClient("", "", fake_credentials, client=client_session)
137+
# add instance to client to control cert expiration etc.
138+
client.instance = fake_instance
139+
return client
137140

138141

139142
@pytest.fixture

tests/unit/mocks.py

Lines changed: 15 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from google.auth.credentials import Credentials
3232

3333
from google.cloud.sql.connector.connector import _DEFAULT_UNIVERSE_DOMAIN
34-
from google.cloud.sql.connector.instance import ConnectionInfo
3534
from google.cloud.sql.connector.utils import generate_keys
3635
from google.cloud.sql.connector.utils import write_to_file
3736

@@ -85,36 +84,6 @@ def valid(self) -> bool:
8584
return self.token is not None and not self.expired
8685

8786

88-
class BadRefresh(Exception):
89-
pass
90-
91-
92-
class MockMetadata(ConnectionInfo):
93-
"""Mock class for ConnectionInfo"""
94-
95-
def __init__(
96-
self, expiration: datetime.datetime, ip_addrs: Dict = {"PRIMARY": "0.0.0.0"}
97-
) -> None:
98-
self.expiration = expiration
99-
self.ip_addrs = ip_addrs
100-
101-
102-
async def instance_metadata_success(*args: Any, **kwargs: Any) -> MockMetadata:
103-
return MockMetadata(
104-
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(minutes=10)
105-
)
106-
107-
108-
async def instance_metadata_expired(*args: Any, **kwargs: Any) -> MockMetadata:
109-
return MockMetadata(
110-
datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(minutes=10)
111-
)
112-
113-
114-
async def instance_metadata_error(*args: Any, **kwargs: Any) -> None:
115-
raise BadRefresh("something went wrong...")
116-
117-
11887
def generate_cert(
11988
project: str,
12089
name: str,
@@ -166,6 +135,9 @@ def client_key_signed_cert(
166135
cert: x509.CertificateBuilder,
167136
priv_key: rsa.RSAPrivateKey,
168137
client_key: rsa.RSAPublicKey,
138+
cert_before: datetime.datetime = datetime.datetime.now(datetime.timezone.utc),
139+
cert_expiration: datetime.datetime = datetime.datetime.now(datetime.timezone.utc)
140+
+ datetime.timedelta(hours=1),
169141
) -> str:
170142
"""
171143
Create a PEM encoded certificate that is signed by given public key.
@@ -185,8 +157,8 @@ def client_key_signed_cert(
185157
.issuer_name(issuer)
186158
.public_key(client_key)
187159
.serial_number(x509.random_serial_number())
188-
.not_valid_before(cert._not_valid_before)
189-
.not_valid_after(cert._not_valid_after) # type: ignore
160+
.not_valid_before(cert_before)
161+
.not_valid_after(cert_expiration) # type: ignore
190162
)
191163
return (
192164
cert.sign(priv_key, hashes.SHA256(), default_backend())
@@ -231,7 +203,7 @@ def __init__(
231203
"PRIVATE": "10.0.0.1",
232204
},
233205
cert_before: datetime = datetime.datetime.now(datetime.timezone.utc),
234-
cert_expiry: datetime = datetime.datetime.now(datetime.timezone.utc)
206+
cert_expiration: datetime = datetime.datetime.now(datetime.timezone.utc)
235207
+ datetime.timedelta(hours=1),
236208
) -> None:
237209
self.project = project
@@ -240,10 +212,10 @@ def __init__(
240212
self.db_version = db_version
241213
self.ip_addrs = ip_addrs
242214
self.cert_before = cert_before
243-
self.cert_expiry = cert_expiry
215+
self.cert_expiration = cert_expiration
244216
# create self signed CA cert
245217
self.server_ca, self.server_key = generate_cert(
246-
self.project, self.name, cert_before, cert_expiry
218+
self.project, self.name, cert_before, cert_expiration
247219
)
248220
self.server_cert = self.server_ca.sign(self.server_key, hashes.SHA256())
249221
self.server_cert_pem = self.server_cert.public_bytes(
@@ -257,7 +229,7 @@ async def connect_settings(self, request: Any) -> web.Response:
257229
"serverCaCert": {
258230
"cert": self.server_cert_pem,
259231
"instance": self.name,
260-
"expirationTime": str(self.server_cert.not_valid_after_utc),
232+
"expirationTime": str(self.cert_expiration),
261233
},
262234
"dnsName": "abcde.12345.us-central1.sql.goog",
263235
"ipAddresses": ip_addrs,
@@ -273,13 +245,17 @@ async def generate_ephemeral(self, request: Any) -> web.Response:
273245
pub_key.encode("UTF-8"), default_backend()
274246
) # type: ignore
275247
ephemeral_cert = client_key_signed_cert(
276-
self.server_ca, self.server_key, client_key
248+
self.server_ca,
249+
self.server_key,
250+
client_key,
251+
self.cert_before,
252+
self.cert_expiration,
277253
)
278254
response = {
279255
"ephemeralCert": {
280256
"kind": "sql#sslCert",
281257
"cert": ephemeral_cert,
282-
"expirationTime": str(self.server_cert.not_valid_after_utc),
258+
"expirationTime": str(self.cert_expiration),
283259
}
284260
}
285261
return web.Response(content_type="application/json", body=json.dumps(response))

0 commit comments

Comments
 (0)