Skip to content

Commit b390fac

Browse files
authored
feat: add rate limiter and force refresh function (#146)
* add basic rate limiter using event queue * add unit tests * use rate limiter in perform_refresh * add force refresh method * use faster rate limiter for perform_refresh tests * remove initial delay in rate limiter * call force refresh when connect attempt fails * use token bucket algorithm * address review comments * use asyncio event to indicate when refresh is in progress * use semaphore instead of queue in rate limiter * address review comments * run black * update type annotations * address review comments * use asyncio time instead of time.time() * update rate limiter implementation * update docstrings * add docstring for force_refresh
1 parent 2c35c06 commit b390fac

File tree

5 files changed

+244
-15
lines changed

5 files changed

+244
-15
lines changed

google/cloud/sql/connector/connector.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616
import asyncio
1717
import concurrent
18+
import logging
1819
from google.cloud.sql.connector.instance_connection_manager import (
1920
InstanceConnectionManager,
2021
IPTypes,
@@ -31,6 +32,8 @@
3132

3233
_instances: Dict[str, InstanceConnectionManager] = {}
3334

35+
logger = logging.getLogger(name=__name__)
36+
3437

3538
def _get_loop() -> asyncio.AbstractEventLoop:
3639
global _loop
@@ -112,5 +115,9 @@ def connect(
112115
timeout = kwargs["connect_timeout"]
113116
else:
114117
timeout = 30 # 30s
115-
116-
return icm.connect(driver, ip_types, timeout, **kwargs)
118+
try:
119+
return icm.connect(driver, ip_types, timeout, **kwargs)
120+
except Exception as e:
121+
# with any other exception, we attempt a force refresh, then throw the error
122+
icm.force_refresh()
123+
raise (e)

google/cloud/sql/connector/instance_connection_manager.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616

1717
# Custom utils import
18+
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
1819
from google.cloud.sql.connector.refresh_utils import _get_ephemeral, _get_metadata
1920
from google.cloud.sql.connector.utils import write_to_file
2021
from google.cloud.sql.connector.version import __version__ as version
@@ -218,8 +219,9 @@ def _client_session(self) -> aiohttp.ClientSession:
218219
_project: str
219220
_region: str
220221

221-
_current: asyncio.Task
222-
_next: asyncio.Task
222+
_refresh_in_progress: asyncio.locks.Event
223+
_current: asyncio.Task # task wraps coroutine that returns InstanceMetadata
224+
_next: asyncio.Task # task wraps coroutine that returns another task
223225

224226
def __init__(
225227
self,
@@ -250,8 +252,13 @@ def __init__(
250252
self._keys = asyncio.wrap_future(keys, loop=self._loop)
251253
self._auth_init()
252254

255+
self._refresh_rate_limiter = AsyncRateLimiter(
256+
max_capacity=2, rate=1 / 30, loop=self._loop
257+
)
258+
253259
async def _set_instance_data() -> None:
254260
logger.debug("Updating instance data")
261+
self._refresh_in_progress = asyncio.locks.Event(loop=self._loop)
255262
self._current = self._loop.create_task(self._get_instance_data())
256263
self._next = self._loop.create_task(self._schedule_refresh())
257264

@@ -350,6 +357,35 @@ def _auth_init(self) -> None:
350357

351358
self._credentials = credentials
352359

360+
async def _force_refresh(self) -> bool:
361+
if self._refresh_in_progress.is_set():
362+
# if a new refresh is already in progress, then block on the result
363+
self._current = await self._next
364+
return True
365+
try:
366+
self._next.cancel()
367+
# schedule a refresh immediately with no delay
368+
self._next = self._loop.create_task(self._schedule_refresh(0))
369+
self._current = await self._next
370+
return True
371+
except Exception as e:
372+
# if anything else goes wrong, log the error and return false
373+
logger.exception("Error occurred during force refresh attempt", exc_info=e)
374+
return False
375+
376+
def force_refresh(self, timeout: Optional[int] = None) -> bool:
377+
"""
378+
Forces a new refresh attempt and returns a boolean value that indicates
379+
whether the attempt was successful.
380+
381+
:type timeout: Optional[int]
382+
:param timeout: Amount of time to wait for the attempted force refresh
383+
to complete before throwing a timeout error.
384+
"""
385+
return asyncio.run_coroutine_threadsafe(
386+
self._force_refresh(), self._loop
387+
).result(timeout=timeout)
388+
353389
async def seconds_until_refresh(self) -> int:
354390
expiration = (await self._current).expiration
355391

@@ -378,7 +414,8 @@ async def _perform_refresh(self) -> asyncio.Task:
378414
:rtype: concurrent.future.Futures
379415
:returns: A future representing the creation of an SSLcontext.
380416
"""
381-
417+
self._refresh_in_progress.set()
418+
await self._refresh_rate_limiter.acquire()
382419
logger.debug("Entered _perform_refresh")
383420

384421
refresh_task = self._loop.create_task(self._get_instance_data())
@@ -387,7 +424,8 @@ async def _perform_refresh(self) -> asyncio.Task:
387424
await refresh_task
388425
except Exception as e:
389426
logger.exception(
390-
"An error occurred while performing refresh. Retrying in 60s.",
427+
"An error occurred while performing refresh."
428+
"Scheduling another refresh attempt immediately",
391429
exc_info=e,
392430
)
393431
instance_data = None
@@ -401,14 +439,14 @@ async def _perform_refresh(self) -> asyncio.Task:
401439
or instance_data.expiration < datetime.datetime.now()
402440
):
403441
self._current = refresh_task
404-
# TODO: Implement force refresh method and a rate-limiter for perform_refresh
405-
# Retry by scheduling a refresh 60s from now.
406-
self._next = self._loop.create_task(self._schedule_refresh(60))
442+
self._next = self._loop.create_task(self._perform_refresh())
407443

408444
else:
409445
self._current = refresh_task
410446
# Ephemeral certificate expires in 1 hour, so we schedule a refresh to happen in 55 minutes.
411447
self._next = self._loop.create_task(self._schedule_refresh())
448+
finally:
449+
self._refresh_in_progress.clear()
412450

413451
return refresh_task
414452

@@ -419,17 +457,15 @@ async def _schedule_refresh(self, delay: Optional[int] = None) -> asyncio.Task:
419457
:rtype: asyncio.Task
420458
:returns: A Task representing _get_instance_data.
421459
"""
422-
logger.debug("Entering sleep")
423460

424461
if delay is None:
425462
delay = await self.seconds_until_refresh()
426-
427463
try:
464+
logger.debug("Entering sleep")
428465
await asyncio.sleep(delay)
429466
except asyncio.CancelledError as e:
430467
logger.debug("Schedule refresh task cancelled.")
431468
raise e
432-
433469
return await self._perform_refresh()
434470

435471
def connect(
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""
2+
Copyright 2021 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
import asyncio
17+
18+
19+
class AsyncRateLimiter(object):
20+
"""
21+
An asyncio-compatible rate limiter which uses the Token Bucket algorithm
22+
(https://en.wikipedia.org/wiki/Token_bucket) to limit the number of function calls over a time interval using an event queue.
23+
24+
:type max_capacity: int
25+
:param: max_capacity:
26+
The maximum capacity of tokens the bucket will store at any one time.
27+
Default: 1
28+
29+
:type rate: float
30+
:param: rate:
31+
The number of tokens that should be added per second.
32+
33+
:type loop: asyncio.AbstractEventLoop
34+
:param: loop:
35+
The event loop to use. If not provided, the default event loop will be used.
36+
37+
38+
"""
39+
40+
def __init__(
41+
self,
42+
max_capacity: int = 1,
43+
rate: float = 1 / 60,
44+
loop: asyncio.AbstractEventLoop = None,
45+
) -> None:
46+
self.rate = rate
47+
self.max_capacity = max_capacity
48+
self._loop = loop or asyncio.get_event_loop()
49+
self._lock = asyncio.Lock(loop=self._loop)
50+
self._tokens: float = max_capacity
51+
self._last_token_update = self._loop.time()
52+
53+
def _update_token_count(self) -> None:
54+
"""
55+
Calculates how much time has passed since the last leak and removes the
56+
appropriate amount of events from the queue.
57+
Leaking is done lazily, meaning that if there is a large time gap between
58+
leaks, the next set of calls might be a burst if burst_size > 1
59+
"""
60+
now = self._loop.time()
61+
time_elapsed = now - self._last_token_update
62+
new_tokens = time_elapsed * self.rate
63+
self._tokens = min(new_tokens + self._tokens, self.max_capacity)
64+
self._last_token_update = now
65+
66+
async def _wait_for_next_token(self) -> None:
67+
"""
68+
Wait until enough time has elapsed to add another token.
69+
"""
70+
token_deficit = 1 - self._tokens
71+
if token_deficit > 0:
72+
wait_time = token_deficit / self.rate
73+
await asyncio.sleep(wait_time, loop=self._loop)
74+
75+
async def acquire(self) -> None:
76+
"""
77+
Waits for a token to become available, if necessary, then subtracts token and allows
78+
request to go through.
79+
"""
80+
async with self._lock:
81+
self._update_token_count()
82+
if self._tokens < 1:
83+
await self._wait_for_next_token()
84+
self._update_token_count()
85+
self._tokens -= 1

tests/unit/test_instance_connection_manager.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import asyncio
1818
import datetime
19+
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
1920
from typing import Any
2021
import pytest # noqa F401 Needed to run the tests
2122
from google.cloud.sql.connector.instance_connection_manager import (
@@ -34,6 +35,11 @@ def icm(
3435
return icm
3536

3637

38+
@pytest.fixture
39+
def test_rate_limiter(async_loop: asyncio.AbstractEventLoop) -> AsyncRateLimiter:
40+
return AsyncRateLimiter(max_capacity=1, rate=1 / 2, loop=async_loop)
41+
42+
3743
class MockMetadata:
3844
def __init__(self, expiration: datetime.datetime) -> None:
3945
self.expiration = expiration
@@ -68,10 +74,14 @@ def test_InstanceConnectionManager_init(async_loop: asyncio.AbstractEventLoop) -
6874

6975

7076
@pytest.mark.asyncio
71-
async def test_perform_refresh_replaces_result(icm: InstanceConnectionManager) -> None:
77+
async def test_perform_refresh_replaces_result(
78+
icm: InstanceConnectionManager, test_rate_limiter: AsyncRateLimiter
79+
) -> None:
7280
"""
7381
Test to check whether _perform_refresh replaces a valid result with another valid result
7482
"""
83+
# allow more frequent refreshes for tests
84+
setattr(icm, "_refresh_rate_limiter", test_rate_limiter)
7585

7686
# stub _get_instance_data to return a "valid" MockMetadata object
7787
setattr(icm, "_get_instance_data", _get_metadata_success)
@@ -85,12 +95,14 @@ async def test_perform_refresh_replaces_result(icm: InstanceConnectionManager) -
8595

8696
@pytest.mark.asyncio
8797
async def test_perform_refresh_wont_replace_valid_result_with_invalid(
88-
icm: InstanceConnectionManager,
98+
icm: InstanceConnectionManager, test_rate_limiter: AsyncRateLimiter
8999
) -> None:
90100
"""
91101
Test to check whether _perform_refresh won't replace a valid _current
92102
value with an invalid one
93103
"""
104+
# allow more frequent refreshes for tests
105+
setattr(icm, "_refresh_rate_limiter", test_rate_limiter)
94106

95107
# stub _get_instance_data to return a "valid" MockMetadata object
96108
setattr(icm, "_get_instance_data", _get_metadata_success)
@@ -111,12 +123,14 @@ async def test_perform_refresh_wont_replace_valid_result_with_invalid(
111123

112124
@pytest.mark.asyncio
113125
async def test_perform_refresh_replaces_invalid_result(
114-
icm: InstanceConnectionManager,
126+
icm: InstanceConnectionManager, test_rate_limiter: AsyncRateLimiter
115127
) -> None:
116128
"""
117129
Test to check whether _perform_refresh will replace an invalid refresh result with
118130
a valid one
119131
"""
132+
# allow more frequent refreshes for tests
133+
setattr(icm, "_refresh_rate_limiter", test_rate_limiter)
120134

121135
# stub _get_instance_data to throw an error
122136
setattr(icm, "_get_instance_data", _get_metadata_error)
@@ -132,3 +146,28 @@ async def test_perform_refresh_replaces_invalid_result(
132146

133147
assert icm._current == new_task
134148
assert isinstance(icm._current.result(), MockMetadata)
149+
150+
151+
@pytest.mark.asyncio
152+
async def test_force_refresh_cancels_pending_refresh(
153+
icm: InstanceConnectionManager,
154+
test_rate_limiter: AsyncRateLimiter,
155+
) -> None:
156+
"""
157+
Test that force_refresh cancels pending task if refresh_in_progress event is not set.
158+
"""
159+
# allow more frequent refreshes for tests
160+
setattr(icm, "_refresh_rate_limiter", test_rate_limiter)
161+
162+
# stub _get_instance_data to return a MockMetadata instance
163+
setattr(icm, "_get_instance_data", _get_metadata_success)
164+
165+
# since the pending refresh isn't for another 55 min, the refresh_in_progress event
166+
# shouldn't be set
167+
pending_refresh = icm._next
168+
assert icm._refresh_in_progress.is_set() is False
169+
170+
icm.force_refresh()
171+
172+
assert pending_refresh.cancelled() is True
173+
assert isinstance(icm._current.result(), MockMetadata)

0 commit comments

Comments
 (0)