Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion google/cloud/sql/connector/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""""
"""
Copyright 2019 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
22 changes: 22 additions & 0 deletions google/cloud/sql/connector/connection_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import abc
from dataclasses import dataclass
import logging
import ssl
Expand All @@ -34,6 +35,27 @@
logger = logging.getLogger(name=__name__)


class ConnectionInfoCache(abc.ABC):
"""Abstract class for Connector connection info caches."""

@abc.abstractmethod
async def connect_info(self) -> ConnectionInfo:
pass

@abc.abstractmethod
async def force_refresh(self) -> None:
pass

@abc.abstractmethod
async def close(self) -> None:
pass

@property
@abc.abstractmethod
def closed(self) -> bool:
pass


@dataclass
class ConnectionInfo:
"""Contains all necessary information to connect securely to the
Expand Down
4 changes: 4 additions & 0 deletions google/cloud/sql/connector/connection_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def __str__(self) -> str:
return f"{self.domain_name} -> {self.project}:{self.region}:{self.instance_name}"
return f"{self.project}:{self.region}:{self.instance_name}"

def get_connection_string(self) -> str:
"""Get the instance connection string for the Cloud SQL instance."""
return f"{self.project}:{self.region}:{self.instance_name}"

Comment on lines +45 to +48
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows for a reliable way to get the instance connection name for a Cloud SQL instance whether the connector is connecting via domain name or instance connection name.


def _is_valid_domain(domain_name: str) -> bool:
if DOMAIN_NAME_REGEX.fullmatch(domain_name) is None:
Expand Down
62 changes: 47 additions & 15 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from functools import partial
import logging
import os
import socket
from threading import Thread
from types import TracebackType
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

import google.auth
from google.auth.credentials import Credentials
Expand All @@ -35,6 +36,7 @@
from google.cloud.sql.connector.enums import RefreshStrategy
from google.cloud.sql.connector.instance import RefreshAheadCache
from google.cloud.sql.connector.lazy import LazyRefreshCache
from google.cloud.sql.connector.monitored_cache import MonitoredCache
import google.cloud.sql.connector.pg8000 as pg8000
import google.cloud.sql.connector.pymysql as pymysql
import google.cloud.sql.connector.pytds as pytds
Expand All @@ -46,6 +48,7 @@
logger = logging.getLogger(name=__name__)

ASYNC_DRIVERS = ["asyncpg"]
SERVER_PROXY_PORT = 3307
_DEFAULT_SCHEME = "https://"
_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"
_SQLADMIN_HOST_TEMPLATE = "sqladmin.{universe_domain}"
Expand All @@ -67,6 +70,7 @@ def __init__(
universe_domain: Optional[str] = None,
refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND,
resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver,
failover_period: int = 30,
) -> None:
"""Initializes a Connector instance.

Expand Down Expand Up @@ -114,6 +118,11 @@ def __init__(
name. To resolve a DNS record to an instance connection name, use
DnsResolver.
Default: DefaultResolver

failover_period (int): The time interval in seconds between each
attempt to check if a failover has occured for a given instance.
Must be used with `resolver=DnsResolver` to have any effect.
Default: 30
"""
# if refresh_strategy is str, convert to RefreshStrategy enum
if isinstance(refresh_strategy, str):
Expand Down Expand Up @@ -143,9 +152,7 @@ def __init__(
)
# initialize dict to store caches, key is a tuple consisting of instance
# connection name string and enable_iam_auth boolean flag
self._cache: dict[
tuple[str, bool], Union[RefreshAheadCache, LazyRefreshCache]
] = {}
self._cache: dict[tuple[str, bool], MonitoredCache] = {}
self._client: Optional[CloudSQLClient] = None

# initialize credentials
Expand All @@ -167,6 +174,7 @@ def __init__(
self._enable_iam_auth = enable_iam_auth
self._user_agent = user_agent
self._resolver = resolver()
self._failover_period = failover_period
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes._from_str(ip_type)
Expand Down Expand Up @@ -285,15 +293,16 @@ async def connect_async(
driver=driver,
)
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
if (instance_connection_string, enable_iam_auth) in self._cache:
cache = self._cache[(instance_connection_string, enable_iam_auth)]

conn_name = await self._resolver.resolve(instance_connection_string)
if (str(conn_name), enable_iam_auth) in self._cache:
monitored_cache = self._cache[(str(conn_name), enable_iam_auth)]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update cache key to be str(conn_name) which will result in cache key either being :

domain name -> instance connection name or just instance connection name

else:
conn_name = await self._resolver.resolve(instance_connection_string)
if self._refresh_strategy == RefreshStrategy.LAZY:
logger.debug(
f"['{conn_name}']: Refresh strategy is set to lazy refresh"
)
cache = LazyRefreshCache(
cache: Union[LazyRefreshCache, RefreshAheadCache] = LazyRefreshCache(
conn_name,
self._client,
self._keys,
Expand All @@ -309,8 +318,14 @@ async def connect_async(
self._keys,
enable_iam_auth,
)
# wrap cache as a MonitoredCache
monitored_cache = MonitoredCache(
cache,
self._failover_period,
self._resolver,
)
logger.debug(f"['{conn_name}']: Connection info added to cache")
self._cache[(instance_connection_string, enable_iam_auth)] = cache
self._cache[(str(conn_name), enable_iam_auth)] = monitored_cache

connect_func = {
"pymysql": pymysql.connect,
Expand All @@ -321,7 +336,7 @@ async def connect_async(

# only accept supported database drivers
try:
connector = connect_func[driver]
connector: Callable = connect_func[driver] # type: ignore
except KeyError:
raise KeyError(f"Driver '{driver}' is not supported.")

Expand All @@ -339,14 +354,14 @@ async def connect_async(

# attempt to get connection info for Cloud SQL instance
try:
conn_info = await cache.connect_info()
conn_info = await monitored_cache.connect_info()
# validate driver matches intended database engine
DriverMapping.validate_engine(driver, conn_info.database_version)
ip_address = conn_info.get_preferred_ip(ip_type)
except Exception:
# with an error from Cloud SQL Admin API call or IP type, invalidate
# the cache and re-raise the error
await self._remove_cached(instance_connection_string, enable_iam_auth)
await self._remove_cached(str(conn_name), enable_iam_auth)
raise
logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307")
# format `user` param for automatic IAM database authn
Expand All @@ -367,18 +382,28 @@ async def connect_async(
await conn_info.create_ssl_context(enable_iam_auth),
**kwargs,
)
# synchronous drivers are blocking and run using executor
# Create socket with SSLContext for sync drivers
ctx = await conn_info.create_ssl_context(enable_iam_auth)
sock = ctx.wrap_socket(
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
server_hostname=ip_address,
)
# If this connection was opened using a domain name, then store it
# for later in case we need to forcibly close it on failover.
if conn_info.conn_name.domain_name:
monitored_cache.sockets.append(sock)
# Synchronous drivers are blocking and run using executor
connect_partial = partial(
connector,
ip_address,
await conn_info.create_ssl_context(enable_iam_auth),
sock,
**kwargs,
)
return await self._loop.run_in_executor(None, connect_partial)

except Exception:
# with any exception, we attempt a force refresh, then throw the error
await cache.force_refresh()
await monitored_cache.force_refresh()
raise

async def _remove_cached(
Expand Down Expand Up @@ -456,6 +481,7 @@ async def create_async_connector(
universe_domain: Optional[str] = None,
refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND,
resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver,
failover_period: int = 30,
) -> Connector:
"""Helper function to create Connector object for asyncio connections.

Expand Down Expand Up @@ -507,6 +533,11 @@ async def create_async_connector(
DnsResolver.
Default: DefaultResolver

failover_period (int): The time interval in seconds between each
attempt to check if a failover has occured for a given instance.
Must be used with `resolver=DnsResolver` to have any effect.
Default: 30

Returns:
A Connector instance configured with running event loop.
"""
Expand All @@ -525,4 +556,5 @@ async def create_async_connector(
universe_domain=universe_domain,
refresh_strategy=refresh_strategy,
resolver=resolver,
failover_period=failover_period,
)
13 changes: 12 additions & 1 deletion google/cloud/sql/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.connection_info import ConnectionInfo
from google.cloud.sql.connector.connection_info import ConnectionInfoCache
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.exceptions import RefreshNotValidError
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
Expand All @@ -35,7 +36,7 @@
APPLICATION_NAME = "cloud-sql-python-connector"


class RefreshAheadCache:
class RefreshAheadCache(ConnectionInfoCache):
"""Cache that refreshes connection info in the background prior to expiration.

Background tasks are used to schedule refresh attempts to get a new
Expand Down Expand Up @@ -74,6 +75,15 @@ def __init__(
self._refresh_in_progress = asyncio.locks.Event()
self._current: asyncio.Task = self._schedule_refresh(0)
self._next: asyncio.Task = self._current
self._closed = False

@property
def conn_name(self) -> ConnectionName:
return self._conn_name

@property
def closed(self) -> bool:
return self._closed

async def force_refresh(self) -> None:
"""
Expand Down Expand Up @@ -212,3 +222,4 @@ async def close(self) -> None:
# gracefully wait for tasks to cancel
tasks = asyncio.gather(self._current, self._next, return_exceptions=True)
await asyncio.wait_for(tasks, timeout=2.0)
self._closed = True
15 changes: 13 additions & 2 deletions google/cloud/sql/connector/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@

from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.connection_info import ConnectionInfo
from google.cloud.sql.connector.connection_info import ConnectionInfoCache
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.refresh_utils import _refresh_buffer

logger = logging.getLogger(name=__name__)


class LazyRefreshCache:
class LazyRefreshCache(ConnectionInfoCache):
"""Cache that refreshes connection info when a caller requests a connection.

Only refreshes the cache when a new connection is requested and the current
Expand Down Expand Up @@ -62,6 +63,15 @@ def __init__(
self._lock = asyncio.Lock()
self._cached: Optional[ConnectionInfo] = None
self._needs_refresh = False
self._closed = False

@property
def conn_name(self) -> ConnectionName:
return self._conn_name

@property
def closed(self) -> bool:
return self._closed

async def force_refresh(self) -> None:
"""
Expand Down Expand Up @@ -121,4 +131,5 @@ async def close(self) -> None:
"""Close is a no-op and provided purely for a consistent interface with
other cache types.
"""
pass
self._closed = True
return
Loading
Loading