Skip to content

Commit ae471df

Browse files
chore: refactor connect into connect_info (#304)
1 parent 372e401 commit ae471df

File tree

7 files changed

+242
-197
lines changed

7 files changed

+242
-197
lines changed

google/cloud/sql/connector/connector.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,14 @@
2121
InstanceConnectionManager,
2222
IPTypes,
2323
)
24+
import google.cloud.sql.connector.pymysql as pymysql
25+
import google.cloud.sql.connector.pg8000 as pg8000
26+
import google.cloud.sql.connector.pytds as pytds
2427
from google.cloud.sql.connector.utils import generate_keys
2528
from google.auth.credentials import Credentials
2629
from threading import Thread
2730
from typing import Any, Dict, Optional, Type
31+
from functools import partial
2832

2933
logger = logging.getLogger(name=__name__)
3034

@@ -159,6 +163,18 @@ async def connect_async(
159163
)
160164
self._instances[instance_connection_string] = icm
161165

166+
connect_func = {
167+
"pymysql": pymysql.connect,
168+
"pg8000": pg8000.connect,
169+
"pytds": pytds.connect,
170+
}
171+
172+
# only accept supported database drivers
173+
try:
174+
connector = connect_func[driver]
175+
except KeyError:
176+
raise KeyError(f"Driver '{driver}' is not supported.")
177+
162178
if "ip_types" in kwargs:
163179
ip_type = kwargs.pop("ip_types")
164180
logger.warning(
@@ -171,11 +187,23 @@ async def connect_async(
171187
if "connect_timeout" in kwargs:
172188
timeout = kwargs.pop("connect_timeout")
173189

190+
# Host and ssl options come from the certificates and metadata, so we don't
191+
# want the user to specify them.
192+
kwargs.pop("host", None)
193+
kwargs.pop("ssl", None)
194+
kwargs.pop("port", None)
195+
196+
# helper function to wrap in timeout
197+
async def get_connection() -> Any:
198+
instance_data, ip_address = await icm.connect_info(ip_type)
199+
connect_partial = partial(
200+
connector, ip_address, instance_data.context, **kwargs
201+
)
202+
return await self._loop.run_in_executor(None, connect_partial)
203+
174204
# attempt to make connection to Cloud SQL instance for given timeout
175205
try:
176-
return await asyncio.wait_for(
177-
icm.connect(driver, ip_type, **kwargs), timeout
178-
)
206+
return await asyncio.wait_for(get_connection(), timeout)
179207
except asyncio.TimeoutError:
180208
raise TimeoutError(f"Connection timed out after {timeout}s")
181209
except Exception as e:

google/cloud/sql/connector/instance_connection_manager.py

Lines changed: 39 additions & 173 deletions
Original file line numberDiff line numberDiff line change
@@ -35,29 +35,20 @@
3535
from google.auth.credentials import Credentials, with_scopes_if_required
3636
import google.auth.transport.requests
3737
import OpenSSL
38-
import platform
3938
import ssl
40-
import socket
4139
from tempfile import TemporaryDirectory
4240
from typing import (
4341
Any,
4442
Awaitable,
4543
Dict,
4644
Optional,
47-
TYPE_CHECKING,
45+
Tuple,
4846
)
49-
50-
from functools import partial
5147
import logging
5248

53-
if TYPE_CHECKING:
54-
import pymysql
55-
import pg8000
56-
import pytds
5749
logger = logging.getLogger(name=__name__)
5850

5951
APPLICATION_NAME = "cloud-sql-python-connector"
60-
SERVER_PROXY_PORT = 3307
6152

6253

6354
class IPTypes(Enum):
@@ -320,13 +311,15 @@ async def _perform_refresh(self) -> InstanceMetadata:
320311
and a string representing a PEM-encoded certificate authority.
321312
"""
322313
self._refresh_in_progress.set()
323-
logger.debug("Entered _perform_refresh")
314+
logger.debug(
315+
f"['{self._instance_connection_string}']: Entered _perform_refresh"
316+
)
324317

325318
try:
326319
await self._refresh_rate_limiter.acquire()
327320
priv_key, pub_key = await self._keys
328321

329-
logger.debug("Creating context")
322+
logger.debug(f"['{self._instance_connection_string}']: Creating context")
330323

331324
metadata_task = self._loop.create_task(
332325
_get_metadata(
@@ -366,7 +359,9 @@ async def _perform_refresh(self) -> InstanceMetadata:
366359
expiration = token_expiration
367360

368361
except Exception as e:
369-
logger.debug("Error occurred during _perform_refresh.")
362+
logger.debug(
363+
f"['{self._instance_connection_string}']: Error occurred during _perform_refresh."
364+
)
370365
raise e
371366

372367
finally:
@@ -402,17 +397,20 @@ async def _refresh_task(
402397
"""
403398
refresh_task: asyncio.Task
404399
try:
405-
logger.debug("Entering sleep")
400+
logger.debug(f"['{self._instance_connection_string}']: Entering sleep")
406401
if delay > 0:
407402
await asyncio.sleep(delay)
408403
refresh_task = self._loop.create_task(self._perform_refresh())
409404
refresh_data = await refresh_task
410405
except asyncio.CancelledError as e:
411-
logger.debug("Schedule refresh task cancelled.")
406+
logger.debug(
407+
f"['{self._instance_connection_string}']: Schedule refresh task cancelled."
408+
)
412409
raise e
413410
# bad refresh attempt
414411
except Exception as e:
415412
logger.exception(
413+
f"['{self._instance_connection_string}']: "
416414
"An error occurred while performing refresh. "
417415
"Scheduling another refresh attempt immediately",
418416
exc_info=e,
@@ -438,179 +436,47 @@ async def _refresh_task(
438436
scheduled_task = self._loop.create_task(_refresh_task(self, delay))
439437
return scheduled_task
440438

441-
async def connect(
439+
async def connect_info(
442440
self,
443-
driver: str,
444441
ip_type: IPTypes,
445-
**kwargs: Any,
446-
) -> Any:
447-
"""A method that returns a DB-API connection to the database.
442+
) -> Tuple[InstanceMetadata, str]:
443+
"""Retrieve instance metadata and ip address required
444+
for making connection to Cloud SQL instance.
448445
449-
:type driver: str
450-
:param driver: A string representing the driver. e.g. "pymysql"
446+
:type ip_type: IPTypes
447+
:param ip_type: Enum specifying whether to look for public
448+
or private IP address.
451449
452-
:returns: A DB-API connection to the primary IP of the database.
453-
"""
454-
logger.debug("Entered connect method")
455-
456-
# Host and ssl options come from the certificates and metadata, so we don't
457-
# want the user to specify them.
458-
kwargs.pop("host", None)
459-
kwargs.pop("ssl", None)
460-
kwargs.pop("port", None)
450+
:rtype instance_data: InstanceMetadata
451+
:returns: Instance metadata for Cloud SQL instance.
461452
462-
connect_func = {
463-
"pymysql": self._connect_with_pymysql,
464-
"pg8000": self._connect_with_pg8000,
465-
"pytds": self._connect_with_pytds,
466-
}
453+
:rtype ip_address: str
454+
:returns: A string representing the IP address of
455+
the given Cloud SQL instance.
456+
"""
457+
logger.debug(
458+
f"['{self._instance_connection_string}']: Entered connect_info method"
459+
)
467460

468461
instance_data: InstanceMetadata
469462

470463
instance_data = await self._current
471464
ip_address: str = instance_data.get_preferred_ip(ip_type)
472-
473-
try:
474-
connector = connect_func[driver]
475-
except KeyError:
476-
raise KeyError(f"Driver {driver} is not supported.")
477-
478-
connect_partial = partial(
479-
connector, ip_address, instance_data.context, **kwargs
480-
)
481-
482-
return await self._loop.run_in_executor(None, connect_partial)
483-
484-
def _connect_with_pymysql(
485-
self, ip_address: str, ctx: ssl.SSLContext, **kwargs: Any
486-
) -> "pymysql.connections.Connection":
487-
"""Helper function to create a pymysql DB-API connection object.
488-
489-
:type ip_address: str
490-
:param ip_address: A string containing an IP address for the Cloud SQL
491-
instance.
492-
493-
:type ctx: ssl.SSLContext
494-
:param ctx: An SSLContext object created from the Cloud SQL server CA
495-
cert and ephemeral cert.
496-
497-
:rtype: pymysql.Connection
498-
:returns: A PyMySQL Connection object for the Cloud SQL instance.
499-
"""
500-
try:
501-
import pymysql
502-
except ImportError:
503-
raise ImportError(
504-
'Unable to import module "pymysql." Please install and try again.'
505-
)
506-
507-
# Create socket and wrap with context.
508-
sock = ctx.wrap_socket(
509-
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
510-
server_hostname=ip_address,
511-
)
512-
513-
# Create pymysql connection object and hand in pre-made connection
514-
conn = pymysql.Connection(host=ip_address, defer_connect=True, **kwargs)
515-
conn.connect(sock)
516-
return conn
517-
518-
def _connect_with_pg8000(
519-
self, ip_address: str, ctx: ssl.SSLContext, **kwargs: Any
520-
) -> "pg8000.dbapi.Connection":
521-
"""Helper function to create a pg8000 DB-API connection object.
522-
523-
:type ip_address: str
524-
:param ip_address: A string containing an IP address for the Cloud SQL
525-
instance.
526-
527-
:type ctx: ssl.SSLContext
528-
:param ctx: An SSLContext object created from the Cloud SQL server CA
529-
cert and ephemeral cert.
530-
531-
532-
:rtype: pg8000.dbapi.Connection
533-
:returns: A pg8000 Connection object for the Cloud SQL instance.
534-
"""
535-
try:
536-
import pg8000
537-
except ImportError:
538-
raise ImportError(
539-
'Unable to import module "pg8000." Please install and try again.'
540-
)
541-
user = kwargs.pop("user")
542-
db = kwargs.pop("db")
543-
passwd = kwargs.pop("password", None)
544-
setattr(ctx, "request_ssl", False)
545-
return pg8000.dbapi.connect(
546-
user,
547-
database=db,
548-
password=passwd,
549-
host=ip_address,
550-
port=SERVER_PROXY_PORT,
551-
ssl_context=ctx,
552-
**kwargs,
553-
)
554-
555-
def _connect_with_pytds(
556-
self, ip_address: str, ctx: ssl.SSLContext, **kwargs: Any
557-
) -> "pytds.Connection":
558-
"""Helper function to create a pytds DB-API connection object.
559-
560-
:type ip_address: str
561-
:param ip_address: A string containing an IP address for the Cloud SQL
562-
instance.
563-
564-
:type ctx: ssl.SSLContext
565-
:param ctx: An SSLContext object created from the Cloud SQL server CA
566-
cert and ephemeral cert.
567-
568-
569-
:rtype: pytds.Connection
570-
:returns: A pytds Connection object for the Cloud SQL instance.
571-
"""
572-
try:
573-
import pytds
574-
except ImportError:
575-
raise ImportError(
576-
'Unable to import module "pytds." Please install and try again.'
577-
)
578-
579-
db = kwargs.pop("db", None)
580-
581-
# Create socket and wrap with context.
582-
sock = ctx.wrap_socket(
583-
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
584-
server_hostname=ip_address,
585-
)
586-
if kwargs.pop("active_directory_auth", False):
587-
if platform.system() == "Windows":
588-
# Ignore username and password if using active directory auth
589-
server_name = kwargs.pop("server_name")
590-
return pytds.connect(
591-
database=db,
592-
auth=pytds.login.SspiAuth(port=1433, server_name=server_name),
593-
sock=sock,
594-
**kwargs,
595-
)
596-
else:
597-
raise PlatformNotSupportedError(
598-
"Active Directory authentication is currently only supported on Windows."
599-
)
600-
601-
user = kwargs.pop("user")
602-
passwd = kwargs.pop("password")
603-
return pytds.connect(
604-
ip_address, database=db, user=user, password=passwd, sock=sock, **kwargs
605-
)
465+
return instance_data, ip_address
606466

607467
async def close(self) -> None:
608468
"""Cleanup function to make sure ClientSession is closed and tasks have
609469
finished to have a graceful exit.
610470
"""
611-
logger.debug("Waiting for _current to be cancelled")
471+
logger.debug(
472+
f"['{self._instance_connection_string}']: Waiting for _current to be cancelled"
473+
)
612474
self._current.cancel()
613-
logger.debug("Waiting for _next to be cancelled")
475+
logger.debug(
476+
f"['{self._instance_connection_string}']: Waiting for _next to be cancelled"
477+
)
614478
self._next.cancel()
615-
logger.debug("Waiting for _client_session to close")
479+
logger.debug(
480+
f"['{self._instance_connection_string}']: Waiting for _client_session to close"
481+
)
616482
await self._client_session.close()
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import ssl
2+
from typing import Any, TYPE_CHECKING
3+
4+
SERVER_PROXY_PORT = 3307
5+
6+
if TYPE_CHECKING:
7+
import pg8000
8+
9+
10+
def connect(
11+
ip_address: str, ctx: ssl.SSLContext, **kwargs: Any
12+
) -> "pg8000.dbapi.Connection":
13+
"""Helper function to create a pg8000 DB-API connection object.
14+
15+
:type ip_address: str
16+
:param ip_address: A string containing an IP address for the Cloud SQL
17+
instance.
18+
19+
:type ctx: ssl.SSLContext
20+
:param ctx: An SSLContext object created from the Cloud SQL server CA
21+
cert and ephemeral cert.
22+
23+
24+
:rtype: pg8000.dbapi.Connection
25+
:returns: A pg8000 Connection object for the Cloud SQL instance.
26+
"""
27+
try:
28+
import pg8000
29+
except ImportError:
30+
raise ImportError(
31+
'Unable to import module "pg8000." Please install and try again.'
32+
)
33+
user = kwargs.pop("user")
34+
db = kwargs.pop("db")
35+
passwd = kwargs.pop("password", None)
36+
setattr(ctx, "request_ssl", False)
37+
return pg8000.dbapi.connect(
38+
user,
39+
database=db,
40+
password=passwd,
41+
host=ip_address,
42+
port=SERVER_PROXY_PORT,
43+
ssl_context=ctx,
44+
**kwargs,
45+
)

0 commit comments

Comments
 (0)