Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions google/cloud/sql/connector/connection_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
import re

# Instance connection name is the format <PROJECT>:<REGION>:<INSTANCE_NAME>
# Additionally, we have to support legacy "domain-scoped" projects
# (e.g. "google.com:PROJECT")
CONN_NAME_REGEX = re.compile(("([^:]+(:[^:]+)?):([^:]+):([^:]+)"))


@dataclass
class ConnectionName:
"""ConnectionName represents a Cloud SQL instance's "instance connection name".

Takes the format "<PROJECT>:<REGION>:<INSTANCE_NAME>".
"""

project: str
region: str
instance_name: str

def __str__(self):
return f"{self.project}:{self.region}:{self.instance_name}"


def _parse_instance_connection_name(connection_name: str) -> ConnectionName:
if CONN_NAME_REGEX.fullmatch(connection_name) is None:
raise ValueError(
"Arg `instance_connection_string` must have "
"format: PROJECT:REGION:INSTANCE, "
f"got {connection_name}."
)
connection_name_split = CONN_NAME_REGEX.split(connection_name)
return ConnectionName(
connection_name_split[1],
connection_name_split[3],
connection_name_split[4],
)
50 changes: 17 additions & 33 deletions google/cloud/sql/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
from datetime import timedelta
from datetime import timezone
import logging
import re

import aiohttp

from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.connection_info import ConnectionInfo
from google.cloud.sql.connector.connection_name import _parse_instance_connection_name
from google.cloud.sql.connector.exceptions import RefreshNotValidError
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
from google.cloud.sql.connector.refresh_utils import _is_valid
Expand All @@ -36,22 +36,6 @@

APPLICATION_NAME = "cloud-sql-python-connector"

# Instance connection name is the format <PROJECT>:<REGION>:<INSTANCE>
# Additionally, we have to support legacy "domain-scoped" projects
# (e.g. "google.com:PROJECT")
CONN_NAME_REGEX = re.compile(("([^:]+(:[^:]+)?):([^:]+):([^:]+)"))


def _parse_instance_connection_name(connection_name: str) -> tuple[str, str, str]:
if CONN_NAME_REGEX.fullmatch(connection_name) is None:
raise ValueError(
"Arg `instance_connection_string` must have "
"format: PROJECT:REGION:INSTANCE, "
f"got {connection_name}."
)
connection_name_split = CONN_NAME_REGEX.split(connection_name)
return connection_name_split[1], connection_name_split[3], connection_name_split[4]


class RefreshAheadCache:
"""Cache that refreshes connection info in the background prior to expiration.
Expand Down Expand Up @@ -81,10 +65,13 @@ def __init__(
connections.
"""
# validate and parse instance connection name
self._project, self._region, self._instance = _parse_instance_connection_name(
instance_connection_string
conn_name = _parse_instance_connection_name(instance_connection_string)
self._project, self._region, self._instance = (
conn_name.project,
conn_name.region,
conn_name.instance_name,
)
self._instance_connection_string = instance_connection_string
self._conn_name = conn_name

self._enable_iam_auth = enable_iam_auth
self._keys = keys
Expand Down Expand Up @@ -121,8 +108,7 @@ async def _perform_refresh(self) -> ConnectionInfo:
"""
self._refresh_in_progress.set()
logger.debug(
f"['{self._instance_connection_string}']: Connection info refresh "
"operation started"
f"['{self._conn_name}']: Connection info refresh " "operation started"
)

try:
Expand All @@ -135,17 +121,16 @@ async def _perform_refresh(self) -> ConnectionInfo:
self._enable_iam_auth,
)
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
"refresh operation complete"
f"['{self._conn_name}']: Connection info " "refresh operation complete"
)
logger.debug(
f"['{self._instance_connection_string}']: Current certificate "
f"['{self._conn_name}']: Current certificate "
f"expiration = {connection_info.expiration.isoformat()}"
)

except aiohttp.ClientResponseError as e:
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
f"['{self._conn_name}']: Connection info "
f"refresh operation failed: {str(e)}"
)
if e.status == 403:
Expand All @@ -154,7 +139,7 @@ async def _perform_refresh(self) -> ConnectionInfo:

except Exception as e:
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
f"['{self._conn_name}']: Connection info "
f"refresh operation failed: {str(e)}"
)
raise
Expand Down Expand Up @@ -188,18 +173,17 @@ async def _refresh_task(self: RefreshAheadCache, delay: int) -> ConnectionInfo:
# check that refresh is valid
if not await _is_valid(refresh_task):
raise RefreshNotValidError(
f"['{self._instance_connection_string}']: Invalid refresh operation. Certficate appears to be expired."
f"['{self._conn_name}']: Invalid refresh operation. Certficate appears to be expired."
)
except asyncio.CancelledError:
logger.debug(
f"['{self._instance_connection_string}']: Scheduled refresh"
" operation cancelled"
f"['{self._conn_name}']: Scheduled refresh" " operation cancelled"
)
raise
# bad refresh attempt
except Exception as e:
logger.exception(
f"['{self._instance_connection_string}']: "
f"['{self._conn_name}']: "
"An error occurred while performing refresh. "
"Scheduling another refresh attempt immediately",
exc_info=e,
Expand All @@ -216,7 +200,7 @@ async def _refresh_task(self: RefreshAheadCache, delay: int) -> ConnectionInfo:
# calculate refresh delay based on certificate expiration
delay = _seconds_until_refresh(refresh_data.expiration)
logger.debug(
f"['{self._instance_connection_string}']: Connection info refresh"
f"['{self._conn_name}']: Connection info refresh"
" operation scheduled for "
f"{(datetime.now(timezone.utc) + timedelta(seconds=delay)).isoformat(timespec='seconds')} "
f"(now + {timedelta(seconds=delay)})"
Expand All @@ -240,7 +224,7 @@ async def close(self) -> None:
graceful exit.
"""
logger.debug(
f"['{self._instance_connection_string}']: Canceling connection info "
f"['{self._conn_name}']: Canceling connection info "
"refresh operation tasks"
)
self._current.cancel()
Expand Down
21 changes: 12 additions & 9 deletions google/cloud/sql/connector/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.connection_info import ConnectionInfo
from google.cloud.sql.connector.instance import _parse_instance_connection_name
from google.cloud.sql.connector.connection_name import _parse_instance_connection_name
from google.cloud.sql.connector.refresh_utils import _refresh_buffer

logger = logging.getLogger(name=__name__)
Expand Down Expand Up @@ -56,10 +56,13 @@ def __init__(
connections.
"""
# validate and parse instance connection name
self._project, self._region, self._instance = _parse_instance_connection_name(
instance_connection_string
conn_name = _parse_instance_connection_name(instance_connection_string)
self._project, self._region, self._instance = (
conn_name.project,
conn_name.region,
conn_name.instance_name,
)
self._instance_connection_string = instance_connection_string
self._conn_name = conn_name

self._enable_iam_auth = enable_iam_auth
self._keys = keys
Expand Down Expand Up @@ -91,12 +94,12 @@ async def connect_info(self) -> ConnectionInfo:
< (self._cached.expiration - timedelta(seconds=_refresh_buffer))
):
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
f"['{str(self._conn_name)}']: Connection info "
"is still valid, using cached info"
)
return self._cached
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
f"['{str(self._conn_name)}']: Connection info "
"refresh operation started"
)
try:
Expand All @@ -109,16 +112,16 @@ async def connect_info(self) -> ConnectionInfo:
)
except Exception as e:
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
f"['{str(self._conn_name)}']: Connection info "
f"refresh operation failed: {str(e)}"
)
raise
logger.debug(
f"['{self._instance_connection_string}']: Connection info "
f"['{str(self._conn_name)}']: Connection info "
"refresh operation completed successfully"
)
logger.debug(
f"['{self._instance_connection_string}']: Current certificate "
f"['{str(self._conn_name)}']: Current certificate "
f"expiration = {str(conn_info.expiration)}"
)
self._cached = conn_info
Expand Down
1 change: 1 addition & 0 deletions google/cloud/sql/connector/pg8000.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import socket
import ssl
from typing import Any, TYPE_CHECKING
Expand Down
1 change: 1 addition & 0 deletions google/cloud/sql/connector/pymysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import socket
import ssl
from typing import Any, TYPE_CHECKING
Expand Down
1 change: 1 addition & 0 deletions google/cloud/sql/connector/pytds.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import platform
import socket
import ssl
Expand Down
2 changes: 2 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def lint(session):
"--check-only",
"--diff",
"--profile=google",
"-w=88",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

black formatter uses line-length of 88 while isort was using 85, setting isort to 88 as well.

Found this as the new import for _parse_instance_connection_name was 87 characters and causing isort and black to argue with one another.

*LINT_PATHS,
)
session.run("black", "--check", "--diff", *LINT_PATHS)
Expand Down Expand Up @@ -85,6 +86,7 @@ def format(session):
"isort",
"--fss",
"--profile=google",
"-w=88",
*LINT_PATHS,
)
session.run(
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import asyncio
import os
import socket
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/test_connection_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest # noqa F401 Needed to run the tests

from google.cloud.sql.connector.connection_name import _parse_instance_connection_name
from google.cloud.sql.connector.connection_name import ConnectionName


def test_ConnectionName() -> None:
conn_name = ConnectionName("project", "region", "instance")
# test class attributes are set properly
assert conn_name.project == "project"
assert conn_name.region == "region"
assert conn_name.instance_name == "instance"
# test ConnectionName str() method prints instance connection name
assert str(conn_name) == "project:region:instance"


@pytest.mark.parametrize(
"connection_name, expected",
[
("project:region:instance", ConnectionName("project", "region", "instance")),
(
"domain-prefix:project:region:instance",
ConnectionName("domain-prefix:project", "region", "instance"),
),
],
)
def test_parse_instance_connection_name(
connection_name: str, expected: ConnectionName
) -> None:
"""
Test that _parse_instance_connection_name works correctly on
normal instance connection names and domain-scoped projects.
"""
assert expected == _parse_instance_connection_name(connection_name)


def test_parse_instance_connection_name_bad_conn_name() -> None:
"""
Tests that ValueError is thrown for bad instance connection names.
"""
with pytest.raises(ValueError):
_parse_instance_connection_name("project:instance") # missing region
Loading
Loading