Skip to content

Commit ef7d8fe

Browse files
refactor: add ConnectionName class (#1186)
This PR refactors all instance connection name related code into its own file connection_name.py It introduces the ConnectionName class which will make tracking if a DNS name was given to the Connector easier in the future.
1 parent 3b24c10 commit ef7d8fe

File tree

15 files changed

+147
-73
lines changed

15 files changed

+147
-73
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass
16+
import re
17+
18+
# Instance connection name is the format <PROJECT>:<REGION>:<INSTANCE_NAME>
19+
# Additionally, we have to support legacy "domain-scoped" projects
20+
# (e.g. "google.com:PROJECT")
21+
CONN_NAME_REGEX = re.compile(("([^:]+(:[^:]+)?):([^:]+):([^:]+)"))
22+
23+
24+
@dataclass
25+
class ConnectionName:
26+
"""ConnectionName represents a Cloud SQL instance's "instance connection name".
27+
28+
Takes the format "<PROJECT>:<REGION>:<INSTANCE_NAME>".
29+
"""
30+
31+
project: str
32+
region: str
33+
instance_name: str
34+
35+
def __str__(self) -> str:
36+
return f"{self.project}:{self.region}:{self.instance_name}"
37+
38+
39+
def _parse_instance_connection_name(connection_name: str) -> ConnectionName:
40+
if CONN_NAME_REGEX.fullmatch(connection_name) is None:
41+
raise ValueError(
42+
"Arg `instance_connection_string` must have "
43+
"format: PROJECT:REGION:INSTANCE, "
44+
f"got {connection_name}."
45+
)
46+
connection_name_split = CONN_NAME_REGEX.split(connection_name)
47+
return ConnectionName(
48+
connection_name_split[1],
49+
connection_name_split[3],
50+
connection_name_split[4],
51+
)

google/cloud/sql/connector/instance.py

Lines changed: 17 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@
2121
from datetime import timedelta
2222
from datetime import timezone
2323
import logging
24-
import re
2524

2625
import aiohttp
2726

2827
from google.cloud.sql.connector.client import CloudSQLClient
2928
from google.cloud.sql.connector.connection_info import ConnectionInfo
29+
from google.cloud.sql.connector.connection_name import _parse_instance_connection_name
3030
from google.cloud.sql.connector.exceptions import RefreshNotValidError
3131
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
3232
from google.cloud.sql.connector.refresh_utils import _is_valid
@@ -36,22 +36,6 @@
3636

3737
APPLICATION_NAME = "cloud-sql-python-connector"
3838

39-
# Instance connection name is the format <PROJECT>:<REGION>:<INSTANCE>
40-
# Additionally, we have to support legacy "domain-scoped" projects
41-
# (e.g. "google.com:PROJECT")
42-
CONN_NAME_REGEX = re.compile(("([^:]+(:[^:]+)?):([^:]+):([^:]+)"))
43-
44-
45-
def _parse_instance_connection_name(connection_name: str) -> tuple[str, str, str]:
46-
if CONN_NAME_REGEX.fullmatch(connection_name) is None:
47-
raise ValueError(
48-
"Arg `instance_connection_string` must have "
49-
"format: PROJECT:REGION:INSTANCE, "
50-
f"got {connection_name}."
51-
)
52-
connection_name_split = CONN_NAME_REGEX.split(connection_name)
53-
return connection_name_split[1], connection_name_split[3], connection_name_split[4]
54-
5539

5640
class RefreshAheadCache:
5741
"""Cache that refreshes connection info in the background prior to expiration.
@@ -81,10 +65,13 @@ def __init__(
8165
connections.
8266
"""
8367
# validate and parse instance connection name
84-
self._project, self._region, self._instance = _parse_instance_connection_name(
85-
instance_connection_string
68+
conn_name = _parse_instance_connection_name(instance_connection_string)
69+
self._project, self._region, self._instance = (
70+
conn_name.project,
71+
conn_name.region,
72+
conn_name.instance_name,
8673
)
87-
self._instance_connection_string = instance_connection_string
74+
self._conn_name = conn_name
8875

8976
self._enable_iam_auth = enable_iam_auth
9077
self._keys = keys
@@ -121,8 +108,7 @@ async def _perform_refresh(self) -> ConnectionInfo:
121108
"""
122109
self._refresh_in_progress.set()
123110
logger.debug(
124-
f"['{self._instance_connection_string}']: Connection info refresh "
125-
"operation started"
111+
f"['{self._conn_name}']: Connection info refresh " "operation started"
126112
)
127113

128114
try:
@@ -135,17 +121,16 @@ async def _perform_refresh(self) -> ConnectionInfo:
135121
self._enable_iam_auth,
136122
)
137123
logger.debug(
138-
f"['{self._instance_connection_string}']: Connection info "
139-
"refresh operation complete"
124+
f"['{self._conn_name}']: Connection info " "refresh operation complete"
140125
)
141126
logger.debug(
142-
f"['{self._instance_connection_string}']: Current certificate "
127+
f"['{self._conn_name}']: Current certificate "
143128
f"expiration = {connection_info.expiration.isoformat()}"
144129
)
145130

146131
except aiohttp.ClientResponseError as e:
147132
logger.debug(
148-
f"['{self._instance_connection_string}']: Connection info "
133+
f"['{self._conn_name}']: Connection info "
149134
f"refresh operation failed: {str(e)}"
150135
)
151136
if e.status == 403:
@@ -154,7 +139,7 @@ async def _perform_refresh(self) -> ConnectionInfo:
154139

155140
except Exception as e:
156141
logger.debug(
157-
f"['{self._instance_connection_string}']: Connection info "
142+
f"['{self._conn_name}']: Connection info "
158143
f"refresh operation failed: {str(e)}"
159144
)
160145
raise
@@ -188,18 +173,17 @@ async def _refresh_task(self: RefreshAheadCache, delay: int) -> ConnectionInfo:
188173
# check that refresh is valid
189174
if not await _is_valid(refresh_task):
190175
raise RefreshNotValidError(
191-
f"['{self._instance_connection_string}']: Invalid refresh operation. Certficate appears to be expired."
176+
f"['{self._conn_name}']: Invalid refresh operation. Certficate appears to be expired."
192177
)
193178
except asyncio.CancelledError:
194179
logger.debug(
195-
f"['{self._instance_connection_string}']: Scheduled refresh"
196-
" operation cancelled"
180+
f"['{self._conn_name}']: Scheduled refresh" " operation cancelled"
197181
)
198182
raise
199183
# bad refresh attempt
200184
except Exception as e:
201185
logger.exception(
202-
f"['{self._instance_connection_string}']: "
186+
f"['{self._conn_name}']: "
203187
"An error occurred while performing refresh. "
204188
"Scheduling another refresh attempt immediately",
205189
exc_info=e,
@@ -216,7 +200,7 @@ async def _refresh_task(self: RefreshAheadCache, delay: int) -> ConnectionInfo:
216200
# calculate refresh delay based on certificate expiration
217201
delay = _seconds_until_refresh(refresh_data.expiration)
218202
logger.debug(
219-
f"['{self._instance_connection_string}']: Connection info refresh"
203+
f"['{self._conn_name}']: Connection info refresh"
220204
" operation scheduled for "
221205
f"{(datetime.now(timezone.utc) + timedelta(seconds=delay)).isoformat(timespec='seconds')} "
222206
f"(now + {timedelta(seconds=delay)})"
@@ -240,7 +224,7 @@ async def close(self) -> None:
240224
graceful exit.
241225
"""
242226
logger.debug(
243-
f"['{self._instance_connection_string}']: Canceling connection info "
227+
f"['{self._conn_name}']: Canceling connection info "
244228
"refresh operation tasks"
245229
)
246230
self._current.cancel()

google/cloud/sql/connector/lazy.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from google.cloud.sql.connector.client import CloudSQLClient
2323
from google.cloud.sql.connector.connection_info import ConnectionInfo
24-
from google.cloud.sql.connector.instance import _parse_instance_connection_name
24+
from google.cloud.sql.connector.connection_name import _parse_instance_connection_name
2525
from google.cloud.sql.connector.refresh_utils import _refresh_buffer
2626

2727
logger = logging.getLogger(name=__name__)
@@ -56,10 +56,13 @@ def __init__(
5656
connections.
5757
"""
5858
# validate and parse instance connection name
59-
self._project, self._region, self._instance = _parse_instance_connection_name(
60-
instance_connection_string
59+
conn_name = _parse_instance_connection_name(instance_connection_string)
60+
self._project, self._region, self._instance = (
61+
conn_name.project,
62+
conn_name.region,
63+
conn_name.instance_name,
6164
)
62-
self._instance_connection_string = instance_connection_string
65+
self._conn_name = conn_name
6366

6467
self._enable_iam_auth = enable_iam_auth
6568
self._keys = keys
@@ -91,13 +94,12 @@ async def connect_info(self) -> ConnectionInfo:
9194
< (self._cached.expiration - timedelta(seconds=_refresh_buffer))
9295
):
9396
logger.debug(
94-
f"['{self._instance_connection_string}']: Connection info "
97+
f"['{self._conn_name}']: Connection info "
9598
"is still valid, using cached info"
9699
)
97100
return self._cached
98101
logger.debug(
99-
f"['{self._instance_connection_string}']: Connection info "
100-
"refresh operation started"
102+
f"['{self._conn_name}']: Connection info " "refresh operation started"
101103
)
102104
try:
103105
conn_info = await self._client.get_connection_info(
@@ -109,16 +111,16 @@ async def connect_info(self) -> ConnectionInfo:
109111
)
110112
except Exception as e:
111113
logger.debug(
112-
f"['{self._instance_connection_string}']: Connection info "
114+
f"['{self._conn_name}']: Connection info "
113115
f"refresh operation failed: {str(e)}"
114116
)
115117
raise
116118
logger.debug(
117-
f"['{self._instance_connection_string}']: Connection info "
119+
f"['{self._conn_name}']: Connection info "
118120
"refresh operation completed successfully"
119121
)
120122
logger.debug(
121-
f"['{self._instance_connection_string}']: Current certificate "
123+
f"['{self._conn_name}']: Current certificate "
122124
f"expiration = {str(conn_info.expiration)}"
123125
)
124126
self._cached = conn_info

google/cloud/sql/connector/pg8000.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
import socket
1718
import ssl
1819
from typing import Any, TYPE_CHECKING

google/cloud/sql/connector/pymysql.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
import socket
1718
import ssl
1819
from typing import Any, TYPE_CHECKING

google/cloud/sql/connector/pytds.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
import platform
1718
import socket
1819
import ssl

noxfile.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def lint(session):
5151
"--check-only",
5252
"--diff",
5353
"--profile=google",
54+
"-w=88",
5455
*LINT_PATHS,
5556
)
5657
session.run("black", "--check", "--diff", *LINT_PATHS)
@@ -85,6 +86,7 @@ def format(session):
8586
"isort",
8687
"--fss",
8788
"--profile=google",
89+
"-w=88",
8890
*LINT_PATHS,
8991
)
9092
session.run(

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
import asyncio
1718
import os
1819
import socket

tests/unit/test_connection_name.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest # noqa F401 Needed to run the tests
16+
17+
from google.cloud.sql.connector.connection_name import _parse_instance_connection_name
18+
from google.cloud.sql.connector.connection_name import ConnectionName
19+
20+
21+
def test_ConnectionName() -> None:
22+
conn_name = ConnectionName("project", "region", "instance")
23+
# test class attributes are set properly
24+
assert conn_name.project == "project"
25+
assert conn_name.region == "region"
26+
assert conn_name.instance_name == "instance"
27+
# test ConnectionName str() method prints instance connection name
28+
assert str(conn_name) == "project:region:instance"
29+
30+
31+
@pytest.mark.parametrize(
32+
"connection_name, expected",
33+
[
34+
("project:region:instance", ConnectionName("project", "region", "instance")),
35+
(
36+
"domain-prefix:project:region:instance",
37+
ConnectionName("domain-prefix:project", "region", "instance"),
38+
),
39+
],
40+
)
41+
def test_parse_instance_connection_name(
42+
connection_name: str, expected: ConnectionName
43+
) -> None:
44+
"""
45+
Test that _parse_instance_connection_name works correctly on
46+
normal instance connection names and domain-scoped projects.
47+
"""
48+
assert expected == _parse_instance_connection_name(connection_name)
49+
50+
51+
def test_parse_instance_connection_name_bad_conn_name() -> None:
52+
"""
53+
Tests that ValueError is thrown for bad instance connection names.
54+
"""
55+
with pytest.raises(ValueError):
56+
_parse_instance_connection_name("project:instance") # missing region

0 commit comments

Comments
 (0)