Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,44 @@ with Connector(resolver=DnsResolver) as connector:
# ... use SQLAlchemy engine normally
```

### Automatic failover using DNS domain names

> [!NOTE]
>
> Usage of the `asyncpg` driver does not currently support automatic failover.
When the connector is configured using a domain name, the connector will
periodically check if the DNS record for an instance changes. When the connector
detects that the domain name refers to a different instance, the connector will
close all open connections to the old instance. Subsequent connection attempts
will be directed to the new instance.

For example: suppose application is configured to connect using the
domain name `prod-db.mycompany.example.com`. Initially the private DNS
zone has a TXT record with the value `my-project:region:my-instance`. The
application establishes connections to the `my-project:region:my-instance`
Cloud SQL instance.

Then, to reconfigure the application to use a different database
instance, change the value of the `prod-db.mycompany.example.com` DNS record
from `my-project:region:my-instance` to `my-project:other-region:my-instance-2`

The connector inside the application detects the change to this
DNS record. Now, when the application connects to its database using the
domain name `prod-db.mycompany.example.com`, it will connect to the
`my-project:other-region:my-instance-2` Cloud SQL instance.

The connector will automatically close all existing connections to
`my-project:region:my-instance`. This will force the connection pools to
establish new connections. Also, it may cause database queries in progress
to fail.

The connector will poll for changes to the DNS name every 30 seconds by default.
You may configure the frequency of the connections using the Connector's
`failover_period` argument (i.e. `Connector(failover_period=60`). When this is
set to 0, the connector will disable polling and only check if the DNS record
changed when it is creating a new connection.

### Using the Python Connector with Python Web Frameworks

The Python Connector can be used alongside popular Python web frameworks such
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/sql/connector/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""""
"""
Copyright 2019 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
22 changes: 22 additions & 0 deletions google/cloud/sql/connector/connection_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import abc
from dataclasses import dataclass
import logging
import ssl
Expand All @@ -34,6 +35,27 @@
logger = logging.getLogger(name=__name__)


class ConnectionInfoCache(abc.ABC):
"""Abstract class for Connector connection info caches."""

@abc.abstractmethod
async def connect_info(self) -> ConnectionInfo:
pass

@abc.abstractmethod
async def force_refresh(self) -> None:
pass

@abc.abstractmethod
async def close(self) -> None:
pass

@property
@abc.abstractmethod
def closed(self) -> bool:
pass


@dataclass
class ConnectionInfo:
"""Contains all necessary information to connect securely to the
Expand Down
4 changes: 4 additions & 0 deletions google/cloud/sql/connector/connection_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def __str__(self) -> str:
return f"{self.domain_name} -> {self.project}:{self.region}:{self.instance_name}"
return f"{self.project}:{self.region}:{self.instance_name}"

def get_connection_string(self) -> str:
"""Get the instance connection string for the Cloud SQL instance."""
return f"{self.project}:{self.region}:{self.instance_name}"

Comment on lines +45 to +48
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows for a reliable way to get the instance connection name for a Cloud SQL instance whether the connector is connecting via domain name or instance connection name.


def _is_valid_domain(domain_name: str) -> bool:
if DOMAIN_NAME_REGEX.fullmatch(domain_name) is None:
Expand Down
65 changes: 50 additions & 15 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from functools import partial
import logging
import os
import socket
from threading import Thread
from types import TracebackType
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

import google.auth
from google.auth.credentials import Credentials
Expand All @@ -35,6 +36,7 @@
from google.cloud.sql.connector.enums import RefreshStrategy
from google.cloud.sql.connector.instance import RefreshAheadCache
from google.cloud.sql.connector.lazy import LazyRefreshCache
from google.cloud.sql.connector.monitored_cache import MonitoredCache
import google.cloud.sql.connector.pg8000 as pg8000
import google.cloud.sql.connector.pymysql as pymysql
import google.cloud.sql.connector.pytds as pytds
Expand All @@ -46,6 +48,7 @@
logger = logging.getLogger(name=__name__)

ASYNC_DRIVERS = ["asyncpg"]
SERVER_PROXY_PORT = 3307
_DEFAULT_SCHEME = "https://"
_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"
_SQLADMIN_HOST_TEMPLATE = "sqladmin.{universe_domain}"
Expand All @@ -67,6 +70,7 @@ def __init__(
universe_domain: Optional[str] = None,
refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND,
resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver,
failover_period: int = 30,
) -> None:
"""Initializes a Connector instance.

Expand Down Expand Up @@ -114,6 +118,11 @@ def __init__(
name. To resolve a DNS record to an instance connection name, use
DnsResolver.
Default: DefaultResolver

failover_period (int): The time interval in seconds between each
attempt to check if a failover has occured for a given instance.
Must be used with `resolver=DnsResolver` to have any effect.
Default: 30
"""
# if refresh_strategy is str, convert to RefreshStrategy enum
if isinstance(refresh_strategy, str):
Expand Down Expand Up @@ -143,9 +152,7 @@ def __init__(
)
# initialize dict to store caches, key is a tuple consisting of instance
# connection name string and enable_iam_auth boolean flag
self._cache: dict[
tuple[str, bool], Union[RefreshAheadCache, LazyRefreshCache]
] = {}
self._cache: dict[tuple[str, bool], MonitoredCache] = {}
self._client: Optional[CloudSQLClient] = None

# initialize credentials
Expand All @@ -167,6 +174,7 @@ def __init__(
self._enable_iam_auth = enable_iam_auth
self._user_agent = user_agent
self._resolver = resolver()
self._failover_period = failover_period
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes._from_str(ip_type)
Expand Down Expand Up @@ -285,15 +293,19 @@ async def connect_async(
driver=driver,
)
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
if (instance_connection_string, enable_iam_auth) in self._cache:
cache = self._cache[(instance_connection_string, enable_iam_auth)]

conn_name = await self._resolver.resolve(instance_connection_string)
# Cache entry must exist and not be closed
if (str(conn_name), enable_iam_auth) in self._cache and not self._cache[
(str(conn_name), enable_iam_auth)
].closed:
monitored_cache = self._cache[(str(conn_name), enable_iam_auth)]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update cache key to be str(conn_name) which will result in cache key either being :

domain name -> instance connection name or just instance connection name

else:
conn_name = await self._resolver.resolve(instance_connection_string)
if self._refresh_strategy == RefreshStrategy.LAZY:
logger.debug(
f"['{conn_name}']: Refresh strategy is set to lazy refresh"
)
cache = LazyRefreshCache(
cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache(
conn_name,
self._client,
self._keys,
Expand All @@ -309,8 +321,14 @@ async def connect_async(
self._keys,
enable_iam_auth,
)
# wrap cache as a MonitoredCache
monitored_cache = MonitoredCache(
cache,
self._failover_period,
self._resolver,
)
logger.debug(f"['{conn_name}']: Connection info added to cache")
self._cache[(instance_connection_string, enable_iam_auth)] = cache
self._cache[(str(conn_name), enable_iam_auth)] = monitored_cache

connect_func = {
"pymysql": pymysql.connect,
Expand All @@ -321,7 +339,7 @@ async def connect_async(

# only accept supported database drivers
try:
connector = connect_func[driver]
connector: Callable = connect_func[driver] # type: ignore
except KeyError:
raise KeyError(f"Driver '{driver}' is not supported.")

Expand All @@ -339,14 +357,14 @@ async def connect_async(

# attempt to get connection info for Cloud SQL instance
try:
conn_info = await cache.connect_info()
conn_info = await monitored_cache.connect_info()
# validate driver matches intended database engine
DriverMapping.validate_engine(driver, conn_info.database_version)
ip_address = conn_info.get_preferred_ip(ip_type)
except Exception:
# with an error from Cloud SQL Admin API call or IP type, invalidate
# the cache and re-raise the error
await self._remove_cached(instance_connection_string, enable_iam_auth)
await self._remove_cached(str(conn_name), enable_iam_auth)
raise
logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307")
# format `user` param for automatic IAM database authn
Expand All @@ -367,18 +385,28 @@ async def connect_async(
await conn_info.create_ssl_context(enable_iam_auth),
**kwargs,
)
# synchronous drivers are blocking and run using executor
# Create socket with SSLContext for sync drivers
ctx = await conn_info.create_ssl_context(enable_iam_auth)
sock = ctx.wrap_socket(
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
server_hostname=ip_address,
)
# If this connection was opened using a domain name, then store it
# for later in case we need to forcibly close it on failover.
if conn_info.conn_name.domain_name:
monitored_cache.sockets.append(sock)
# Synchronous drivers are blocking and run using executor
connect_partial = partial(
connector,
ip_address,
await conn_info.create_ssl_context(enable_iam_auth),
sock,
**kwargs,
)
return await self._loop.run_in_executor(None, connect_partial)

except Exception:
# with any exception, we attempt a force refresh, then throw the error
await cache.force_refresh()
await monitored_cache.force_refresh()
raise

async def _remove_cached(
Expand Down Expand Up @@ -456,6 +484,7 @@ async def create_async_connector(
universe_domain: Optional[str] = None,
refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND,
resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver,
failover_period: int = 30,
) -> Connector:
"""Helper function to create Connector object for asyncio connections.

Expand Down Expand Up @@ -507,6 +536,11 @@ async def create_async_connector(
DnsResolver.
Default: DefaultResolver

failover_period (int): The time interval in seconds between each
attempt to check if a failover has occured for a given instance.
Must be used with `resolver=DnsResolver` to have any effect.
Default: 30

Returns:
A Connector instance configured with running event loop.
"""
Expand All @@ -525,4 +559,5 @@ async def create_async_connector(
universe_domain=universe_domain,
refresh_strategy=refresh_strategy,
resolver=resolver,
failover_period=failover_period,
)
7 changes: 7 additions & 0 deletions google/cloud/sql/connector/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,10 @@ class DnsResolutionError(Exception):
Exception to be raised when an instance connection name can not be resolved
from a DNS record.
"""


class CacheClosedError(Exception):
"""
Exception to be raised when a ConnectionInfoCache can not be accessed after
it is closed.
"""
13 changes: 12 additions & 1 deletion google/cloud/sql/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.connection_info import ConnectionInfo
from google.cloud.sql.connector.connection_info import ConnectionInfoCache
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.exceptions import RefreshNotValidError
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
Expand All @@ -35,7 +36,7 @@
APPLICATION_NAME = "cloud-sql-python-connector"


class RefreshAheadCache:
class RefreshAheadCache(ConnectionInfoCache):
"""Cache that refreshes connection info in the background prior to expiration.

Background tasks are used to schedule refresh attempts to get a new
Expand Down Expand Up @@ -74,6 +75,15 @@ def __init__(
self._refresh_in_progress = asyncio.locks.Event()
self._current: asyncio.Task = self._schedule_refresh(0)
self._next: asyncio.Task = self._current
self._closed = False

@property
def conn_name(self) -> ConnectionName:
return self._conn_name

@property
def closed(self) -> bool:
return self._closed

async def force_refresh(self) -> None:
"""
Expand Down Expand Up @@ -212,3 +222,4 @@ async def close(self) -> None:
# gracefully wait for tasks to cancel
tasks = asyncio.gather(self._current, self._next, return_exceptions=True)
await asyncio.wait_for(tasks, timeout=2.0)
self._closed = True
15 changes: 13 additions & 2 deletions google/cloud/sql/connector/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@

from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.connection_info import ConnectionInfo
from google.cloud.sql.connector.connection_info import ConnectionInfoCache
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.refresh_utils import _refresh_buffer

logger = logging.getLogger(name=__name__)


class LazyRefreshCache:
class LazyRefreshCache(ConnectionInfoCache):
"""Cache that refreshes connection info when a caller requests a connection.

Only refreshes the cache when a new connection is requested and the current
Expand Down Expand Up @@ -62,6 +63,15 @@ def __init__(
self._lock = asyncio.Lock()
self._cached: Optional[ConnectionInfo] = None
self._needs_refresh = False
self._closed = False

@property
def conn_name(self) -> ConnectionName:
return self._conn_name

@property
def closed(self) -> bool:
return self._closed

async def force_refresh(self) -> None:
"""
Expand Down Expand Up @@ -121,4 +131,5 @@ async def close(self) -> None:
"""Close is a no-op and provided purely for a consistent interface with
other cache types.
"""
pass
self._closed = True
return
Loading
Loading