Skip to content

Commit aad7ff5

Browse files
refactor: use one ClientSession per Connector (#1007)
1 parent fb58373 commit aad7ff5

File tree

12 files changed

+564
-700
lines changed

12 files changed

+564
-700
lines changed
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import datetime
18+
import logging
19+
from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
20+
21+
import aiohttp
22+
from cryptography.hazmat.backends import default_backend
23+
from cryptography.x509 import load_pem_x509_certificate
24+
import google.auth.transport.requests
25+
26+
from google.cloud.sql.connector.refresh_utils import _downscope_credentials
27+
from google.cloud.sql.connector.version import __version__ as version
28+
29+
if TYPE_CHECKING:
30+
from google.auth.credentials import Credentials
31+
32+
USER_AGENT: str = f"cloud-sql-python-connector/{version}"
33+
API_VERSION: str = "v1beta4"
34+
35+
logger = logging.getLogger(name=__name__)
36+
37+
38+
def _format_user_agent(driver: Optional[str], custom: Optional[str]) -> str:
39+
agent = f"{USER_AGENT}+{driver}" if driver else USER_AGENT
40+
if custom:
41+
agent = f"{agent} {custom}"
42+
return agent
43+
44+
45+
class CloudSQLClient:
46+
def __init__(
47+
self,
48+
sqladmin_api_endpoint: str,
49+
quota_project: Optional[str],
50+
credentials: Credentials,
51+
client: Optional[aiohttp.ClientSession] = None,
52+
driver: Optional[str] = None,
53+
user_agent: Optional[str] = None,
54+
) -> None:
55+
"""
56+
Establish the client to be used for Cloud SQL Admin API requests.
57+
58+
Args:
59+
sqladmin_api_endpoint (str): Base URL to use when calling
60+
the Cloud SQL Admin API endpoints.
61+
quota_project (str): The Project ID for an existing Google Cloud
62+
project. The project specified is used for quota and
63+
billing purposes.
64+
credentials (google.auth.credentials.Credentials):
65+
A credentials object created from the google-auth Python library.
66+
Must have the Cloud SQL Admin scopes. For more info check out
67+
https://google-auth.readthedocs.io/en/latest/.
68+
client (aiohttp.ClientSession): Async client used to make requests to
69+
Cloud SQL Admin APIs.
70+
Optional, defaults to None and creates new client.
71+
driver (str): Database driver to be used by the client.
72+
"""
73+
user_agent = _format_user_agent(driver, user_agent)
74+
headers = {
75+
"x-goog-api-client": user_agent,
76+
"User-Agent": user_agent,
77+
"Content-Type": "application/json",
78+
}
79+
if quota_project:
80+
headers["x-goog-user-project"] = quota_project
81+
82+
self._client = client if client else aiohttp.ClientSession(headers=headers)
83+
self._credentials = credentials
84+
self._sqladmin_api_endpoint = sqladmin_api_endpoint
85+
self._user_agent = user_agent
86+
87+
async def _get_metadata(
88+
self,
89+
project: str,
90+
region: str,
91+
instance: str,
92+
) -> Dict[str, Any]:
93+
"""Requests metadata from the Cloud SQL Instance
94+
and returns a dictionary containing the IP addresses and certificate
95+
authority of the Cloud SQL Instance.
96+
97+
:type project: str
98+
:param project:
99+
A string representing the name of the project.
100+
101+
:type region: str
102+
:param region : A string representing the name of the region.
103+
104+
:type instance: str
105+
:param instance: A string representing the name of the instance.
106+
107+
:rtype: Dict[str: Union[Dict, str]]
108+
:returns: Returns a dictionary containing a dictionary of all IP
109+
addresses and their type and a string representing the
110+
certificate authority.
111+
"""
112+
if not self._credentials.valid:
113+
request = google.auth.transport.requests.Request()
114+
self._credentials.refresh(request)
115+
116+
headers = {
117+
"Authorization": f"Bearer {self._credentials.token}",
118+
}
119+
120+
url = f"{self._sqladmin_api_endpoint}/sql/{API_VERSION}/projects/{project}/instances/{instance}/connectSettings"
121+
122+
logger.debug(f"['{instance}']: Requesting metadata")
123+
124+
resp = await self._client.get(url, headers=headers, raise_for_status=True)
125+
ret_dict = await resp.json()
126+
127+
if ret_dict["region"] != region:
128+
raise ValueError(
129+
f'[{project}:{region}:{instance}]: Provided region was mismatched - got region {region}, expected {ret_dict["region"]}.'
130+
)
131+
132+
ip_addresses = (
133+
{ip["type"]: ip["ipAddress"] for ip in ret_dict["ipAddresses"]}
134+
if "ipAddresses" in ret_dict
135+
else {}
136+
)
137+
if "dnsName" in ret_dict:
138+
ip_addresses["PSC"] = ret_dict["dnsName"]
139+
140+
return {
141+
"ip_addresses": ip_addresses,
142+
"server_ca_cert": ret_dict["serverCaCert"]["cert"],
143+
"database_version": ret_dict["databaseVersion"],
144+
}
145+
146+
async def _get_ephemeral(
147+
self,
148+
project: str,
149+
instance: str,
150+
pub_key: str,
151+
enable_iam_auth: bool = False,
152+
) -> Tuple[str, datetime.datetime]:
153+
"""Asynchronously requests an ephemeral certificate from the Cloud SQL Instance.
154+
155+
:type project: str
156+
:param project : A string representing the name of the project.
157+
158+
:type instance: str
159+
:param instance: A string representing the name of the instance.
160+
161+
:type pub_key:
162+
:param str: A string representing PEM-encoded RSA public key.
163+
164+
:type enable_iam_auth: bool
165+
:param enable_iam_auth
166+
Enables automatic IAM database authentication for Postgres or MySQL
167+
instances.
168+
169+
:rtype: str
170+
:returns: An ephemeral certificate from the Cloud SQL instance that allows
171+
authorized connections to the instance.
172+
"""
173+
174+
logger.debug(f"['{instance}']: Requesting ephemeral certificate")
175+
176+
if not self._credentials.valid:
177+
request = google.auth.transport.requests.Request()
178+
self._credentials.refresh(request)
179+
180+
headers = {
181+
"Authorization": f"Bearer {self._credentials.token}",
182+
}
183+
184+
url = f"{self._sqladmin_api_endpoint}/sql/{API_VERSION}/projects/{project}/instances/{instance}:generateEphemeralCert"
185+
186+
data = {"public_key": pub_key}
187+
188+
if enable_iam_auth:
189+
# down-scope credentials with only IAM login scope (refreshes them too)
190+
login_creds = _downscope_credentials(self._credentials)
191+
data["access_token"] = login_creds.token
192+
193+
resp = await self._client.post(
194+
url, headers=headers, json=data, raise_for_status=True
195+
)
196+
197+
ret_dict = await resp.json()
198+
199+
ephemeral_cert: str = ret_dict["ephemeralCert"]["cert"]
200+
201+
# decode cert to read expiration
202+
x509 = load_pem_x509_certificate(
203+
ephemeral_cert.encode("UTF-8"), default_backend()
204+
)
205+
expiration = x509.not_valid_after_utc
206+
# for IAM authentication OAuth2 token is embedded in cert so it
207+
# must still be valid for successful connection
208+
if enable_iam_auth:
209+
token_expiration: datetime.datetime = login_creds.expiry
210+
# google.auth library strips timezone info for backwards compatibality
211+
# reasons with Python 2. Add it back to allow timezone aware datetimes.
212+
# Ref: https://github.com/googleapis/google-auth-library-python/blob/49a5ff7411a2ae4d32a7d11700f9f961c55406a9/google/auth/_helpers.py#L93-L99
213+
token_expiration = token_expiration.replace(tzinfo=datetime.timezone.utc)
214+
215+
if expiration > token_expiration:
216+
expiration = token_expiration
217+
return ephemeral_cert, expiration
218+
219+
async def close(self) -> None:
220+
"""Close CloudSQLClient gracefully."""
221+
await self._client.close()

google/cloud/sql/connector/connector.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
from __future__ import annotations
1718

1819
import asyncio
@@ -28,6 +29,7 @@
2829
from google.auth.credentials import with_scopes_if_required
2930

3031
import google.cloud.sql.connector.asyncpg as asyncpg
32+
from google.cloud.sql.connector.client import CloudSQLClient
3133
from google.cloud.sql.connector.exceptions import ConnectorLoopError
3234
from google.cloud.sql.connector.exceptions import DnsNameResolutionError
3335
from google.cloud.sql.connector.instance import Instance
@@ -109,11 +111,12 @@ def __init__(
109111
loop=self._loop,
110112
)
111113
self._instances: Dict[str, Instance] = {}
114+
self._client: Optional[CloudSQLClient] = None
112115

113116
# initialize credentials
114117
scopes = ["https://www.googleapis.com/auth/sqlservice.admin"]
115118
if credentials:
116-
# verfiy custom credentials are proper type
119+
# verify custom credentials are proper type
117120
# and atleast base class of google.auth.credentials
118121
if not isinstance(credentials, Credentials):
119122
raise TypeError(
@@ -207,27 +210,32 @@ async def connect_async(
207210
# Use the Instance to establish an SSL Connection.
208211
#
209212
# Return a DBAPI connection
213+
if self._client is None:
214+
# lazy init client as it has to be initialized in async context
215+
self._client = CloudSQLClient(
216+
self._sqladmin_api_endpoint,
217+
self._quota_project,
218+
self._credentials,
219+
user_agent=self._user_agent,
220+
driver=driver,
221+
)
210222
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
211223
if instance_connection_string in self._instances:
212224
instance = self._instances[instance_connection_string]
213225
if enable_iam_auth != instance._enable_iam_auth:
214226
raise ValueError(
215-
f"connect() called with `enable_iam_auth={enable_iam_auth}`, "
216-
f"but previously used enable_iam_auth={instance._enable_iam_auth}`. "
227+
f"connect() called with 'enable_iam_auth={enable_iam_auth}', "
228+
f"but previously used 'enable_iam_auth={instance._enable_iam_auth}'. "
217229
"If you require both for your use case, please use a new "
218230
"connector.Connector object."
219231
)
220232
else:
221233
instance = Instance(
222234
instance_connection_string,
223-
driver,
235+
self._client,
224236
self._keys,
225237
self._loop,
226-
self._credentials,
227238
enable_iam_auth,
228-
self._quota_project,
229-
self._sqladmin_api_endpoint,
230-
user_agent=self._user_agent,
231239
)
232240
self._instances[instance_connection_string] = instance
233241

@@ -329,8 +337,8 @@ def close(self) -> None:
329337
close_future = asyncio.run_coroutine_threadsafe(
330338
self.close_async(), loop=self._loop
331339
)
332-
# Will attempt to safely shut down tasks for 5s
333-
close_future.result(timeout=5)
340+
# Will attempt to safely shut down tasks for 3s
341+
close_future.result(timeout=3)
334342
# if background thread exists for Connector, clean it up
335343
if self._thread:
336344
if self._loop.is_running():
@@ -345,6 +353,8 @@ async def close_async(self) -> None:
345353
await asyncio.gather(
346354
*[instance.close() for instance in self._instances.values()]
347355
)
356+
if self._client:
357+
await self._client.close()
348358

349359
def __del__(self) -> None:
350360
"""Close Connector as part of garbage collection"""

0 commit comments

Comments
 (0)