From ca55d369a0b3ec1cbcd46daa994169ef42f780e8 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 1 Nov 2024 15:32:44 +0000 Subject: [PATCH 1/3] refactor: add ConnectionName class --- google/cloud/sql/connector/connection_name.py | 51 +++++++++++++++++ google/cloud/sql/connector/instance.py | 50 ++++++----------- google/cloud/sql/connector/lazy.py | 21 ++++--- google/cloud/sql/connector/pg8000.py | 1 + google/cloud/sql/connector/pymysql.py | 1 + google/cloud/sql/connector/pytds.py | 1 + noxfile.py | 2 + tests/conftest.py | 1 + tests/unit/test_connection_name.py | 56 +++++++++++++++++++ tests/unit/test_instance.py | 30 ---------- tests/unit/test_pg8000.py | 1 + tests/unit/test_pymysql.py | 1 + tests/unit/test_pytds.py | 1 + tests/unit/test_rate_limiter.py | 1 + tests/unit/test_utils.py | 1 + 15 files changed, 147 insertions(+), 72 deletions(-) create mode 100644 google/cloud/sql/connector/connection_name.py create mode 100644 tests/unit/test_connection_name.py diff --git a/google/cloud/sql/connector/connection_name.py b/google/cloud/sql/connector/connection_name.py new file mode 100644 index 000000000..e85225d3c --- /dev/null +++ b/google/cloud/sql/connector/connection_name.py @@ -0,0 +1,51 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +import re + +# Instance connection name is the format :: +# Additionally, we have to support legacy "domain-scoped" projects +# (e.g. "google.com:PROJECT") +CONN_NAME_REGEX = re.compile(("([^:]+(:[^:]+)?):([^:]+):([^:]+)")) + + +@dataclass +class ConnectionName: + """ConnectionName represents a Cloud SQL instance's "instance connection name". + + Takes the format "::". + """ + + project: str + region: str + instance_name: str + + def __str__(self): + return f"{self.project}:{self.region}:{self.instance_name}" + + +def _parse_instance_connection_name(connection_name: str) -> ConnectionName: + if CONN_NAME_REGEX.fullmatch(connection_name) is None: + raise ValueError( + "Arg `instance_connection_string` must have " + "format: PROJECT:REGION:INSTANCE, " + f"got {connection_name}." + ) + connection_name_split = CONN_NAME_REGEX.split(connection_name) + return ConnectionName( + connection_name_split[1], + connection_name_split[3], + connection_name_split[4], + ) diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index 818d5eb11..f244b8cf3 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -21,12 +21,12 @@ from datetime import timedelta from datetime import timezone import logging -import re import aiohttp from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo +from google.cloud.sql.connector.connection_name import _parse_instance_connection_name from google.cloud.sql.connector.exceptions import RefreshNotValidError from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter from google.cloud.sql.connector.refresh_utils import _is_valid @@ -36,22 +36,6 @@ APPLICATION_NAME = "cloud-sql-python-connector" -# Instance connection name is the format :: -# Additionally, we have to support legacy "domain-scoped" projects -# (e.g. "google.com:PROJECT") -CONN_NAME_REGEX = re.compile(("([^:]+(:[^:]+)?):([^:]+):([^:]+)")) - - -def _parse_instance_connection_name(connection_name: str) -> tuple[str, str, str]: - if CONN_NAME_REGEX.fullmatch(connection_name) is None: - raise ValueError( - "Arg `instance_connection_string` must have " - "format: PROJECT:REGION:INSTANCE, " - f"got {connection_name}." - ) - connection_name_split = CONN_NAME_REGEX.split(connection_name) - return connection_name_split[1], connection_name_split[3], connection_name_split[4] - class RefreshAheadCache: """Cache that refreshes connection info in the background prior to expiration. @@ -81,10 +65,13 @@ def __init__( connections. """ # validate and parse instance connection name - self._project, self._region, self._instance = _parse_instance_connection_name( - instance_connection_string + conn_name = _parse_instance_connection_name(instance_connection_string) + self._project, self._region, self._instance = ( + conn_name.project, + conn_name.region, + conn_name.instance_name, ) - self._instance_connection_string = instance_connection_string + self._conn_name = conn_name self._enable_iam_auth = enable_iam_auth self._keys = keys @@ -121,8 +108,7 @@ async def _perform_refresh(self) -> ConnectionInfo: """ self._refresh_in_progress.set() logger.debug( - f"['{self._instance_connection_string}']: Connection info refresh " - "operation started" + f"['{self._conn_name}']: Connection info refresh " "operation started" ) try: @@ -135,17 +121,16 @@ async def _perform_refresh(self) -> ConnectionInfo: self._enable_iam_auth, ) logger.debug( - f"['{self._instance_connection_string}']: Connection info " - "refresh operation complete" + f"['{self._conn_name}']: Connection info " "refresh operation complete" ) logger.debug( - f"['{self._instance_connection_string}']: Current certificate " + f"['{self._conn_name}']: Current certificate " f"expiration = {connection_info.expiration.isoformat()}" ) except aiohttp.ClientResponseError as e: logger.debug( - f"['{self._instance_connection_string}']: Connection info " + f"['{self._conn_name}']: Connection info " f"refresh operation failed: {str(e)}" ) if e.status == 403: @@ -154,7 +139,7 @@ async def _perform_refresh(self) -> ConnectionInfo: except Exception as e: logger.debug( - f"['{self._instance_connection_string}']: Connection info " + f"['{self._conn_name}']: Connection info " f"refresh operation failed: {str(e)}" ) raise @@ -188,18 +173,17 @@ async def _refresh_task(self: RefreshAheadCache, delay: int) -> ConnectionInfo: # check that refresh is valid if not await _is_valid(refresh_task): raise RefreshNotValidError( - f"['{self._instance_connection_string}']: Invalid refresh operation. Certficate appears to be expired." + f"['{self._conn_name}']: Invalid refresh operation. Certficate appears to be expired." ) except asyncio.CancelledError: logger.debug( - f"['{self._instance_connection_string}']: Scheduled refresh" - " operation cancelled" + f"['{self._conn_name}']: Scheduled refresh" " operation cancelled" ) raise # bad refresh attempt except Exception as e: logger.exception( - f"['{self._instance_connection_string}']: " + f"['{self._conn_name}']: " "An error occurred while performing refresh. " "Scheduling another refresh attempt immediately", exc_info=e, @@ -216,7 +200,7 @@ async def _refresh_task(self: RefreshAheadCache, delay: int) -> ConnectionInfo: # calculate refresh delay based on certificate expiration delay = _seconds_until_refresh(refresh_data.expiration) logger.debug( - f"['{self._instance_connection_string}']: Connection info refresh" + f"['{self._conn_name}']: Connection info refresh" " operation scheduled for " f"{(datetime.now(timezone.utc) + timedelta(seconds=delay)).isoformat(timespec='seconds')} " f"(now + {timedelta(seconds=delay)})" @@ -240,7 +224,7 @@ async def close(self) -> None: graceful exit. """ logger.debug( - f"['{self._instance_connection_string}']: Canceling connection info " + f"['{self._conn_name}']: Canceling connection info " "refresh operation tasks" ) self._current.cancel() diff --git a/google/cloud/sql/connector/lazy.py b/google/cloud/sql/connector/lazy.py index 9b8cfa24d..cbe313750 100644 --- a/google/cloud/sql/connector/lazy.py +++ b/google/cloud/sql/connector/lazy.py @@ -21,7 +21,7 @@ from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_info import ConnectionInfo -from google.cloud.sql.connector.instance import _parse_instance_connection_name +from google.cloud.sql.connector.connection_name import _parse_instance_connection_name from google.cloud.sql.connector.refresh_utils import _refresh_buffer logger = logging.getLogger(name=__name__) @@ -56,10 +56,13 @@ def __init__( connections. """ # validate and parse instance connection name - self._project, self._region, self._instance = _parse_instance_connection_name( - instance_connection_string + conn_name = _parse_instance_connection_name(instance_connection_string) + self._project, self._region, self._instance = ( + conn_name.project, + conn_name.region, + conn_name.instance_name, ) - self._instance_connection_string = instance_connection_string + self._conn_name = conn_name self._enable_iam_auth = enable_iam_auth self._keys = keys @@ -91,12 +94,12 @@ async def connect_info(self) -> ConnectionInfo: < (self._cached.expiration - timedelta(seconds=_refresh_buffer)) ): logger.debug( - f"['{self._instance_connection_string}']: Connection info " + f"['{str(self._conn_name)}']: Connection info " "is still valid, using cached info" ) return self._cached logger.debug( - f"['{self._instance_connection_string}']: Connection info " + f"['{str(self._conn_name)}']: Connection info " "refresh operation started" ) try: @@ -109,16 +112,16 @@ async def connect_info(self) -> ConnectionInfo: ) except Exception as e: logger.debug( - f"['{self._instance_connection_string}']: Connection info " + f"['{str(self._conn_name)}']: Connection info " f"refresh operation failed: {str(e)}" ) raise logger.debug( - f"['{self._instance_connection_string}']: Connection info " + f"['{str(self._conn_name)}']: Connection info " "refresh operation completed successfully" ) logger.debug( - f"['{self._instance_connection_string}']: Current certificate " + f"['{str(self._conn_name)}']: Current certificate " f"expiration = {str(conn_info.expiration)}" ) self._cached = conn_info diff --git a/google/cloud/sql/connector/pg8000.py b/google/cloud/sql/connector/pg8000.py index 623738f85..1f66dde2a 100644 --- a/google/cloud/sql/connector/pg8000.py +++ b/google/cloud/sql/connector/pg8000.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import socket import ssl from typing import Any, TYPE_CHECKING diff --git a/google/cloud/sql/connector/pymysql.py b/google/cloud/sql/connector/pymysql.py index 8971ff9b2..a16584367 100644 --- a/google/cloud/sql/connector/pymysql.py +++ b/google/cloud/sql/connector/pymysql.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import socket import ssl from typing import Any, TYPE_CHECKING diff --git a/google/cloud/sql/connector/pytds.py b/google/cloud/sql/connector/pytds.py index 5c78fd3fc..243d90fd5 100644 --- a/google/cloud/sql/connector/pytds.py +++ b/google/cloud/sql/connector/pytds.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import platform import socket import ssl diff --git a/noxfile.py b/noxfile.py index 528642c1a..8329b2de8 100644 --- a/noxfile.py +++ b/noxfile.py @@ -51,6 +51,7 @@ def lint(session): "--check-only", "--diff", "--profile=google", + "-w=88", *LINT_PATHS, ) session.run("black", "--check", "--diff", *LINT_PATHS) @@ -85,6 +86,7 @@ def format(session): "isort", "--fss", "--profile=google", + "-w=88", *LINT_PATHS, ) session.run( diff --git a/tests/conftest.py b/tests/conftest.py index dd5c3952d..470fe19f4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import asyncio import os import socket diff --git a/tests/unit/test_connection_name.py b/tests/unit/test_connection_name.py new file mode 100644 index 000000000..1e3730424 --- /dev/null +++ b/tests/unit/test_connection_name.py @@ -0,0 +1,56 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest # noqa F401 Needed to run the tests + +from google.cloud.sql.connector.connection_name import _parse_instance_connection_name +from google.cloud.sql.connector.connection_name import ConnectionName + + +def test_ConnectionName() -> None: + conn_name = ConnectionName("project", "region", "instance") + # test class attributes are set properly + assert conn_name.project == "project" + assert conn_name.region == "region" + assert conn_name.instance_name == "instance" + # test ConnectionName str() method prints instance connection name + assert str(conn_name) == "project:region:instance" + + +@pytest.mark.parametrize( + "connection_name, expected", + [ + ("project:region:instance", ConnectionName("project", "region", "instance")), + ( + "domain-prefix:project:region:instance", + ConnectionName("domain-prefix:project", "region", "instance"), + ), + ], +) +def test_parse_instance_connection_name( + connection_name: str, expected: ConnectionName +) -> None: + """ + Test that _parse_instance_connection_name works correctly on + normal instance connection names and domain-scoped projects. + """ + assert expected == _parse_instance_connection_name(connection_name) + + +def test_parse_instance_connection_name_bad_conn_name() -> None: + """ + Tests that ValueError is thrown for bad instance connection names. + """ + with pytest.raises(ValueError): + _parse_instance_connection_name("project:instance") # missing region diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 5dcf1f5aa..5b0887aa2 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -16,7 +16,6 @@ import asyncio import datetime -from typing import Tuple from aiohttp import ClientResponseError from aiohttp import RequestInfo @@ -31,7 +30,6 @@ from google.cloud.sql.connector.connection_info import ConnectionInfo from google.cloud.sql.connector.exceptions import AutoIAMAuthNotSupported from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError -from google.cloud.sql.connector.instance import _parse_instance_connection_name from google.cloud.sql.connector.instance import RefreshAheadCache from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter from google.cloud.sql.connector.refresh_utils import _is_valid @@ -43,34 +41,6 @@ def test_rate_limiter() -> AsyncRateLimiter: return AsyncRateLimiter(max_capacity=1, rate=1 / 2) -@pytest.mark.parametrize( - "connection_name, expected", - [ - ("project:region:instance", ("project", "region", "instance")), - ( - "domain-prefix:project:region:instance", - ("domain-prefix:project", "region", "instance"), - ), - ], -) -def test_parse_instance_connection_name( - connection_name: str, expected: Tuple[str, str, str] -) -> None: - """ - Test that _parse_instance_connection_name works correctly on - normal instance connection names and domain-scoped projects. - """ - assert expected == _parse_instance_connection_name(connection_name) - - -def test_parse_instance_connection_name_bad_conn_name() -> None: - """ - Tests that ValueError is thrown for bad instance connection names. - """ - with pytest.raises(ValueError): - _parse_instance_connection_name("project:instance") # missing region - - @pytest.mark.asyncio async def test_Instance_init( cache: RefreshAheadCache, diff --git a/tests/unit/test_pg8000.py b/tests/unit/test_pg8000.py index 26a4bfeab..1b2adbb65 100644 --- a/tests/unit/test_pg8000.py +++ b/tests/unit/test_pg8000.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from functools import partial from typing import Any diff --git a/tests/unit/test_pymysql.py b/tests/unit/test_pymysql.py index 3afefb2a4..69d2aba8f 100644 --- a/tests/unit/test_pymysql.py +++ b/tests/unit/test_pymysql.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from functools import partial import ssl from typing import Any diff --git a/tests/unit/test_pytds.py b/tests/unit/test_pytds.py index 5bfa62419..633aab74a 100644 --- a/tests/unit/test_pytds.py +++ b/tests/unit/test_pytds.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + from functools import partial import platform from typing import Any diff --git a/tests/unit/test_rate_limiter.py b/tests/unit/test_rate_limiter.py index 587e76809..5e187b81d 100644 --- a/tests/unit/test_rate_limiter.py +++ b/tests/unit/test_rate_limiter.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import asyncio import pytest # noqa F401 Needed to run the tests diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index fe190ceba..6545bc7a8 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ + import pytest # noqa F401 Needed to run the tests from google.cloud.sql.connector import utils From e64d9f7e2fdc3fcc7539960b80cbe63f3b443416 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 1 Nov 2024 15:42:06 +0000 Subject: [PATCH 2/3] chore: add missing return type --- google/cloud/sql/connector/connection_name.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/sql/connector/connection_name.py b/google/cloud/sql/connector/connection_name.py index e85225d3c..d240fb565 100644 --- a/google/cloud/sql/connector/connection_name.py +++ b/google/cloud/sql/connector/connection_name.py @@ -32,7 +32,7 @@ class ConnectionName: region: str instance_name: str - def __str__(self): + def __str__(self) -> str: return f"{self.project}:{self.region}:{self.instance_name}" From 1345ac2bbfa2dc61c15f7606b134492dfb54332b Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 1 Nov 2024 15:47:32 +0000 Subject: [PATCH 3/3] chore: no need to use str() --- google/cloud/sql/connector/lazy.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/google/cloud/sql/connector/lazy.py b/google/cloud/sql/connector/lazy.py index cbe313750..672f989e8 100644 --- a/google/cloud/sql/connector/lazy.py +++ b/google/cloud/sql/connector/lazy.py @@ -94,13 +94,12 @@ async def connect_info(self) -> ConnectionInfo: < (self._cached.expiration - timedelta(seconds=_refresh_buffer)) ): logger.debug( - f"['{str(self._conn_name)}']: Connection info " + f"['{self._conn_name}']: Connection info " "is still valid, using cached info" ) return self._cached logger.debug( - f"['{str(self._conn_name)}']: Connection info " - "refresh operation started" + f"['{self._conn_name}']: Connection info " "refresh operation started" ) try: conn_info = await self._client.get_connection_info( @@ -112,16 +111,16 @@ async def connect_info(self) -> ConnectionInfo: ) except Exception as e: logger.debug( - f"['{str(self._conn_name)}']: Connection info " + f"['{self._conn_name}']: Connection info " f"refresh operation failed: {str(e)}" ) raise logger.debug( - f"['{str(self._conn_name)}']: Connection info " + f"['{self._conn_name}']: Connection info " "refresh operation completed successfully" ) logger.debug( - f"['{str(self._conn_name)}']: Current certificate " + f"['{self._conn_name}']: Current certificate " f"expiration = {str(conn_info.expiration)}" ) self._cached = conn_info