1717from __future__ import annotations
1818
1919import asyncio
20+ from dataclasses import dataclass
2021from enum import Enum
2122import logging
2223import re
2930from google .cloud .sql .connector .client import CloudSQLClient
3031from google .cloud .sql .connector .exceptions import AutoIAMAuthNotSupported
3132from google .cloud .sql .connector .exceptions import CloudSQLIPTypeError
33+ from google .cloud .sql .connector .exceptions import RefreshNotValidError
3234from google .cloud .sql .connector .exceptions import TLSVersionError
3335from google .cloud .sql .connector .rate_limiter import AsyncRateLimiter
3436from google .cloud .sql .connector .refresh_utils import _is_valid
@@ -79,33 +81,39 @@ def _from_str(cls, ip_type_str: str) -> IPTypes:
7981 return cls (ip_type_str .upper ())
8082
8183
84+ @dataclass
8285class ConnectionInfo :
86+ """Contains all necessary information to connect securely to the
87+ server-side Proxy running on a Cloud SQL instance."""
88+
89+ client_cert : str
90+ server_ca_cert : str
91+ private_key : bytes
8392 ip_addrs : Dict [str , Any ]
84- context : ssl .SSLContext
8593 database_version : str
8694 expiration : datetime .datetime
95+ context : ssl .SSLContext | None = None
8796
88- def __init__ (
89- self ,
90- ephemeral_cert : str ,
91- database_version : str ,
92- ip_addrs : Dict [str , Any ],
93- private_key : bytes ,
94- server_ca_cert : str ,
95- expiration : datetime .datetime ,
96- enable_iam_auth : bool ,
97- ) -> None :
98- self .ip_addrs = ip_addrs
99- self .database_version = database_version
100- self .context = ssl .SSLContext (ssl .PROTOCOL_TLS_CLIENT )
97+ def create_ssl_context (self , enable_iam_auth : bool = False ) -> ssl .SSLContext :
98+ """Constructs a SSL/TLS context for the given connection info.
99+
100+ Cache the SSL context to ensure we don't read from disk repeatedly when
101+ configuring a secure connection.
102+ """
103+ # if SSL context is cached, use it
104+ if self .context is not None :
105+ return self .context
106+ context = ssl .SSLContext (ssl .PROTOCOL_TLS_CLIENT )
101107
102108 # update ssl.PROTOCOL_TLS_CLIENT default
103- self . context .check_hostname = False
109+ context .check_hostname = False
104110
111+ # TODO: remove if/else when Python 3.10 is min version. PEP 644 has been
112+ # implemented. The ssl module requires OpenSSL 1.1.1 or newer.
105113 # verify OpenSSL version supports TLSv1.3
106114 if ssl .HAS_TLSv1_3 :
107115 # force TLSv1.3 if supported by client
108- self . context .minimum_version = ssl .TLSVersion .TLSv1_3
116+ context .minimum_version = ssl .TLSVersion .TLSv1_3
109117 # fallback to TLSv1.2 for older versions of OpenSSL
110118 else :
111119 if enable_iam_auth :
@@ -119,18 +127,20 @@ def __init__(
119127 f"({ ssl .OPENSSL_VERSION } ), falling back to TLSv1.2\n "
120128 "Upgrade your OpenSSL version to 1.1.1 for TLSv1.3 support."
121129 )
122- self .context .minimum_version = ssl .TLSVersion .TLSv1_2
123- self .expiration = expiration
130+ context .minimum_version = ssl .TLSVersion .TLSv1_2
124131
125132 # tmpdir and its contents are automatically deleted after the CA cert
126133 # and ephemeral cert are loaded into the SSLcontext. The values
127134 # need to be written to files in order to be loaded by the SSLContext
128135 with TemporaryDirectory () as tmpdir :
129136 ca_filename , cert_filename , key_filename = write_to_file (
130- tmpdir , server_ca_cert , ephemeral_cert , private_key
137+ tmpdir , self . server_ca_cert , self . client_cert , self . private_key
131138 )
132- self .context .load_cert_chain (cert_filename , keyfile = key_filename )
133- self .context .load_verify_locations (cafile = ca_filename )
139+ context .load_cert_chain (cert_filename , keyfile = key_filename )
140+ context .load_verify_locations (cafile = ca_filename )
141+ # set class attribute to cache context for subsequent calls
142+ self .context = context
143+ return context
134144
135145 def get_preferred_ip (self , ip_type : IPTypes ) -> str :
136146 """Returns the first IP address for the instance, according to the preference
@@ -272,12 +282,11 @@ async def _perform_refresh(self) -> ConnectionInfo:
272282
273283 return ConnectionInfo (
274284 ephemeral_cert ,
275- metadata ["database_version" ],
276- metadata ["ip_addresses" ],
277- priv_key ,
278285 metadata ["server_ca_cert" ],
286+ priv_key ,
287+ metadata ["ip_addresses" ],
288+ metadata ["database_version" ],
279289 expiration ,
280- self ._enable_iam_auth ,
281290 )
282291
283292 def _schedule_refresh (self , delay : int ) -> asyncio .Task :
@@ -303,6 +312,11 @@ async def _refresh_task(self: RefreshAheadCache, delay: int) -> ConnectionInfo:
303312 await asyncio .sleep (delay )
304313 refresh_task = asyncio .create_task (self ._perform_refresh ())
305314 refresh_data = await refresh_task
315+ # check that refresh is valid
316+ if not await _is_valid (refresh_task ):
317+ raise RefreshNotValidError (
318+ f"['{ self ._instance_connection_string } ']: Invalid refresh operation. Certficate appears to be expired."
319+ )
306320 except asyncio .CancelledError :
307321 logger .debug (
308322 f"['{ self ._instance_connection_string } ']: Schedule refresh task cancelled."
0 commit comments