Skip to content

Commit 3020037

Browse files
authored
chore: refactor perform_refresh and add tests (#144)
* chore(tests): refactor perform_refresh and add tests * chore: move stubs to outside of tests * use run_coroutine_threadsafe to call _perform_refresh in tests * linting
1 parent 108833e commit 3020037

File tree

3 files changed

+99
-32
lines changed

3 files changed

+99
-32
lines changed

google/cloud/sql/connector/instance_connection_manager.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -383,35 +383,32 @@ async def _perform_refresh(self) -> asyncio.Task:
383383

384384
refresh_task = self._loop.create_task(self._get_instance_data())
385385

386-
def _refresh_callback(task: asyncio.Task) -> None:
386+
try:
387+
await refresh_task
388+
except Exception as e:
389+
logger.exception(
390+
"An error occurred while performing refresh. Retrying in 60s.",
391+
exc_info=e,
392+
)
393+
instance_data = None
387394
try:
388-
task.result()
389-
except Exception as e:
390-
logger.exception(
391-
"An error occurred while performing refresh. Retrying in 60s.",
392-
exc_info=e,
393-
)
394-
instance_data = None
395-
try:
396-
instance_data = self._current.result()
397-
except Exception:
398-
# Current result is invalid, no-op
399-
logger.debug("Current instance data is invalid.")
400-
if (
401-
instance_data is None
402-
or instance_data.expiration < datetime.datetime.now()
403-
):
404-
self._current = task
405-
# TODO: Implement force refresh method and a rate-limiter for perform_refresh
406-
# Retry by scheduling a refresh 60s from now.
407-
self._next = self._loop.create_task(self._schedule_refresh(60))
408-
409-
else:
395+
instance_data = await self._current
396+
except Exception:
397+
# Current result is invalid, no-op
398+
logger.debug("Current instance data is invalid.")
399+
if (
400+
instance_data is None
401+
or instance_data.expiration < datetime.datetime.now()
402+
):
410403
self._current = refresh_task
411-
# Ephemeral certificate expires in 1 hour, so we schedule a refresh to happen in 55 minutes.
412-
self._next = self._loop.create_task(self._schedule_refresh())
404+
# TODO: Implement force refresh method and a rate-limiter for perform_refresh
405+
# Retry by scheduling a refresh 60s from now.
406+
self._next = self._loop.create_task(self._schedule_refresh(60))
413407

414-
refresh_task.add_done_callback(_refresh_callback)
408+
else:
409+
self._current = refresh_task
410+
# Ephemeral certificate expires in 1 hour, so we schedule a refresh to happen in 55 minutes.
411+
self._next = self._loop.create_task(self._schedule_refresh())
415412

416413
return refresh_task
417414

noxfile.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def default(session, path):
6565
"py.test",
6666
# "--cov=util",
6767
# "--cov=connector",
68+
"-v",
6869
"--cov-append",
6970
"--cov-config=.coveragerc",
7071
"--cov-report=",

tests/unit/test_instance_connection_manager.py

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"""
1616

1717
import asyncio
18+
import datetime
19+
from typing import Any
1820
import pytest # noqa F401 Needed to run the tests
1921
from google.cloud.sql.connector.instance_connection_manager import (
2022
InstanceConnectionManager,
@@ -32,6 +34,19 @@ def icm(
3234
return icm
3335

3436

37+
class MockMetadata:
38+
def __init__(self, expiration: datetime.datetime) -> None:
39+
self.expiration = expiration
40+
41+
42+
async def _get_metadata_success(*args: Any, **kwargs: Any) -> MockMetadata:
43+
return MockMetadata(datetime.datetime.now() + datetime.timedelta(minutes=10))
44+
45+
46+
async def _get_metadata_error(*args: Any, **kwargs: Any) -> None:
47+
raise Exception("something went wrong...")
48+
49+
3550
def test_InstanceConnectionManager_init(async_loop: asyncio.AbstractEventLoop) -> None:
3651
"""
3752
Test to check whether the __init__ method of InstanceConnectionManager
@@ -53,13 +68,67 @@ def test_InstanceConnectionManager_init(async_loop: asyncio.AbstractEventLoop) -
5368

5469

5570
@pytest.mark.asyncio
56-
async def test_InstanceConnectionManager_perform_refresh(
57-
icm: InstanceConnectionManager, async_loop: asyncio.AbstractEventLoop
71+
async def test_perform_refresh_replaces_result(icm: InstanceConnectionManager) -> None:
72+
"""
73+
Test to check whether _perform_refresh replaces a valid result with another valid result
74+
"""
75+
76+
# stub _get_instance_data to return a "valid" MockMetadata object
77+
setattr(icm, "_get_instance_data", _get_metadata_success)
78+
new_task = asyncio.run_coroutine_threadsafe(
79+
icm._perform_refresh(), icm._loop
80+
).result(timeout=10)
81+
82+
assert icm._current == new_task
83+
assert isinstance(icm._current.result(), MockMetadata)
84+
85+
86+
@pytest.mark.asyncio
87+
async def test_perform_refresh_wont_replace_valid_result_with_invalid(
88+
icm: InstanceConnectionManager,
5889
) -> None:
5990
"""
60-
Test to check whether _perform_refresh works as described given valid
61-
conditions.
91+
Test to check whether _perform_refresh won't replace a valid _current
92+
value with an invalid one
6293
"""
63-
task = await icm._perform_refresh()
6494

65-
assert isinstance(task, asyncio.Task)
95+
# stub _get_instance_data to return a "valid" MockMetadata object
96+
setattr(icm, "_get_instance_data", _get_metadata_success)
97+
icm._current = asyncio.run_coroutine_threadsafe(
98+
icm._perform_refresh(), icm._loop
99+
).result(timeout=10)
100+
old_task = icm._current
101+
102+
# stub _get_instance_data to throw an error, then await _perform_refresh
103+
setattr(icm, "_get_instance_data", _get_metadata_error)
104+
asyncio.run_coroutine_threadsafe(icm._perform_refresh(), icm._loop).result(
105+
timeout=10
106+
)
107+
108+
assert icm._current == old_task
109+
assert isinstance(icm._current.result(), MockMetadata)
110+
111+
112+
@pytest.mark.asyncio
113+
async def test_perform_refresh_replaces_invalid_result(
114+
icm: InstanceConnectionManager,
115+
) -> None:
116+
"""
117+
Test to check whether _perform_refresh will replace an invalid refresh result with
118+
a valid one
119+
"""
120+
121+
# stub _get_instance_data to throw an error
122+
setattr(icm, "_get_instance_data", _get_metadata_error)
123+
icm._current = asyncio.run_coroutine_threadsafe(
124+
icm._perform_refresh(), icm._loop
125+
).result(timeout=10)
126+
127+
# stub _get_instance_data to return a MockMetadata instance
128+
setattr(icm, "_get_instance_data", _get_metadata_success)
129+
new_task = asyncio.run_coroutine_threadsafe(
130+
icm._perform_refresh(), icm._loop
131+
).result(timeout=10)
132+
133+
assert icm._current == new_task
134+
assert isinstance(icm._current.result(), MockMetadata)

0 commit comments

Comments
 (0)