Skip to content

Commit 5712e9e

Browse files
chore: lazy init keys with lazy refresh (#1110)
RSA key-pair generation should be done on first connection attempt when lazy refresh strategy is configured.
1 parent deb732e commit 5712e9e

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

google/cloud/sql/connector/connector.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,20 +108,32 @@ def __init__(
108108
RefreshStrategy.BACKGROUND ("BACKGROUND").
109109
Default: RefreshStrategy.BACKGROUND
110110
"""
111+
# if refresh_strategy is str, convert to RefreshStrategy enum
112+
if isinstance(refresh_strategy, str):
113+
refresh_strategy = RefreshStrategy._from_str(refresh_strategy)
114+
self._refresh_strategy = refresh_strategy
111115
# if event loop is given, use for background tasks
112116
if loop:
113117
self._loop: asyncio.AbstractEventLoop = loop
114118
self._thread: Optional[Thread] = None
115-
self._keys: asyncio.Future = loop.create_task(generate_keys())
119+
# if lazy refresh is specified we should lazy init keys
120+
if self._refresh_strategy == RefreshStrategy.LAZY:
121+
self._keys: Optional[asyncio.Future] = None
122+
else:
123+
self._keys = loop.create_task(generate_keys())
116124
# if no event loop is given, spin up new loop in background thread
117125
else:
118126
self._loop = asyncio.new_event_loop()
119127
self._thread = Thread(target=self._loop.run_forever, daemon=True)
120128
self._thread.start()
121-
self._keys = asyncio.wrap_future(
122-
asyncio.run_coroutine_threadsafe(generate_keys(), self._loop),
123-
loop=self._loop,
124-
)
129+
# if lazy refresh is specified we should lazy init keys
130+
if self._refresh_strategy == RefreshStrategy.LAZY:
131+
self._keys = None
132+
else:
133+
self._keys = asyncio.wrap_future(
134+
asyncio.run_coroutine_threadsafe(generate_keys(), self._loop),
135+
loop=self._loop,
136+
)
125137
self._cache: Dict[str, Union[RefreshAheadCache, LazyRefreshCache]] = {}
126138
self._client: Optional[CloudSQLClient] = None
127139

@@ -148,10 +160,6 @@ def __init__(
148160
if isinstance(ip_type, str):
149161
ip_type = IPTypes._from_str(ip_type)
150162
self._ip_type = ip_type
151-
# if refresh_strategy is str, convert to RefreshStrategy enum
152-
if isinstance(refresh_strategy, str):
153-
refresh_strategy = RefreshStrategy._from_str(refresh_strategy)
154-
self._refresh_strategy = refresh_strategy
155163
self._universe_domain = universe_domain
156164
# construct service endpoint for Cloud SQL Admin API calls
157165
if not sqladmin_api_endpoint:
@@ -258,6 +266,8 @@ async def connect_async(
258266
DnsNameResolutionError: Could not resolve PSC IP address from DNS
259267
host name.
260268
"""
269+
if self._keys is None:
270+
self._keys = asyncio.create_task(generate_keys())
261271
if self._client is None:
262272
# lazy init client as it has to be initialized in async context
263273
self._client = CloudSQLClient(

tests/unit/test_connector.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@ def test_Connector_Init(fake_credentials: Credentials) -> None:
114114
connector.close()
115115

116116

117+
def test_Connector_Init_with_lazy_refresh(fake_credentials: Credentials) -> None:
118+
"""Test that Connector with lazy refresh sets keys to None."""
119+
with Connector(credentials=fake_credentials, refresh_strategy="lazy") as connector:
120+
assert connector._keys is None
121+
122+
117123
def test_Connector_Init_with_credentials(fake_credentials: Credentials) -> None:
118124
"""Test that Connector uses custom credentials when given them."""
119125
with patch(

0 commit comments

Comments
 (0)