Skip to content

Commit 265218d

Browse files
refactor: rename Instance to RefreshAheadCache (#1068)
1 parent 79a5426 commit 265218d

File tree

6 files changed

+127
-138
lines changed

6 files changed

+127
-138
lines changed

google/cloud/sql/connector/connector.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@
3232
from google.cloud.sql.connector.client import CloudSQLClient
3333
from google.cloud.sql.connector.exceptions import ConnectorLoopError
3434
from google.cloud.sql.connector.exceptions import DnsNameResolutionError
35-
from google.cloud.sql.connector.instance import Instance
3635
from google.cloud.sql.connector.instance import IPTypes
36+
from google.cloud.sql.connector.instance import RefreshAheadCache
3737
import google.cloud.sql.connector.pg8000 as pg8000
3838
import google.cloud.sql.connector.pymysql as pymysql
3939
import google.cloud.sql.connector.pytds as pytds
@@ -113,7 +113,7 @@ def __init__(
113113
asyncio.run_coroutine_threadsafe(generate_keys(), self._loop),
114114
loop=self._loop,
115115
)
116-
self._instances: Dict[str, Instance] = {}
116+
self._cache: Dict[str, RefreshAheadCache] = {}
117117
self._client: Optional[CloudSQLClient] = None
118118

119119
# initialize credentials
@@ -255,23 +255,23 @@ async def connect_async(
255255
driver=driver,
256256
)
257257
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
258-
if instance_connection_string in self._instances:
259-
instance = self._instances[instance_connection_string]
260-
if enable_iam_auth != instance._enable_iam_auth:
258+
if instance_connection_string in self._cache:
259+
cache = self._cache[instance_connection_string]
260+
if enable_iam_auth != cache._enable_iam_auth:
261261
raise ValueError(
262262
f"connect() called with 'enable_iam_auth={enable_iam_auth}', "
263-
f"but previously used 'enable_iam_auth={instance._enable_iam_auth}'. "
263+
f"but previously used 'enable_iam_auth={cache._enable_iam_auth}'. "
264264
"If you require both for your use case, please use a new "
265265
"connector.Connector object."
266266
)
267267
else:
268-
instance = Instance(
268+
cache = RefreshAheadCache(
269269
instance_connection_string,
270270
self._client,
271271
self._keys,
272272
enable_iam_auth,
273273
)
274-
self._instances[instance_connection_string] = instance
274+
self._cache[instance_connection_string] = cache
275275

276276
connect_func = {
277277
"pymysql": pymysql.connect,
@@ -300,7 +300,7 @@ async def connect_async(
300300

301301
# attempt to make connection to Cloud SQL instance
302302
try:
303-
instance_data, ip_address = await instance.connect_info(ip_type)
303+
instance_data, ip_address = await cache.connect_info(ip_type)
304304
# resolve DNS name into IP address for PSC
305305
if ip_type.value == "PSC":
306306
addr_info = await self._loop.getaddrinfo(
@@ -339,7 +339,7 @@ async def connect_async(
339339

340340
except Exception:
341341
# with any exception, we attempt a force refresh, then throw the error
342-
await instance.force_refresh()
342+
await cache.force_refresh()
343343
raise
344344

345345
def __enter__(self) -> Any:
@@ -385,11 +385,9 @@ def close(self) -> None:
385385
self._thread.join()
386386

387387
async def close_async(self) -> None:
388-
"""Helper function to cancel Instances' tasks
388+
"""Helper function to cancel the cache's tasks
389389
and close aiohttp.ClientSession."""
390-
await asyncio.gather(
391-
*[instance.close() for instance in self._instances.values()]
392-
)
390+
await asyncio.gather(*[cache.close() for cache in self._cache.values()])
393391
if self._client:
394392
await self._client.close()
395393

google/cloud/sql/connector/instance.py

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -144,32 +144,13 @@ def get_preferred_ip(self, ip_type: IPTypes) -> str:
144144
)
145145

146146

147-
class Instance:
148-
"""A class to manage the details of the connection to a Cloud SQL
149-
instance, including refreshing the credentials.
150-
151-
:param instance_connection_string:
152-
The Google Cloud SQL Instance's connection
153-
string.
154-
:type instance_connection_string: str
155-
156-
:param enable_iam_auth
157-
Enables automatic IAM database authentication for Postgres or MySQL
158-
instances.
159-
:type enable_iam_auth: bool
160-
"""
161-
162-
_enable_iam_auth: bool
163-
_keys: asyncio.Future
164-
_instance_connection_string: str
165-
_instance: str
166-
_project: str
167-
_region: str
147+
class RefreshAheadCache:
148+
"""Cache that refreshes connection info in the background prior to expiration.
168149
169-
_refresh_rate_limiter: AsyncRateLimiter
170-
_refresh_in_progress: asyncio.locks.Event
171-
_current: asyncio.Task # task wraps coroutine that returns ConnectionInfo
172-
_next: asyncio.Task # task wraps coroutine that returns another task
150+
Background tasks are used to schedule refresh attempts to get a new
151+
ephemeral certificate and Cloud SQL metadata (IP addresses, etc.) ahead of
152+
expiration.
153+
"""
173154

174155
def __init__(
175156
self,
@@ -178,6 +159,18 @@ def __init__(
178159
keys: asyncio.Future,
179160
enable_iam_auth: bool = False,
180161
) -> None:
162+
"""Initializes a RefreshAheadCache instance.
163+
164+
Args:
165+
instance_connection_string (str): The Cloud SQL Instance's
166+
connection string (also known as an instance connection name).
167+
client (CloudSQLClient): The Cloud SQL Client instance.
168+
keys (asyncio.Future): A future to the client's public-private key
169+
pair.
170+
enable_iam_auth (bool): Enables automatic IAM database authentication
171+
(Postgres and MySQL) as the default authentication method for all
172+
connections.
173+
"""
181174
# validate and parse instance connection name
182175
self._project, self._region, self._instance = _parse_instance_connection_name(
183176
instance_connection_string
@@ -192,8 +185,8 @@ def __init__(
192185
rate=1 / 30,
193186
)
194187
self._refresh_in_progress = asyncio.locks.Event()
195-
self._current = self._schedule_refresh(0)
196-
self._next = self._current
188+
self._current: asyncio.Task = self._schedule_refresh(0)
189+
self._next: asyncio.Task = self._current
197190

198191
async def force_refresh(self) -> None:
199192
"""
@@ -211,10 +204,11 @@ async def _perform_refresh(self) -> ConnectionInfo:
211204
"""Retrieves instance metadata and ephemeral certificate from the
212205
Cloud SQL Instance.
213206
214-
:rtype: ConnectionInfo
215-
:returns: A dataclass containing a string representing the ephemeral certificate, a dict
216-
containing the instances IP adresses, a string representing a PEM-encoded private key
217-
and a string representing a PEM-encoded certificate authority.
207+
Returns:
208+
A ConnectionInfo instance containing a string representing the
209+
ephemeral certificate, a dict containing the instances IP adresses,
210+
a string representing a PEM-encoded private key and a string
211+
representing a PEM-encoded certificate authority.
218212
"""
219213
self._refresh_in_progress.set()
220214
logger.debug(
@@ -290,15 +284,14 @@ def _schedule_refresh(self, delay: int) -> asyncio.Task:
290284
"""
291285
Schedule task to sleep and then perform refresh to get ConnectionInfo.
292286
293-
:type delay: int
294-
:param delay
295-
Time in seconds to sleep before running _perform_refresh.
287+
Args:
288+
delay (int): Time in seconds to sleep before performing a refresh.
296289
297-
:rtype: asyncio.Task
298-
:returns: A Task representing the scheduled _perform_refresh.
290+
Returns:
291+
An asyncio.Task representing the scheduled refresh.
299292
"""
300293

301-
async def _refresh_task(self: Instance, delay: int) -> ConnectionInfo:
294+
async def _refresh_task(self: RefreshAheadCache, delay: int) -> ConnectionInfo:
302295
"""
303296
A coroutine that sleeps for the specified amount of time before
304297
running _perform_refresh.
@@ -349,16 +342,14 @@ async def connect_info(
349342
"""Retrieve instance metadata and ip address required
350343
for making connection to Cloud SQL instance.
351344
352-
:type ip_type: IPTypes
353-
:param ip_type: Enum specifying whether to look for public
354-
or private IP address.
355-
356-
:rtype instance_data: ConnectionInfo
357-
:returns: Instance metadata for Cloud SQL instance.
345+
Args:
346+
ip_type (IPTypes): Enum specifying type of IP address to lookup and
347+
use for connection.
358348
359-
:rtype ip_address: str
360-
:returns: A string representing the IP address of
361-
the given Cloud SQL instance.
349+
Returns:
350+
A tuple with the first item being the ConnectionInfo instance for
351+
establishing the connection, and the second item being the IP
352+
address of the Cloud SQL instance matching the specified IP type.
362353
"""
363354
logger.debug(
364355
f"['{self._instance_connection_string}']: Entered connect_info method"

tests/conftest.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from unit.mocks import FakeCSQLInstance # type: ignore
2626

2727
from google.cloud.sql.connector.client import CloudSQLClient
28-
from google.cloud.sql.connector.instance import Instance
28+
from google.cloud.sql.connector.instance import RefreshAheadCache
2929
from google.cloud.sql.connector.utils import generate_keys
3030

3131
SCOPES = ["https://www.googleapis.com/auth/sqlservice.admin"]
@@ -137,12 +137,12 @@ async def fake_client(
137137

138138

139139
@pytest.fixture
140-
async def instance(fake_client: CloudSQLClient) -> AsyncGenerator[Instance, None]:
140+
async def cache(fake_client: CloudSQLClient) -> AsyncGenerator[RefreshAheadCache, None]:
141141
keys = asyncio.create_task(generate_keys())
142-
instance = Instance(
142+
cache = RefreshAheadCache(
143143
"test-project:test-region:test-instance",
144144
client=fake_client,
145145
keys=keys,
146146
)
147-
yield instance
148-
await instance.close()
147+
yield cache
148+
await cache.close()

tests/system/test_connector_object.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,11 @@ def test_multiple_connectors() -> None:
8383
conn.execute(sqlalchemy.text("SELECT 1"))
8484

8585
instance_connection_string = os.environ["MYSQL_CONNECTION_NAME"]
86-
assert instance_connection_string in first_connector._instances
87-
assert instance_connection_string in second_connector._instances
86+
assert instance_connection_string in first_connector._cache
87+
assert instance_connection_string in second_connector._cache
8888
assert (
89-
first_connector._instances[instance_connection_string]
90-
!= second_connector._instances[instance_connection_string]
89+
first_connector._cache[instance_connection_string]
90+
!= second_connector._cache[instance_connection_string]
9191
)
9292
except Exception as e:
9393
logging.exception("Failed to connect with multiple Connector objects!", e)

tests/unit/test_connector.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,18 +25,18 @@
2525
from google.cloud.sql.connector import IPTypes
2626
from google.cloud.sql.connector.client import CloudSQLClient
2727
from google.cloud.sql.connector.exceptions import ConnectorLoopError
28-
from google.cloud.sql.connector.instance import Instance
28+
from google.cloud.sql.connector.instance import RefreshAheadCache
2929

3030

3131
def test_connect_enable_iam_auth_error(
32-
fake_credentials: Credentials, instance: Instance
32+
fake_credentials: Credentials, cache: RefreshAheadCache
3333
) -> None:
3434
"""Test that calling connect() with different enable_iam_auth
3535
argument values throws error."""
3636
connect_string = "test-project:test-region:test-instance"
3737
connector = Connector(credentials=fake_credentials)
38-
# set instance
39-
connector._instances[connect_string] = instance
38+
# set cache
39+
connector._cache[connect_string] = cache
4040
# try to connect using enable_iam_auth=True, should raise error
4141
with pytest.raises(ValueError) as exc_info:
4242
connector.connect(connect_string, "pg8000", enable_iam_auth=True)
@@ -46,8 +46,8 @@ def test_connect_enable_iam_auth_error(
4646
"If you require both for your use case, please use a new "
4747
"connector.Connector object."
4848
)
49-
# remove instance to avoid destructor warnings
50-
connector._instances = {}
49+
# remove cache entrry to avoid destructor warnings
50+
connector._cache = {}
5151

5252

5353
def test_connect_with_unsupported_driver(fake_credentials: Credentials) -> None:

0 commit comments

Comments
 (0)