Skip to content

Commit 7231c57

Browse files
refactor: add ConnectionName to ConnectionInfo class (#1212)
Adding ConnectionName to ConnectionInfo class Benefits: - Cleaner interface, passing connection name around instead of individual args (project, region, instance) - Consistent debug logs (give access to ConnectionName throughout) - Allows us to tell if ConnectionInfo is using DNS (if ConnectionName.domain_name is set)
1 parent 5af7582 commit 7231c57

File tree

8 files changed

+32
-48
lines changed

8 files changed

+32
-48
lines changed

google/cloud/sql/connector/client.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from google.auth.transport import requests
2727

2828
from google.cloud.sql.connector.connection_info import ConnectionInfo
29+
from google.cloud.sql.connector.connection_name import ConnectionName
2930
from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported
3031
from google.cloud.sql.connector.refresh_utils import _downscope_credentials
3132
from google.cloud.sql.connector.refresh_utils import retry_50x
@@ -245,20 +246,16 @@ async def _get_ephemeral(
245246

246247
async def get_connection_info(
247248
self,
248-
project: str,
249-
region: str,
250-
instance: str,
249+
conn_name: ConnectionName,
251250
keys: asyncio.Future,
252251
enable_iam_auth: bool,
253252
) -> ConnectionInfo:
254253
"""Immediately performs a full refresh operation using the Cloud SQL
255254
Admin API.
256255
257256
Args:
258-
project (str): The name of the project the Cloud SQL instance is
259-
located in.
260-
region (str): The region the Cloud SQL instance is located in.
261-
instance (str): Name of the Cloud SQL instance.
257+
conn_name (ConnectionName): The Cloud SQL instance's
258+
connection name.
262259
keys (asyncio.Future): A future to the client's public-private key
263260
pair.
264261
enable_iam_auth (bool): Whether an automatic IAM database
@@ -278,16 +275,16 @@ async def get_connection_info(
278275

279276
metadata_task = asyncio.create_task(
280277
self._get_metadata(
281-
project,
282-
region,
283-
instance,
278+
conn_name.project,
279+
conn_name.region,
280+
conn_name.instance_name,
284281
)
285282
)
286283

287284
ephemeral_task = asyncio.create_task(
288285
self._get_ephemeral(
289-
project,
290-
instance,
286+
conn_name.project,
287+
conn_name.instance_name,
291288
pub_key,
292289
enable_iam_auth,
293290
)
@@ -311,6 +308,7 @@ async def get_connection_info(
311308
ephemeral_cert, expiration = await ephemeral_task
312309

313310
return ConnectionInfo(
311+
conn_name,
314312
ephemeral_cert,
315313
metadata["server_ca_cert"],
316314
priv_key,

google/cloud/sql/connector/connection_info.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from aiofiles.tempfile import TemporaryDirectory
2323

24+
from google.cloud.sql.connector.connection_name import ConnectionName
2425
from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError
2526
from google.cloud.sql.connector.exceptions import TLSVersionError
2627
from google.cloud.sql.connector.utils import write_to_file
@@ -38,6 +39,7 @@ class ConnectionInfo:
3839
"""Contains all necessary information to connect securely to the
3940
server-side Proxy running on a Cloud SQL instance."""
4041

42+
conn_name: ConnectionName
4143
client_cert: str
4244
server_ca_cert: str
4345
private_key: bytes

google/cloud/sql/connector/connector.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ def __init__(
113113
name. To resolve a DNS record to an instance connection name, use
114114
DnsResolver.
115115
Default: DefaultResolver
116-
117116
"""
118117
# if refresh_strategy is str, convert to RefreshStrategy enum
119118
if isinstance(refresh_strategy, str):
@@ -283,8 +282,7 @@ async def connect_async(
283282
conn_name = await self._resolver.resolve(instance_connection_string)
284283
if self._refresh_strategy == RefreshStrategy.LAZY:
285284
logger.debug(
286-
f"['{instance_connection_string}']: Refresh strategy is set"
287-
" to lazy refresh"
285+
f"['{conn_name}']: Refresh strategy is set to lazy refresh"
288286
)
289287
cache = LazyRefreshCache(
290288
conn_name,
@@ -294,18 +292,15 @@ async def connect_async(
294292
)
295293
else:
296294
logger.debug(
297-
f"['{instance_connection_string}']: Refresh strategy is set"
298-
" to backgound refresh"
295+
f"['{conn_name}']: Refresh strategy is set to backgound refresh"
299296
)
300297
cache = RefreshAheadCache(
301298
conn_name,
302299
self._client,
303300
self._keys,
304301
enable_iam_auth,
305302
)
306-
logger.debug(
307-
f"['{instance_connection_string}']: Connection info added to cache"
308-
)
303+
logger.debug(f"['{conn_name}']: Connection info added to cache")
309304
self._cache[(instance_connection_string, enable_iam_auth)] = cache
310305

311306
connect_func = {
@@ -344,9 +339,7 @@ async def connect_async(
344339
# the cache and re-raise the error
345340
await self._remove_cached(instance_connection_string, enable_iam_auth)
346341
raise
347-
logger.debug(
348-
f"['{instance_connection_string}']: Connecting to {ip_address}:3307"
349-
)
342+
logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307")
350343
# format `user` param for automatic IAM database authn
351344
if enable_iam_auth:
352345
formatted_user = format_database_user(

google/cloud/sql/connector/instance.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,6 @@ def __init__(
6262
(Postgres and MySQL) as the default authentication method for all
6363
connections.
6464
"""
65-
self._project, self._region, self._instance = (
66-
conn_name.project,
67-
conn_name.region,
68-
conn_name.instance_name,
69-
)
7065
self._conn_name = conn_name
7166

7267
self._enable_iam_auth = enable_iam_auth
@@ -104,20 +99,18 @@ async def _perform_refresh(self) -> ConnectionInfo:
10499
"""
105100
self._refresh_in_progress.set()
106101
logger.debug(
107-
f"['{self._conn_name}']: Connection info refresh " "operation started"
102+
f"['{self._conn_name}']: Connection info refresh operation started"
108103
)
109104

110105
try:
111106
await self._refresh_rate_limiter.acquire()
112107
connection_info = await self._client.get_connection_info(
113-
self._project,
114-
self._region,
115-
self._instance,
108+
self._conn_name,
116109
self._keys,
117110
self._enable_iam_auth,
118111
)
119112
logger.debug(
120-
f"['{self._conn_name}']: Connection info " "refresh operation complete"
113+
f"['{self._conn_name}']: Connection info refresh operation complete"
121114
)
122115
logger.debug(
123116
f"['{self._conn_name}']: Current certificate "

google/cloud/sql/connector/lazy.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,7 @@ def __init__(
5555
(Postgres and MySQL) as the default authentication method for all
5656
connections.
5757
"""
58-
self._project, self._region, self._instance = (
59-
conn_name.project,
60-
conn_name.region,
61-
conn_name.instance_name,
62-
)
6358
self._conn_name = conn_name
64-
6559
self._enable_iam_auth = enable_iam_auth
6660
self._keys = keys
6761
self._client = client
@@ -101,9 +95,7 @@ async def connect_info(self) -> ConnectionInfo:
10195
)
10296
try:
10397
conn_info = await self._client.get_connection_info(
104-
self._project,
105-
self._region,
106-
self._instance,
98+
self._conn_name,
10799
self._keys,
108100
self._enable_iam_auth,
109101
)

google/cloud/sql/connector/resolver.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
import dns.asyncresolver
1616

17+
from google.cloud.sql.connector.connection_name import (
18+
_parse_connection_name_with_domain_name,
19+
)
1720
from google.cloud.sql.connector.connection_name import _parse_connection_name
1821
from google.cloud.sql.connector.connection_name import ConnectionName
1922
from google.cloud.sql.connector.exceptions import DnsResolutionError
@@ -52,7 +55,7 @@ async def query_dns(self, dns: str) -> ConnectionName:
5255
# Attempt to parse records, returning the first valid record.
5356
for record in rdata:
5457
try:
55-
conn_name = _parse_connection_name(record)
58+
conn_name = _parse_connection_name_with_domain_name(record, dns)
5659
return conn_name
5760
except Exception:
5861
continue

tests/unit/test_instance.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ async def test_Instance_init(
4747
can tell if the connection string that's passed in is formatted correctly.
4848
"""
4949
assert (
50-
cache._project == "test-project"
51-
and cache._region == "test-region"
52-
and cache._instance == "test-instance"
50+
cache._conn_name.project == "test-project"
51+
and cache._conn_name.region == "test-region"
52+
and cache._conn_name.instance_name == "test-instance"
5353
)
5454
assert cache._enable_iam_auth is False
5555

@@ -283,7 +283,7 @@ async def test_AutoIAMAuthNotSupportedError(fake_client: CloudSQLClient) -> None
283283

284284
async def test_ConnectionInfo_caches_sslcontext() -> None:
285285
info = ConnectionInfo(
286-
"cert", "cert", "key".encode(), {}, "POSTGRES", datetime.datetime.now()
286+
"", "cert", "cert", "key".encode(), {}, "POSTGRES", datetime.datetime.now()
287287
)
288288
# context should default to None
289289
assert info.context is None

tests/unit/test_resolver.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626

2727
conn_str = "my-project:my-region:my-instance"
2828
conn_name = ConnectionName("my-project", "my-region", "my-instance")
29+
conn_name_with_domain = ConnectionName(
30+
"my-project", "my-region", "my-instance", "db.example.com"
31+
)
2932

3033

3134
async def test_DefaultResolver() -> None:
@@ -74,7 +77,7 @@ async def test_DnsResolver_with_dns_name() -> None:
7477
resolver.port = 5053
7578
# Resolution should return first value sorted alphabetically
7679
result = await resolver.resolve("db.example.com")
77-
assert result == conn_name
80+
assert result == conn_name_with_domain
7881

7982

8083
query_text_malformed = """id 1234

0 commit comments

Comments
 (0)