Skip to content

Commit cef1ed1

Browse files
feat: expose Connector object to allow multiple connection configurations. (#210)
* create connector object * update default connect method
1 parent 117e894 commit cef1ed1

File tree

2 files changed

+110
-69
lines changed

2 files changed

+110
-69
lines changed

google/cloud/sql/connector/connector.py

Lines changed: 107 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -23,43 +23,118 @@
2323
from google.cloud.sql.connector.utils import generate_keys
2424

2525
from threading import Thread
26-
from typing import Any, Dict, Optional
26+
from typing import Any, Dict
2727

28-
# This thread is used to background processing
29-
_thread: Optional[Thread] = None
30-
_loop: Optional[asyncio.AbstractEventLoop] = None
31-
_keys: Optional[concurrent.futures.Future] = None
28+
logger = logging.getLogger(name=__name__)
3229

33-
_instances: Dict[str, InstanceConnectionManager] = {}
30+
_default_connector = None
3431

35-
logger = logging.getLogger(name=__name__)
3632

33+
class Connector:
34+
"""A class to configure and create connections to Cloud SQL instances.
3735
38-
def _get_loop() -> asyncio.AbstractEventLoop:
39-
global _loop
40-
if _loop is None:
41-
_loop = asyncio.new_event_loop()
42-
_thread = Thread(target=_loop.run_forever, daemon=True)
43-
_thread.start()
44-
return _loop
36+
:type ip_types: IPTypes
37+
:param ip_types
38+
The IP type (public or private) used to connect. IP types
39+
can be either IPTypes.PUBLIC or IPTypes.PRIVATE.
4540
41+
:type enable_iam_auth: bool
42+
:param enable_iam_auth
43+
Enables IAM based authentication (Postgres only).
4644
47-
def _get_keys(loop: asyncio.AbstractEventLoop) -> concurrent.futures.Future:
48-
global _keys
49-
if _keys is None:
50-
_keys = asyncio.run_coroutine_threadsafe(generate_keys(), loop)
51-
return _keys
45+
:type timeout: int
46+
:param timeout:
47+
The time limit for a connection before raising a TimeoutError.
5248
49+
"""
5350

54-
def connect(
55-
instance_connection_string: str,
56-
driver: str,
57-
ip_types: IPTypes = IPTypes.PUBLIC,
58-
enable_iam_auth: bool = False,
59-
**kwargs: Any
60-
) -> Any:
61-
"""Prepares and returns a database connection object and starts a
62-
background thread to refresh the certificates and metadata.
51+
def __init__(
52+
self,
53+
ip_types: IPTypes = IPTypes.PUBLIC,
54+
enable_iam_auth: bool = False,
55+
timeout: int = 30,
56+
) -> None:
57+
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
58+
self._thread: Thread = Thread(target=self._loop.run_forever, daemon=True)
59+
self._thread.start()
60+
self._keys: concurrent.futures.Future = asyncio.run_coroutine_threadsafe(
61+
generate_keys(), self._loop
62+
)
63+
self._instances: Dict[str, InstanceConnectionManager] = {}
64+
65+
# set default params for connections
66+
self._timeout = timeout
67+
self._enable_iam_auth = enable_iam_auth
68+
self._ip_types = ip_types
69+
70+
def connect(
71+
self, instance_connection_string: str, driver: str, **kwargs: Any
72+
) -> Any:
73+
"""Prepares and returns a database connection object and starts a
74+
background thread to refresh the certificates and metadata.
75+
76+
:type instance_connection_string: str
77+
:param instance_connection_string:
78+
A string containing the GCP project name, region name, and instance
79+
name separated by colons.
80+
81+
Example: example-proj:example-region-us6:example-instance
82+
83+
:type driver: str
84+
:param: driver:
85+
A string representing the driver to connect with. Supported drivers are
86+
pymysql, pg8000, and pytds.
87+
88+
:param kwargs:
89+
Pass in any driver-specific arguments needed to connect to the Cloud
90+
SQL instance.
91+
92+
:rtype: Connection
93+
:returns:
94+
A DB-API connection to the specified Cloud SQL instance.
95+
"""
96+
97+
# Initiate event loop and run in background thread.
98+
#
99+
# Create an InstanceConnectionManager object from the connection string.
100+
# The InstanceConnectionManager should verify arguments.
101+
#
102+
# Use the InstanceConnectionManager to establish an SSL Connection.
103+
#
104+
# Return a DBAPI connection
105+
106+
if instance_connection_string in self._instances:
107+
icm = self._instances[instance_connection_string]
108+
else:
109+
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
110+
icm = InstanceConnectionManager(
111+
instance_connection_string,
112+
driver,
113+
self._keys,
114+
self._loop,
115+
enable_iam_auth,
116+
)
117+
self._instances[instance_connection_string] = icm
118+
119+
ip_types = kwargs.pop("ip_types", self._ip_types)
120+
if "timeout" in kwargs:
121+
return icm.connect(driver, ip_types, **kwargs)
122+
elif "connect_timeout" in kwargs:
123+
timeout = kwargs["connect_timeout"]
124+
else:
125+
timeout = self._timeout
126+
try:
127+
return icm.connect(driver, ip_types, timeout, **kwargs)
128+
except Exception as e:
129+
# with any other exception, we attempt a force refresh, then throw the error
130+
icm.force_refresh()
131+
raise (e)
132+
133+
134+
def connect(instance_connection_string: str, driver: str, **kwargs: Any) -> Any:
135+
"""Uses a Connector object with default settings and returns a database
136+
connection object with a background thread to refresh the certificates and metadata.
137+
For more advanced configurations, callers should instantiate Connector on their own.
63138
64139
:type instance_connection_string: str
65140
:param instance_connection_string:
@@ -73,14 +148,6 @@ def connect(
73148
A string representing the driver to connect with. Supported drivers are
74149
pymysql, pg8000, and pytds.
75150
76-
:type ip_types: IPTypes
77-
The IP type (public or private) used to connect. IP types
78-
can be either IPTypes.PUBLIC or IPTypes.PRIVATE.
79-
80-
:param enable_iam_auth
81-
Enables IAM based authentication (Postgres only).
82-
:type enable_iam_auth: bool
83-
84151
:param kwargs:
85152
Pass in any driver-specific arguments needed to connect to the Cloud
86153
SQL instance.
@@ -89,35 +156,7 @@ def connect(
89156
:returns:
90157
A DB-API connection to the specified Cloud SQL instance.
91158
"""
92-
93-
# Initiate event loop and run in background thread.
94-
#
95-
# Create an InstanceConnectionManager object from the connection string.
96-
# The InstanceConnectionManager should verify arguments.
97-
#
98-
# Use the InstanceConnectionManager to establish an SSL Connection.
99-
#
100-
# Return a DBAPI connection
101-
102-
loop = _get_loop()
103-
if instance_connection_string in _instances:
104-
icm = _instances[instance_connection_string]
105-
else:
106-
keys = _get_keys(loop)
107-
icm = InstanceConnectionManager(
108-
instance_connection_string, driver, keys, loop, enable_iam_auth
109-
)
110-
_instances[instance_connection_string] = icm
111-
112-
if "timeout" in kwargs:
113-
return icm.connect(driver, ip_types, **kwargs)
114-
elif "connect_timeout" in kwargs:
115-
timeout = kwargs["connect_timeout"]
116-
else:
117-
timeout = 30 # 30s
118-
try:
119-
return icm.connect(driver, ip_types, timeout, **kwargs)
120-
except Exception as e:
121-
# with any other exception, we attempt a force refresh, then throw the error
122-
icm.force_refresh()
123-
raise (e)
159+
global _default_connector
160+
if _default_connector is None:
161+
_default_connector = Connector()
162+
return _default_connector.connect(instance_connection_string, driver, **kwargs)

tests/unit/test_connector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ async def timeout_stub(*args: Any, **kwargs: Any) -> None:
4343

4444
mock_instances = {}
4545
mock_instances[connect_string] = icm
46-
with patch.dict(connector._instances, mock_instances):
46+
mock_connector = connector.Connector()
47+
connector._default_connector = mock_connector
48+
with patch.dict(mock_connector._instances, mock_instances):
4749
pytest.raises(
4850
TimeoutError,
4951
connector.connect,

0 commit comments

Comments
 (0)