3535from google .auth .credentials import Credentials , with_scopes_if_required
3636import google .auth .transport .requests
3737import OpenSSL
38- import platform
3938import ssl
40- import socket
4139from tempfile import TemporaryDirectory
4240from typing import (
4341 Any ,
4442 Awaitable ,
4543 Dict ,
4644 Optional ,
47- TYPE_CHECKING ,
45+ Tuple ,
4846)
49-
50- from functools import partial
5147import logging
5248
53- if TYPE_CHECKING :
54- import pymysql
55- import pg8000
56- import pytds
5749logger = logging .getLogger (name = __name__ )
5850
5951APPLICATION_NAME = "cloud-sql-python-connector"
60- SERVER_PROXY_PORT = 3307
6152
6253
6354class IPTypes (Enum ):
@@ -320,13 +311,15 @@ async def _perform_refresh(self) -> InstanceMetadata:
320311 and a string representing a PEM-encoded certificate authority.
321312 """
322313 self ._refresh_in_progress .set ()
323- logger .debug ("Entered _perform_refresh" )
314+ logger .debug (
315+ f"['{ self ._instance_connection_string } ']: Entered _perform_refresh"
316+ )
324317
325318 try :
326319 await self ._refresh_rate_limiter .acquire ()
327320 priv_key , pub_key = await self ._keys
328321
329- logger .debug (" Creating context" )
322+ logger .debug (f"[' { self . _instance_connection_string } ']: Creating context" )
330323
331324 metadata_task = self ._loop .create_task (
332325 _get_metadata (
@@ -366,7 +359,9 @@ async def _perform_refresh(self) -> InstanceMetadata:
366359 expiration = token_expiration
367360
368361 except Exception as e :
369- logger .debug ("Error occurred during _perform_refresh." )
362+ logger .debug (
363+ f"['{ self ._instance_connection_string } ']: Error occurred during _perform_refresh."
364+ )
370365 raise e
371366
372367 finally :
@@ -402,17 +397,20 @@ async def _refresh_task(
402397 """
403398 refresh_task : asyncio .Task
404399 try :
405- logger .debug (" Entering sleep" )
400+ logger .debug (f"[' { self . _instance_connection_string } ']: Entering sleep" )
406401 if delay > 0 :
407402 await asyncio .sleep (delay )
408403 refresh_task = self ._loop .create_task (self ._perform_refresh ())
409404 refresh_data = await refresh_task
410405 except asyncio .CancelledError as e :
411- logger .debug ("Schedule refresh task cancelled." )
406+ logger .debug (
407+ f"['{ self ._instance_connection_string } ']: Schedule refresh task cancelled."
408+ )
412409 raise e
413410 # bad refresh attempt
414411 except Exception as e :
415412 logger .exception (
413+ f"['{ self ._instance_connection_string } ']: "
416414 "An error occurred while performing refresh. "
417415 "Scheduling another refresh attempt immediately" ,
418416 exc_info = e ,
@@ -438,179 +436,47 @@ async def _refresh_task(
438436 scheduled_task = self ._loop .create_task (_refresh_task (self , delay ))
439437 return scheduled_task
440438
441- async def connect (
439+ async def connect_info (
442440 self ,
443- driver : str ,
444441 ip_type : IPTypes ,
445- ** kwargs : Any ,
446- ) -> Any :
447- """A method that returns a DB-API connection to the database .
442+ ) -> Tuple [ InstanceMetadata , str ]:
443+ """Retrieve instance metadata and ip address required
444+ for making connection to Cloud SQL instance .
448445
449- :type driver: str
450- :param driver: A string representing the driver. e.g. "pymysql"
446+ :type ip_type: IPTypes
447+ :param ip_type: Enum specifying whether to look for public
448+ or private IP address.
451449
452- :returns: A DB-API connection to the primary IP of the database.
453- """
454- logger .debug ("Entered connect method" )
455-
456- # Host and ssl options come from the certificates and metadata, so we don't
457- # want the user to specify them.
458- kwargs .pop ("host" , None )
459- kwargs .pop ("ssl" , None )
460- kwargs .pop ("port" , None )
450+ :rtype instance_data: InstanceMetadata
451+ :returns: Instance metadata for Cloud SQL instance.
461452
462- connect_func = {
463- "pymysql" : self ._connect_with_pymysql ,
464- "pg8000" : self ._connect_with_pg8000 ,
465- "pytds" : self ._connect_with_pytds ,
466- }
453+ :rtype ip_address: str
454+ :returns: A string representing the IP address of
455+ the given Cloud SQL instance.
456+ """
457+ logger .debug (
458+ f"['{ self ._instance_connection_string } ']: Entered connect_info method"
459+ )
467460
468461 instance_data : InstanceMetadata
469462
470463 instance_data = await self ._current
471464 ip_address : str = instance_data .get_preferred_ip (ip_type )
472-
473- try :
474- connector = connect_func [driver ]
475- except KeyError :
476- raise KeyError (f"Driver { driver } is not supported." )
477-
478- connect_partial = partial (
479- connector , ip_address , instance_data .context , ** kwargs
480- )
481-
482- return await self ._loop .run_in_executor (None , connect_partial )
483-
484- def _connect_with_pymysql (
485- self , ip_address : str , ctx : ssl .SSLContext , ** kwargs : Any
486- ) -> "pymysql.connections.Connection" :
487- """Helper function to create a pymysql DB-API connection object.
488-
489- :type ip_address: str
490- :param ip_address: A string containing an IP address for the Cloud SQL
491- instance.
492-
493- :type ctx: ssl.SSLContext
494- :param ctx: An SSLContext object created from the Cloud SQL server CA
495- cert and ephemeral cert.
496-
497- :rtype: pymysql.Connection
498- :returns: A PyMySQL Connection object for the Cloud SQL instance.
499- """
500- try :
501- import pymysql
502- except ImportError :
503- raise ImportError (
504- 'Unable to import module "pymysql." Please install and try again.'
505- )
506-
507- # Create socket and wrap with context.
508- sock = ctx .wrap_socket (
509- socket .create_connection ((ip_address , SERVER_PROXY_PORT )),
510- server_hostname = ip_address ,
511- )
512-
513- # Create pymysql connection object and hand in pre-made connection
514- conn = pymysql .Connection (host = ip_address , defer_connect = True , ** kwargs )
515- conn .connect (sock )
516- return conn
517-
518- def _connect_with_pg8000 (
519- self , ip_address : str , ctx : ssl .SSLContext , ** kwargs : Any
520- ) -> "pg8000.dbapi.Connection" :
521- """Helper function to create a pg8000 DB-API connection object.
522-
523- :type ip_address: str
524- :param ip_address: A string containing an IP address for the Cloud SQL
525- instance.
526-
527- :type ctx: ssl.SSLContext
528- :param ctx: An SSLContext object created from the Cloud SQL server CA
529- cert and ephemeral cert.
530-
531-
532- :rtype: pg8000.dbapi.Connection
533- :returns: A pg8000 Connection object for the Cloud SQL instance.
534- """
535- try :
536- import pg8000
537- except ImportError :
538- raise ImportError (
539- 'Unable to import module "pg8000." Please install and try again.'
540- )
541- user = kwargs .pop ("user" )
542- db = kwargs .pop ("db" )
543- passwd = kwargs .pop ("password" , None )
544- setattr (ctx , "request_ssl" , False )
545- return pg8000 .dbapi .connect (
546- user ,
547- database = db ,
548- password = passwd ,
549- host = ip_address ,
550- port = SERVER_PROXY_PORT ,
551- ssl_context = ctx ,
552- ** kwargs ,
553- )
554-
555- def _connect_with_pytds (
556- self , ip_address : str , ctx : ssl .SSLContext , ** kwargs : Any
557- ) -> "pytds.Connection" :
558- """Helper function to create a pytds DB-API connection object.
559-
560- :type ip_address: str
561- :param ip_address: A string containing an IP address for the Cloud SQL
562- instance.
563-
564- :type ctx: ssl.SSLContext
565- :param ctx: An SSLContext object created from the Cloud SQL server CA
566- cert and ephemeral cert.
567-
568-
569- :rtype: pytds.Connection
570- :returns: A pytds Connection object for the Cloud SQL instance.
571- """
572- try :
573- import pytds
574- except ImportError :
575- raise ImportError (
576- 'Unable to import module "pytds." Please install and try again.'
577- )
578-
579- db = kwargs .pop ("db" , None )
580-
581- # Create socket and wrap with context.
582- sock = ctx .wrap_socket (
583- socket .create_connection ((ip_address , SERVER_PROXY_PORT )),
584- server_hostname = ip_address ,
585- )
586- if kwargs .pop ("active_directory_auth" , False ):
587- if platform .system () == "Windows" :
588- # Ignore username and password if using active directory auth
589- server_name = kwargs .pop ("server_name" )
590- return pytds .connect (
591- database = db ,
592- auth = pytds .login .SspiAuth (port = 1433 , server_name = server_name ),
593- sock = sock ,
594- ** kwargs ,
595- )
596- else :
597- raise PlatformNotSupportedError (
598- "Active Directory authentication is currently only supported on Windows."
599- )
600-
601- user = kwargs .pop ("user" )
602- passwd = kwargs .pop ("password" )
603- return pytds .connect (
604- ip_address , database = db , user = user , password = passwd , sock = sock , ** kwargs
605- )
465+ return instance_data , ip_address
606466
607467 async def close (self ) -> None :
608468 """Cleanup function to make sure ClientSession is closed and tasks have
609469 finished to have a graceful exit.
610470 """
611- logger .debug ("Waiting for _current to be cancelled" )
471+ logger .debug (
472+ f"['{ self ._instance_connection_string } ']: Waiting for _current to be cancelled"
473+ )
612474 self ._current .cancel ()
613- logger .debug ("Waiting for _next to be cancelled" )
475+ logger .debug (
476+ f"['{ self ._instance_connection_string } ']: Waiting for _next to be cancelled"
477+ )
614478 self ._next .cancel ()
615- logger .debug ("Waiting for _client_session to close" )
479+ logger .debug (
480+ f"['{ self ._instance_connection_string } ']: Waiting for _client_session to close"
481+ )
616482 await self ._client_session .close ()
0 commit comments