Skip to content

Commit e52c2d9

Browse files
refactor: move ConnectionInfo to own file and add get_connection_info (#1090)
Refactor revolving around ConnectionInfo in preparation for adding a LazyRefreshCache. Moving ConnectionInfo into its own file, connection_info.py as it is to be shared by both RefreshAheadCache and LazyRefreshCache. Adding a get_connection_info method to the CloudSQLClient. This is the equivalent of the Go Connector's refresher.ConnectInfo. This will allow the lazy refresh to just check expiration and then call get_connection_info and not need to duplicate code from refresh ahead cache.
1 parent 9fbe87a commit e52c2d9

File tree

4 files changed

+196
-134
lines changed

4 files changed

+196
-134
lines changed

google/cloud/sql/connector/client.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,19 @@
1414

1515
from __future__ import annotations
1616

17+
import asyncio
1718
import datetime
1819
import logging
1920
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
2021

2122
import aiohttp
2223
from cryptography.hazmat.backends import default_backend
2324
from cryptography.x509 import load_pem_x509_certificate
25+
from google.auth.credentials import TokenState
26+
from google.auth.transport import requests
2427

28+
from google.cloud.sql.connector.connection_info import ConnectionInfo
29+
from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported
2530
from google.cloud.sql.connector.refresh_utils import _downscope_credentials
2631
from google.cloud.sql.connector.version import __version__ as version
2732

@@ -212,6 +217,82 @@ async def _get_ephemeral(
212217
expiration = token_expiration
213218
return ephemeral_cert, expiration
214219

220+
async def get_connection_info(
221+
self,
222+
project: str,
223+
region: str,
224+
instance: str,
225+
keys: asyncio.Future,
226+
enable_iam_auth: bool,
227+
) -> ConnectionInfo:
228+
"""Immediately performs a full refresh operation using the Cloud SQL
229+
Admin API.
230+
231+
Args:
232+
project (str): The name of the project the Cloud SQL instance is
233+
located in.
234+
region (str): The region the Cloud SQL instance is located in.
235+
instance (str): Name of the Cloud SQL instance.
236+
keys (asyncio.Future): A future to the client's public-private key
237+
pair.
238+
enable_iam_auth (bool): Whether an automatic IAM database
239+
authentication connection is being requested (Postgres and MySQL).
240+
241+
Returns:
242+
ConnectionInfo: All the information required to connect securely to
243+
the Cloud SQL instance.
244+
Raises:
245+
AutoIAMAuthNotSupported: Database engine does not support automatic
246+
IAM authentication.
247+
"""
248+
priv_key, pub_key = await keys
249+
# before making Cloud SQL Admin API calls, refresh creds if required
250+
if not self._credentials.token_state == TokenState.FRESH:
251+
self._credentials.refresh(requests.Request())
252+
253+
metadata_task = asyncio.create_task(
254+
self._get_metadata(
255+
project,
256+
region,
257+
instance,
258+
)
259+
)
260+
261+
ephemeral_task = asyncio.create_task(
262+
self._get_ephemeral(
263+
project,
264+
instance,
265+
pub_key,
266+
enable_iam_auth,
267+
)
268+
)
269+
try:
270+
metadata = await metadata_task
271+
# check if automatic IAM database authn is supported for database engine
272+
if enable_iam_auth and not metadata["database_version"].startswith(
273+
("POSTGRES", "MYSQL")
274+
):
275+
raise AutoIAMAuthNotSupported(
276+
f"'{metadata['database_version']}' does not support "
277+
"automatic IAM authentication. It is only supported with "
278+
"Cloud SQL Postgres or MySQL instances."
279+
)
280+
except Exception:
281+
# cancel ephemeral cert task if exception occurs before it is awaited
282+
ephemeral_task.cancel()
283+
raise
284+
285+
ephemeral_cert, expiration = await ephemeral_task
286+
287+
return ConnectionInfo(
288+
ephemeral_cert,
289+
metadata["server_ca_cert"],
290+
priv_key,
291+
metadata["ip_addresses"],
292+
metadata["database_version"],
293+
expiration,
294+
)
295+
215296
async def close(self) -> None:
216297
"""Close CloudSQLClient gracefully."""
217298
await self._client.close()
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from dataclasses import dataclass
18+
import logging
19+
import ssl
20+
from tempfile import TemporaryDirectory
21+
from typing import Any, Dict, Optional, TYPE_CHECKING
22+
23+
from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError
24+
from google.cloud.sql.connector.exceptions import TLSVersionError
25+
from google.cloud.sql.connector.utils import write_to_file
26+
27+
if TYPE_CHECKING:
28+
import datetime
29+
30+
from google.cloud.sql.connector.instance import IPTypes
31+
32+
logger = logging.getLogger(name=__name__)
33+
34+
35+
@dataclass
36+
class ConnectionInfo:
37+
"""Contains all necessary information to connect securely to the
38+
server-side Proxy running on a Cloud SQL instance."""
39+
40+
client_cert: str
41+
server_ca_cert: str
42+
private_key: bytes
43+
ip_addrs: Dict[str, Any]
44+
database_version: str
45+
expiration: datetime.datetime
46+
context: Optional[ssl.SSLContext] = None
47+
48+
def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLContext:
49+
"""Constructs a SSL/TLS context for the given connection info.
50+
51+
Cache the SSL context to ensure we don't read from disk repeatedly when
52+
configuring a secure connection.
53+
"""
54+
# if SSL context is cached, use it
55+
if self.context is not None:
56+
return self.context
57+
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
58+
59+
# update ssl.PROTOCOL_TLS_CLIENT default
60+
context.check_hostname = False
61+
62+
# TODO: remove if/else when Python 3.10 is min version. PEP 644 has been
63+
# implemented. The ssl module requires OpenSSL 1.1.1 or newer.
64+
# verify OpenSSL version supports TLSv1.3
65+
if ssl.HAS_TLSv1_3:
66+
# force TLSv1.3 if supported by client
67+
context.minimum_version = ssl.TLSVersion.TLSv1_3
68+
# fallback to TLSv1.2 for older versions of OpenSSL
69+
else:
70+
if enable_iam_auth:
71+
raise TLSVersionError(
72+
f"Your current version of OpenSSL ({ssl.OPENSSL_VERSION}) does not "
73+
"support TLSv1.3, which is required to use IAM Authentication.\n"
74+
"Upgrade your OpenSSL version to 1.1.1 for TLSv1.3 support."
75+
)
76+
logger.warning(
77+
"TLSv1.3 is not supported with your version of OpenSSL "
78+
f"({ssl.OPENSSL_VERSION}), falling back to TLSv1.2\n"
79+
"Upgrade your OpenSSL version to 1.1.1 for TLSv1.3 support."
80+
)
81+
context.minimum_version = ssl.TLSVersion.TLSv1_2
82+
83+
# tmpdir and its contents are automatically deleted after the CA cert
84+
# and ephemeral cert are loaded into the SSLcontext. The values
85+
# need to be written to files in order to be loaded by the SSLContext
86+
with TemporaryDirectory() as tmpdir:
87+
ca_filename, cert_filename, key_filename = write_to_file(
88+
tmpdir, self.server_ca_cert, self.client_cert, self.private_key
89+
)
90+
context.load_cert_chain(cert_filename, keyfile=key_filename)
91+
context.load_verify_locations(cafile=ca_filename)
92+
# set class attribute to cache context for subsequent calls
93+
self.context = context
94+
return context
95+
96+
def get_preferred_ip(self, ip_type: IPTypes) -> str:
97+
"""Returns the first IP address for the instance, according to the preference
98+
supplied by ip_type. If no IP addressess with the given preference are found,
99+
an error is raised."""
100+
if ip_type.value in self.ip_addrs:
101+
return self.ip_addrs[ip_type.value]
102+
raise CloudSQLIPTypeError(
103+
"Cloud SQL instance does not have any IP addresses matching "
104+
f"preference: {ip_type.value})"
105+
)

google/cloud/sql/connector/instance.py

Lines changed: 9 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,19 @@
1717
from __future__ import annotations
1818

1919
import asyncio
20-
from dataclasses import dataclass
2120
from enum import Enum
2221
import logging
2322
import re
24-
import ssl
25-
from tempfile import TemporaryDirectory
26-
from typing import Any, Dict, Tuple, TYPE_CHECKING
23+
from typing import Tuple
2724

2825
import aiohttp
29-
from google.auth.credentials import TokenState
30-
from google.auth.transport import requests
3126

3227
from google.cloud.sql.connector.client import CloudSQLClient
33-
from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported
34-
from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError
28+
from google.cloud.sql.connector.connection_info import ConnectionInfo
3529
from google.cloud.sql.connector.exceptions import RefreshNotValidError
36-
from google.cloud.sql.connector.exceptions import TLSVersionError
3730
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
3831
from google.cloud.sql.connector.refresh_utils import _is_valid
3932
from google.cloud.sql.connector.refresh_utils import _seconds_until_refresh
40-
from google.cloud.sql.connector.utils import write_to_file
41-
42-
if TYPE_CHECKING:
43-
import datetime
4433

4534
logger = logging.getLogger(name=__name__)
4635

@@ -83,79 +72,6 @@ def _from_str(cls, ip_type_str: str) -> IPTypes:
8372
return cls(ip_type_str.upper())
8473

8574

86-
@dataclass
87-
class ConnectionInfo:
88-
"""Contains all necessary information to connect securely to the
89-
server-side Proxy running on a Cloud SQL instance."""
90-
91-
client_cert: str
92-
server_ca_cert: str
93-
private_key: bytes
94-
ip_addrs: Dict[str, Any]
95-
database_version: str
96-
expiration: datetime.datetime
97-
context: ssl.SSLContext | None = None
98-
99-
def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLContext:
100-
"""Constructs a SSL/TLS context for the given connection info.
101-
102-
Cache the SSL context to ensure we don't read from disk repeatedly when
103-
configuring a secure connection.
104-
"""
105-
# if SSL context is cached, use it
106-
if self.context is not None:
107-
return self.context
108-
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
109-
110-
# update ssl.PROTOCOL_TLS_CLIENT default
111-
context.check_hostname = False
112-
113-
# TODO: remove if/else when Python 3.10 is min version. PEP 644 has been
114-
# implemented. The ssl module requires OpenSSL 1.1.1 or newer.
115-
# verify OpenSSL version supports TLSv1.3
116-
if ssl.HAS_TLSv1_3:
117-
# force TLSv1.3 if supported by client
118-
context.minimum_version = ssl.TLSVersion.TLSv1_3
119-
# fallback to TLSv1.2 for older versions of OpenSSL
120-
else:
121-
if enable_iam_auth:
122-
raise TLSVersionError(
123-
f"Your current version of OpenSSL ({ssl.OPENSSL_VERSION}) does not "
124-
"support TLSv1.3, which is required to use IAM Authentication.\n"
125-
"Upgrade your OpenSSL version to 1.1.1 for TLSv1.3 support."
126-
)
127-
logger.warning(
128-
"TLSv1.3 is not supported with your version of OpenSSL "
129-
f"({ssl.OPENSSL_VERSION}), falling back to TLSv1.2\n"
130-
"Upgrade your OpenSSL version to 1.1.1 for TLSv1.3 support."
131-
)
132-
context.minimum_version = ssl.TLSVersion.TLSv1_2
133-
134-
# tmpdir and its contents are automatically deleted after the CA cert
135-
# and ephemeral cert are loaded into the SSLcontext. The values
136-
# need to be written to files in order to be loaded by the SSLContext
137-
with TemporaryDirectory() as tmpdir:
138-
ca_filename, cert_filename, key_filename = write_to_file(
139-
tmpdir, self.server_ca_cert, self.client_cert, self.private_key
140-
)
141-
context.load_cert_chain(cert_filename, keyfile=key_filename)
142-
context.load_verify_locations(cafile=ca_filename)
143-
# set class attribute to cache context for subsequent calls
144-
self.context = context
145-
return context
146-
147-
def get_preferred_ip(self, ip_type: IPTypes) -> str:
148-
"""Returns the first IP address for the instance, according to the preference
149-
supplied by ip_type. If no IP addressess with the given preference are found,
150-
an error is raised."""
151-
if ip_type.value in self.ip_addrs:
152-
return self.ip_addrs[ip_type.value]
153-
raise CloudSQLIPTypeError(
154-
"Cloud SQL instance does not have any IP addresses matching "
155-
f"preference: {ip_type.value})"
156-
)
157-
158-
15975
class RefreshAheadCache:
16076
"""Cache that refreshes connection info in the background prior to expiration.
16177
@@ -229,45 +145,13 @@ async def _perform_refresh(self) -> ConnectionInfo:
229145

230146
try:
231147
await self._refresh_rate_limiter.acquire()
232-
priv_key, pub_key = await self._keys
233-
234-
logger.debug(f"['{self._instance_connection_string}']: Creating context")
235-
236-
# before making Cloud SQL Admin API calls, refresh creds
237-
if not self._client._credentials.token_state == TokenState.FRESH:
238-
self._client._credentials.refresh(requests.Request())
239-
240-
metadata_task = asyncio.create_task(
241-
self._client._get_metadata(
242-
self._project,
243-
self._region,
244-
self._instance,
245-
)
246-
)
247-
248-
ephemeral_task = asyncio.create_task(
249-
self._client._get_ephemeral(
250-
self._project,
251-
self._instance,
252-
pub_key,
253-
self._enable_iam_auth,
254-
)
148+
connection_info = await self._client.get_connection_info(
149+
self._project,
150+
self._region,
151+
self._instance,
152+
self._keys,
153+
self._enable_iam_auth,
255154
)
256-
try:
257-
metadata = await metadata_task
258-
# check if automatic IAM database authn is supported for database engine
259-
if self._enable_iam_auth and not metadata[
260-
"database_version"
261-
].startswith(("POSTGRES", "MYSQL")):
262-
raise AutoIAMAuthNotSupported(
263-
f"'{metadata['database_version']}' does not support automatic IAM authentication. It is only supported with Cloud SQL Postgres or MySQL instances."
264-
)
265-
except Exception:
266-
# cancel ephemeral cert task if exception occurs before it is awaited
267-
ephemeral_task.cancel()
268-
raise
269-
270-
ephemeral_cert, expiration = await ephemeral_task
271155

272156
except aiohttp.ClientResponseError as e:
273157
logger.debug(
@@ -285,15 +169,7 @@ async def _perform_refresh(self) -> ConnectionInfo:
285169

286170
finally:
287171
self._refresh_in_progress.clear()
288-
289-
return ConnectionInfo(
290-
ephemeral_cert,
291-
metadata["server_ca_cert"],
292-
priv_key,
293-
metadata["ip_addresses"],
294-
metadata["database_version"],
295-
expiration,
296-
)
172+
return connection_info
297173

298174
def _schedule_refresh(self, delay: int) -> asyncio.Task:
299175
"""

0 commit comments

Comments
 (0)