2020from  functools  import  partial 
2121import  logging 
2222import  os 
23+ import  socket 
2324from  threading  import  Thread 
2425from  types  import  TracebackType 
25- from  typing  import  Any , Optional , Union 
26+ from  typing  import  Any , Callable ,  Optional , Union 
2627
2728import  google .auth 
2829from  google .auth .credentials  import  Credentials 
3536from  google .cloud .sql .connector .enums  import  RefreshStrategy 
3637from  google .cloud .sql .connector .instance  import  RefreshAheadCache 
3738from  google .cloud .sql .connector .lazy  import  LazyRefreshCache 
39+ from  google .cloud .sql .connector .monitored_cache  import  MonitoredCache 
3840import  google .cloud .sql .connector .pg8000  as  pg8000 
3941import  google .cloud .sql .connector .pymysql  as  pymysql 
4042import  google .cloud .sql .connector .pytds  as  pytds 
4648logger  =  logging .getLogger (name = __name__ )
4749
4850ASYNC_DRIVERS  =  ["asyncpg" ]
51+ SERVER_PROXY_PORT  =  3307 
4952_DEFAULT_SCHEME  =  "https://" 
5053_DEFAULT_UNIVERSE_DOMAIN  =  "googleapis.com" 
5154_SQLADMIN_HOST_TEMPLATE  =  "sqladmin.{universe_domain}" 
@@ -67,6 +70,7 @@ def __init__(
6770        universe_domain : Optional [str ] =  None ,
6871        refresh_strategy : str  |  RefreshStrategy  =  RefreshStrategy .BACKGROUND ,
6972        resolver : type [DefaultResolver ] |  type [DnsResolver ] =  DefaultResolver ,
73+         failover_period : int  =  30 ,
7074    ) ->  None :
7175        """Initializes a Connector instance. 
7276
@@ -114,6 +118,11 @@ def __init__(
114118                name. To resolve a DNS record to an instance connection name, use 
115119                DnsResolver. 
116120                Default: DefaultResolver 
121+ 
122+             failover_period (int): The time interval in seconds between each 
123+                 attempt to check if a failover has occured for a given instance. 
124+                 Must be used with `resolver=DnsResolver` to have any effect. 
125+                 Default: 30 
117126        """ 
118127        # if refresh_strategy is str, convert to RefreshStrategy enum 
119128        if  isinstance (refresh_strategy , str ):
@@ -143,9 +152,7 @@ def __init__(
143152                )
144153        # initialize dict to store caches, key is a tuple consisting of instance 
145154        # connection name string and enable_iam_auth boolean flag 
146-         self ._cache : dict [
147-             tuple [str , bool ], Union [RefreshAheadCache , LazyRefreshCache ]
148-         ] =  {}
155+         self ._cache : dict [tuple [str , bool ], MonitoredCache ] =  {}
149156        self ._client : Optional [CloudSQLClient ] =  None 
150157
151158        # initialize credentials 
@@ -167,6 +174,7 @@ def __init__(
167174        self ._enable_iam_auth  =  enable_iam_auth 
168175        self ._user_agent  =  user_agent 
169176        self ._resolver  =  resolver ()
177+         self ._failover_period  =  failover_period 
170178        # if ip_type is str, convert to IPTypes enum 
171179        if  isinstance (ip_type , str ):
172180            ip_type  =  IPTypes ._from_str (ip_type )
@@ -285,15 +293,19 @@ async def connect_async(
285293                driver = driver ,
286294            )
287295        enable_iam_auth  =  kwargs .pop ("enable_iam_auth" , self ._enable_iam_auth )
288-         if  (instance_connection_string , enable_iam_auth ) in  self ._cache :
289-             cache  =  self ._cache [(instance_connection_string , enable_iam_auth )]
296+ 
297+         conn_name  =  await  self ._resolver .resolve (instance_connection_string )
298+         # Cache entry must exist and not be closed 
299+         if  (str (conn_name ), enable_iam_auth ) in  self ._cache  and  not  self ._cache [
300+             (str (conn_name ), enable_iam_auth )
301+         ].closed :
302+             monitored_cache  =  self ._cache [(str (conn_name ), enable_iam_auth )]
290303        else :
291-             conn_name  =  await  self ._resolver .resolve (instance_connection_string )
292304            if  self ._refresh_strategy  ==  RefreshStrategy .LAZY :
293305                logger .debug (
294306                    f"['{ conn_name }  ']: Refresh strategy is set to lazy refresh" 
295307                )
296-                 cache  =  LazyRefreshCache (
308+                 cache :  Union [ LazyRefreshCache ,  RefreshAheadCache ]  =  LazyRefreshCache (
297309                    conn_name ,
298310                    self ._client ,
299311                    self ._keys ,
@@ -309,8 +321,14 @@ async def connect_async(
309321                    self ._keys ,
310322                    enable_iam_auth ,
311323                )
324+             # wrap cache as a MonitoredCache 
325+             monitored_cache  =  MonitoredCache (
326+                 cache ,
327+                 self ._failover_period ,
328+                 self ._resolver ,
329+             )
312330            logger .debug (f"['{ conn_name }  ']: Connection info added to cache" )
313-             self ._cache [(instance_connection_string , enable_iam_auth )] =  cache 
331+             self ._cache [(str ( conn_name ) , enable_iam_auth )] =  monitored_cache 
314332
315333        connect_func  =  {
316334            "pymysql" : pymysql .connect ,
@@ -321,7 +339,7 @@ async def connect_async(
321339
322340        # only accept supported database drivers 
323341        try :
324-             connector   =  connect_func [driver ]
342+             connector :  Callable   =  connect_func [driver ]   # type: ignore 
325343        except  KeyError :
326344            raise  KeyError (f"Driver '{ driver }  ' is not supported." )
327345
@@ -339,14 +357,14 @@ async def connect_async(
339357
340358        # attempt to get connection info for Cloud SQL instance 
341359        try :
342-             conn_info  =  await  cache .connect_info ()
360+             conn_info  =  await  monitored_cache .connect_info ()
343361            # validate driver matches intended database engine 
344362            DriverMapping .validate_engine (driver , conn_info .database_version )
345363            ip_address  =  conn_info .get_preferred_ip (ip_type )
346364        except  Exception :
347365            # with an error from Cloud SQL Admin API call or IP type, invalidate 
348366            # the cache and re-raise the error 
349-             await  self ._remove_cached (instance_connection_string , enable_iam_auth )
367+             await  self ._remove_cached (str ( conn_name ) , enable_iam_auth )
350368            raise 
351369        logger .debug (f"['{ conn_info .conn_name }  ']: Connecting to { ip_address }  :3307" )
352370        # format `user` param for automatic IAM database authn 
@@ -367,18 +385,28 @@ async def connect_async(
367385                    await  conn_info .create_ssl_context (enable_iam_auth ),
368386                    ** kwargs ,
369387                )
370-             # synchronous drivers are blocking and run using executor 
388+             # Create socket with SSLContext for sync drivers 
389+             ctx  =  await  conn_info .create_ssl_context (enable_iam_auth )
390+             sock  =  ctx .wrap_socket (
391+                 socket .create_connection ((ip_address , SERVER_PROXY_PORT )),
392+                 server_hostname = ip_address ,
393+             )
394+             # If this connection was opened using a domain name, then store it 
395+             # for later in case we need to forcibly close it on failover. 
396+             if  conn_info .conn_name .domain_name :
397+                 monitored_cache .sockets .append (sock )
398+             # Synchronous drivers are blocking and run using executor 
371399            connect_partial  =  partial (
372400                connector ,
373401                ip_address ,
374-                 await   conn_info . create_ssl_context ( enable_iam_auth ) ,
402+                 sock ,
375403                ** kwargs ,
376404            )
377405            return  await  self ._loop .run_in_executor (None , connect_partial )
378406
379407        except  Exception :
380408            # with any exception, we attempt a force refresh, then throw the error 
381-             await  cache .force_refresh ()
409+             await  monitored_cache .force_refresh ()
382410            raise 
383411
384412    async  def  _remove_cached (
@@ -456,6 +484,7 @@ async def create_async_connector(
456484    universe_domain : Optional [str ] =  None ,
457485    refresh_strategy : str  |  RefreshStrategy  =  RefreshStrategy .BACKGROUND ,
458486    resolver : type [DefaultResolver ] |  type [DnsResolver ] =  DefaultResolver ,
487+     failover_period : int  =  30 ,
459488) ->  Connector :
460489    """Helper function to create Connector object for asyncio connections. 
461490
@@ -507,6 +536,11 @@ async def create_async_connector(
507536            DnsResolver. 
508537            Default: DefaultResolver 
509538
539+         failover_period (int): The time interval in seconds between each 
540+             attempt to check if a failover has occured for a given instance. 
541+             Must be used with `resolver=DnsResolver` to have any effect. 
542+             Default: 30 
543+ 
510544    Returns: 
511545        A Connector instance configured with running event loop. 
512546    """ 
@@ -525,4 +559,5 @@ async def create_async_connector(
525559        universe_domain = universe_domain ,
526560        refresh_strategy = refresh_strategy ,
527561        resolver = resolver ,
562+         failover_period = failover_period ,
528563    )
0 commit comments