2020from functools import partial
2121import logging
2222import os
23+ import socket
2324from threading import Thread
2425from types import TracebackType
2526from typing import Any , Optional , Union
4748logger = logging .getLogger (name = __name__ )
4849
4950ASYNC_DRIVERS = ["asyncpg" ]
51+ SERVER_PROXY_PORT = 3307
5052_DEFAULT_SCHEME = "https://"
5153_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"
5254_SQLADMIN_HOST_TEMPLATE = "sqladmin.{universe_domain}"
@@ -291,10 +293,11 @@ async def connect_async(
291293 driver = driver ,
292294 )
293295 enable_iam_auth = kwargs .pop ("enable_iam_auth" , self ._enable_iam_auth )
294- if (instance_connection_string , enable_iam_auth ) in self ._cache :
295- monitored_cache = self ._cache [(instance_connection_string , enable_iam_auth )]
296+
297+ conn_name = await self ._resolver .resolve (instance_connection_string )
298+ if (str (conn_name ), enable_iam_auth ) in self ._cache :
299+ monitored_cache = self ._cache [(str (conn_name ), enable_iam_auth )]
296300 else :
297- conn_name = await self ._resolver .resolve (instance_connection_string )
298301 if self ._refresh_strategy == RefreshStrategy .LAZY :
299302 logger .debug (
300303 f"['{ conn_name } ']: Refresh strategy is set to lazy refresh"
@@ -322,7 +325,7 @@ async def connect_async(
322325 self ._resolver ,
323326 )
324327 logger .debug (f"['{ conn_name } ']: Connection info added to cache" )
325- self ._cache [(instance_connection_string , enable_iam_auth )] = monitored_cache
328+ self ._cache [(str ( conn_name ) , enable_iam_auth )] = monitored_cache
326329
327330 connect_func = {
328331 "pymysql" : pymysql .connect ,
@@ -358,7 +361,7 @@ async def connect_async(
358361 except Exception :
359362 # with an error from Cloud SQL Admin API call or IP type, invalidate
360363 # the cache and re-raise the error
361- await self ._remove_cached (instance_connection_string , enable_iam_auth )
364+ await self ._remove_cached (str ( conn_name ) , enable_iam_auth )
362365 raise
363366 logger .debug (f"['{ conn_info .conn_name } ']: Connecting to { ip_address } :3307" )
364367 # format `user` param for automatic IAM database authn
@@ -379,11 +382,21 @@ async def connect_async(
379382 await conn_info .create_ssl_context (enable_iam_auth ),
380383 ** kwargs ,
381384 )
382- # synchronous drivers are blocking and run using executor
385+ # Create socket with SSLContext for sync drivers
386+ ctx = await conn_info .create_ssl_context (enable_iam_auth )
387+ sock = ctx .wrap_socket (
388+ socket .create_connection ((ip_address , SERVER_PROXY_PORT )),
389+ server_hostname = ip_address ,
390+ )
391+ # If this connection was opened using a domain name, then store it
392+ # for later in case we need to forcibly close it on failover.
393+ if conn_info .conn_name .domain_name :
394+ monitored_cache .sockets .append (sock )
395+ # Synchronous drivers are blocking and run using executor
383396 connect_partial = partial (
384397 connector ,
385398 ip_address ,
386- await conn_info . create_ssl_context ( enable_iam_auth ) ,
399+ sock ,
387400 ** kwargs ,
388401 )
389402 return await self ._loop .run_in_executor (None , connect_partial )
@@ -468,6 +481,7 @@ async def create_async_connector(
468481 universe_domain : Optional [str ] = None ,
469482 refresh_strategy : str | RefreshStrategy = RefreshStrategy .BACKGROUND ,
470483 resolver : type [DefaultResolver ] | type [DnsResolver ] = DefaultResolver ,
484+ failover_period : int = 30 ,
471485) -> Connector :
472486 """Helper function to create Connector object for asyncio connections.
473487
@@ -519,6 +533,11 @@ async def create_async_connector(
519533 DnsResolver.
520534 Default: DefaultResolver
521535
536+ failover_period (int): The time interval in seconds between each
537+ attempt to check if a failover has occured for a given instance.
538+ Must be used with `resolver=DnsResolver` to have any effect.
539+ Default: 30
540+
522541 Returns:
523542 A Connector instance configured with running event loop.
524543 """
@@ -537,4 +556,5 @@ async def create_async_connector(
537556 universe_domain = universe_domain ,
538557 refresh_strategy = refresh_strategy ,
539558 resolver = resolver ,
559+ failover_period = failover_period ,
540560 )
0 commit comments