Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ jobs:
POSTGRES_CAS_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CAS_PASS
POSTGRES_CUSTOMER_CAS_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_CONNECTION_NAME
POSTGRES_CUSTOMER_CAS_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_PASS
POSTGRES_CUSTOMER_CAS_DOMAIN_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/POSTGRES_CUSTOMER_CAS_DOMAIN_NAME
SQLSERVER_CONNECTION_NAME:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_CONNECTION_NAME
SQLSERVER_USER:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_USER
SQLSERVER_PASS:${{ vars.GOOGLE_CLOUD_PROJECT }}/SQLSERVER_PASS
Expand All @@ -102,6 +103,7 @@ jobs:
POSTGRES_CAS_PASS: "${{ steps.secrets.outputs.POSTGRES_CAS_PASS }}"
POSTGRES_CUSTOMER_CAS_CONNECTION_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_CONNECTION_NAME }}"
POSTGRES_CUSTOMER_CAS_PASS: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_PASS }}"
POSTGRES_CUSTOMER_CAS_DOMAIN_NAME: "${{ steps.secrets.outputs.POSTGRES_CUSTOMER_CAS_DOMAIN_NAME }}"
SQLSERVER_CONNECTION_NAME: "${{ steps.secrets.outputs.SQLSERVER_CONNECTION_NAME }}"
SQLSERVER_USER: "${{ steps.secrets.outputs.SQLSERVER_USER }}"
SQLSERVER_PASS: "${{ steps.secrets.outputs.SQLSERVER_PASS }}"
Expand Down
22 changes: 22 additions & 0 deletions google/cloud/sql/connector/connection_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import abc
from dataclasses import dataclass
import logging
import ssl
Expand All @@ -34,6 +35,27 @@
logger = logging.getLogger(name=__name__)


class ConnectionInfoCache(abc.ABC):
"""Abstract class for Connector connection info caches."""

@abc.abstractmethod
async def connect_info(self) -> ConnectionInfo:
pass

@abc.abstractmethod
async def force_refresh(self) -> None:
pass

@abc.abstractmethod
async def close(self) -> None:
pass

@property
@abc.abstractmethod
def closed(self) -> bool:
pass


@dataclass
class ConnectionInfo:
"""Contains all necessary information to connect securely to the
Expand Down
4 changes: 4 additions & 0 deletions google/cloud/sql/connector/connection_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def __str__(self) -> str:
return f"{self.domain_name} -> {self.project}:{self.region}:{self.instance_name}"
return f"{self.project}:{self.region}:{self.instance_name}"

def get_connection_string(self) -> str:
"""Get the instance connection string for the Cloud SQL instance."""
return f"{self.project}:{self.region}:{self.instance_name}"

Comment on lines +45 to +48
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows for a reliable way to get the instance connection name for a Cloud SQL instance whether the connector is connecting via domain name or instance connection name.


def _parse_connection_name(connection_name: str) -> ConnectionName:
return _parse_connection_name_with_domain_name(connection_name, "")
Expand Down
26 changes: 19 additions & 7 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from google.cloud.sql.connector.enums import RefreshStrategy
from google.cloud.sql.connector.instance import RefreshAheadCache
from google.cloud.sql.connector.lazy import LazyRefreshCache
from google.cloud.sql.connector.monitored_cache import MonitoredCache
import google.cloud.sql.connector.pg8000 as pg8000
import google.cloud.sql.connector.pymysql as pymysql
import google.cloud.sql.connector.pytds as pytds
Expand Down Expand Up @@ -67,6 +68,7 @@ def __init__(
universe_domain: Optional[str] = None,
refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND,
resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver,
failover_period: int = 30,
) -> None:
"""Initializes a Connector instance.

Expand Down Expand Up @@ -114,6 +116,11 @@ def __init__(
name. To resolve a DNS record to an instance connection name, use
DnsResolver.
Default: DefaultResolver

failover_period (int): The time interval in seconds between each
attempt to check if a failover has occured for a given instance.
Must be used with `resolver=DnsResolver` to have any effect.
Default: 30
"""
# if refresh_strategy is str, convert to RefreshStrategy enum
if isinstance(refresh_strategy, str):
Expand Down Expand Up @@ -143,9 +150,7 @@ def __init__(
)
# initialize dict to store caches, key is a tuple consisting of instance
# connection name string and enable_iam_auth boolean flag
self._cache: dict[
tuple[str, bool], Union[RefreshAheadCache, LazyRefreshCache]
] = {}
self._cache: dict[tuple[str, bool], MonitoredCache] = {}
self._client: Optional[CloudSQLClient] = None

# initialize credentials
Expand All @@ -167,6 +172,7 @@ def __init__(
self._enable_iam_auth = enable_iam_auth
self._user_agent = user_agent
self._resolver = resolver()
self._failover_period = failover_period
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes._from_str(ip_type)
Expand Down Expand Up @@ -286,14 +292,14 @@ async def connect_async(
)
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
if (instance_connection_string, enable_iam_auth) in self._cache:
cache = self._cache[(instance_connection_string, enable_iam_auth)]
monitored_cache = self._cache[(instance_connection_string, enable_iam_auth)]
else:
conn_name = await self._resolver.resolve(instance_connection_string)
if self._refresh_strategy == RefreshStrategy.LAZY:
logger.debug(
f"['{conn_name}']: Refresh strategy is set to lazy refresh"
)
cache = LazyRefreshCache(
cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache(
conn_name,
self._client,
self._keys,
Expand All @@ -309,8 +315,14 @@ async def connect_async(
self._keys,
enable_iam_auth,
)
# wrap cache as a MonitoredCache
monitored_cache = MonitoredCache(
cache,
self._failover_period,
self._resolver,
)
logger.debug(f"['{conn_name}']: Connection info added to cache")
self._cache[(instance_connection_string, enable_iam_auth)] = cache
self._cache[(instance_connection_string, enable_iam_auth)] = monitored_cache

connect_func = {
"pymysql": pymysql.connect,
Expand Down Expand Up @@ -339,7 +351,7 @@ async def connect_async(

# attempt to get connection info for Cloud SQL instance
try:
conn_info = await cache.connect_info()
conn_info = await monitored_cache.connect_info()
# validate driver matches intended database engine
DriverMapping.validate_engine(driver, conn_info.database_version)
ip_address = conn_info.get_preferred_ip(ip_type)
Expand Down
13 changes: 12 additions & 1 deletion google/cloud/sql/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.connection_info import ConnectionInfo
from google.cloud.sql.connector.connection_info import ConnectionInfoCache
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.exceptions import RefreshNotValidError
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
Expand All @@ -35,7 +36,7 @@
APPLICATION_NAME = "cloud-sql-python-connector"


class RefreshAheadCache:
class RefreshAheadCache(ConnectionInfoCache):
"""Cache that refreshes connection info in the background prior to expiration.

Background tasks are used to schedule refresh attempts to get a new
Expand Down Expand Up @@ -74,6 +75,15 @@ def __init__(
self._refresh_in_progress = asyncio.locks.Event()
self._current: asyncio.Task = self._schedule_refresh(0)
self._next: asyncio.Task = self._current
self._closed = False

@property
def conn_name(self) -> ConnectionName:
return self._conn_name

@property
def closed(self) -> bool:
return self._closed

async def force_refresh(self) -> None:
"""
Expand Down Expand Up @@ -212,3 +222,4 @@ async def close(self) -> None:
# gracefully wait for tasks to cancel
tasks = asyncio.gather(self._current, self._next, return_exceptions=True)
await asyncio.wait_for(tasks, timeout=2.0)
self._closed = True
15 changes: 13 additions & 2 deletions google/cloud/sql/connector/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@

from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.connection_info import ConnectionInfo
from google.cloud.sql.connector.connection_info import ConnectionInfoCache
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.refresh_utils import _refresh_buffer

logger = logging.getLogger(name=__name__)


class LazyRefreshCache:
class LazyRefreshCache(ConnectionInfoCache):
"""Cache that refreshes connection info when a caller requests a connection.

Only refreshes the cache when a new connection is requested and the current
Expand Down Expand Up @@ -62,6 +63,15 @@ def __init__(
self._lock = asyncio.Lock()
self._cached: Optional[ConnectionInfo] = None
self._needs_refresh = False
self._closed = False

@property
def conn_name(self) -> ConnectionName:
return self._conn_name

@property
def closed(self) -> bool:
return self._closed

async def force_refresh(self) -> None:
"""
Expand Down Expand Up @@ -121,4 +131,5 @@ async def close(self) -> None:
"""Close is a no-op and provided purely for a consistent interface with
other cache types.
"""
pass
self._closed = True
return
109 changes: 109 additions & 0 deletions google/cloud/sql/connector/monitored_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import logging
from typing import Any, Callable, Optional, Union

from google.cloud.sql.connector.connection_info import ConnectionInfo
from google.cloud.sql.connector.connection_info import ConnectionInfoCache
from google.cloud.sql.connector.instance import RefreshAheadCache
from google.cloud.sql.connector.lazy import LazyRefreshCache
from google.cloud.sql.connector.resolver import DefaultResolver
from google.cloud.sql.connector.resolver import DnsResolver

logger = logging.getLogger(name=__name__)


class MonitoredCache(ConnectionInfoCache):
def __init__(
self,
cache: Union[RefreshAheadCache, LazyRefreshCache],
failover_period: int,
resolver: Union[DefaultResolver, DnsResolver],
) -> None:
self.resolver = resolver
self.cache = cache
self.domain_name_ticker: Optional[asyncio.Task] = None
self.open_conns_count: int = 0

if self.cache.conn_name.domain_name:
self.domain_name_ticker = asyncio.create_task(
ticker(failover_period, self._check_domain_name)
)
logger.debug(
f"['{self.cache.conn_name}']: Configured polling of domain "
f"name with failover period of {failover_period} seconds."
)

@property
def closed(self) -> bool:
return self.cache.closed

async def _check_domain_name(self) -> None:
try:
# Resolve domain name and see if Cloud SQL instance connection name
# has changed. If it has, close all connections.
new_conn_name = await self.resolver.resolve(
self.cache.conn_name.domain_name
)
if new_conn_name != self.cache.conn_name:
logger.debug(
f"['{self.cache.conn_name}']: Cloud SQL instance changed "
f"from {self.cache.conn_name.get_connection_string()} to "
f"{new_conn_name.get_connection_string()}, closing all "
"connections!"
)
await self.close()

except Exception as e:
# Domain name checks should not be fatal, log error and continue.
logger.debug(
f"['{self.cache.conn_name}']: Unable to check domain name, "
f"domain name {self.cache.conn_name.domain_name} did not "
f"resolve: {e}"
)

async def connect_info(self) -> ConnectionInfo:
return await self.cache.connect_info()

async def force_refresh(self) -> None:
return await self.cache.force_refresh()

async def close(self) -> None:
# Cancel domain name ticker task.
if self.domain_name_ticker:
self.domain_name_ticker.cancel()
try:
await self.domain_name_ticker
except asyncio.CancelledError:
logger.debug(
f"['{self.cache.conn_name}']: Cancelled domain name polling task."
)

# If cache is already closed, no further work.
if self.cache.closed:
return
await self.cache.close()


async def ticker(interval: int, function: Callable, *args: Any, **kwargs: Any) -> None:
"""
Ticker function to sleep for specified interval and then schedule call
to given function.
"""
while True:
# Sleep for interval and then schedule task
await asyncio.sleep(interval)
asyncio.create_task(function(*args, **kwargs))
30 changes: 29 additions & 1 deletion tests/system/test_pg8000_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
import os

# [START cloud_sql_connector_postgres_pg8000]
from typing import Union

import pg8000
import sqlalchemy

from google.cloud.sql.connector import Connector
from google.cloud.sql.connector import DefaultResolver
from google.cloud.sql.connector import DnsResolver


def create_sqlalchemy_engine(
Expand All @@ -30,6 +34,7 @@ def create_sqlalchemy_engine(
password: str,
db: str,
refresh_strategy: str = "background",
resolver: Union[DefaultResolver, DnsResolver] = DefaultResolver,
) -> tuple[sqlalchemy.engine.Engine, Connector]:
"""Creates a connection pool for a Cloud SQL instance and returns the pool
and the connector. Callers are responsible for closing the pool and the
Expand Down Expand Up @@ -64,8 +69,13 @@ def create_sqlalchemy_engine(
Refresh strategy for the Cloud SQL Connector. Can be one of "lazy"
or "background". For serverless environments use "lazy" to avoid
errors resulting from CPU being throttled.
resolver (Optional[google.cloud.sql.connector.DefaultResolver | google.cloud.sql.connector.DnsResolver])
Resolver class for the Cloud SQL Connector. Can be one of
DefaultResolver (default) or DnsResolver. The resolver tells the
connector whether to resolve the 'instance_connection_name' as a
Cloud SQL instance connection name or as a domain name.
"""
connector = Connector(refresh_strategy=refresh_strategy)
connector = Connector(refresh_strategy=refresh_strategy, resolver=resolver)

def getconn() -> pg8000.dbapi.Connection:
conn: pg8000.dbapi.Connection = connector.connect(
Expand Down Expand Up @@ -153,3 +163,21 @@ def test_customer_managed_CAS_pg8000_connection() -> None:
curr_time = time[0]
assert type(curr_time) is datetime
connector.close()


def test_domain_name_pg8000_connection() -> None:
"""Basic test to get time from database using domain name to connect."""
domain_name = os.environ["POSTGRES_CUSTOMER_CAS_DOMAIN_NAME"]
user = os.environ["POSTGRES_USER"]
password = os.environ["POSTGRES_CUSTOMER_CAS_PASS"]
db = os.environ["POSTGRES_DB"]

engine, connector = create_sqlalchemy_engine(
domain_name, user, password, db, "lazy", DnsResolver
)
with engine.connect() as conn:
time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone()
conn.commit()
curr_time = time[0]
assert type(curr_time) is datetime
connector.close()
Loading
Loading