Skip to content

Commit 964deb0

Browse files
feat: wrap generate_keys() in future (#168)
1 parent 5d95ee2 commit 964deb0

File tree

5 files changed

+17
-18
lines changed

5 files changed

+17
-18
lines changed

google/cloud/alloydb/connector/connector.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ def __init__(
6767
# otherwise use application default credentials
6868
else:
6969
self._credentials, _ = default(scopes=scopes)
70-
self._keys = generate_keys()
70+
self._keys = asyncio.wrap_future(
71+
asyncio.run_coroutine_threadsafe(generate_keys(), self._loop),
72+
loop=self._loop,
73+
)
7174
self._client: Optional[AlloyDBClient] = None
7275

7376
def connect(self, instance_uri: str, driver: str, **kwargs: Any) -> Any:
@@ -123,11 +126,7 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
123126
if instance_uri in self._instances:
124127
instance = self._instances[instance_uri]
125128
else:
126-
instance = Instance(
127-
instance_uri,
128-
self._client,
129-
self._keys,
130-
)
129+
instance = Instance(instance_uri, self._client, self._keys)
131130
self._instances[instance_uri] = instance
132131

133132
connect_func = {

google/cloud/alloydb/connector/instance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353
self,
5454
instance_uri: str,
5555
client: AlloyDBClient,
56-
keys: Tuple[rsa.RSAPrivateKey, str],
56+
keys: asyncio.Future[Tuple[rsa.RSAPrivateKey, str]],
5757
) -> None:
5858
# validate and parse instance_uri
5959
instance_uri_split = instance_uri.split("/")
@@ -98,7 +98,7 @@ async def _perform_refresh(self) -> RefreshResult:
9898

9999
try:
100100
await self._refresh_rate_limiter.acquire()
101-
priv_key, pub_key = self._keys
101+
priv_key, pub_key = await self._keys
102102
# fetch metadata
103103
metadata_task = asyncio.create_task(
104104
self._client._get_metadata(

google/cloud/alloydb/connector/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _write_to_file(
4747
return (ca_filename, cert_chain_filename, key_filename)
4848

4949

50-
def generate_keys() -> Tuple[rsa.RSAPrivateKey, str]:
50+
async def generate_keys() -> Tuple[rsa.RSAPrivateKey, str]:
5151
priv_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
5252
pub_key = (
5353
priv_key.public_key()

tests/unit/test_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ async def test__get_client_certificate(
7777
Test _get_client_certificate returns successfully.
7878
"""
7979
test_client = AlloyDBClient("", "", credentials, client)
80-
keys = generate_keys()
80+
keys = await generate_keys()
8181
certs = await test_client._get_client_certificate(
8282
"test-project", "test-region", "test-cluster", keys[1]
8383
)

tests/unit/test_instance.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ async def test_Instance_init() -> None:
3131
Test to check whether the __init__ method of Instance
3232
can tell if the instance URI that's passed in is formatted correctly.
3333
"""
34-
keys = generate_keys()
34+
keys = asyncio.create_task(generate_keys())
3535
async with aiohttp.ClientSession() as client:
3636
instance = Instance(
3737
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
@@ -52,7 +52,7 @@ async def test_Instance_init_invalid_instant_uri() -> None:
5252
Test to check whether the __init__ method of Instance
5353
will throw error for invalid instance URI.
5454
"""
55-
keys = generate_keys()
55+
keys = asyncio.create_task(generate_keys())
5656
async with aiohttp.ClientSession() as client:
5757
with pytest.raises(ValueError):
5858
Instance("invalid/instance/uri/", client, keys)
@@ -64,7 +64,7 @@ async def test_Instance_close() -> None:
6464
Test that Instance's close method
6565
cancels tasks gracefully.
6666
"""
67-
keys = generate_keys()
67+
keys = asyncio.create_task(generate_keys())
6868
client = FakeAlloyDBClient()
6969
instance = Instance(
7070
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
@@ -84,7 +84,7 @@ async def test_Instance_close() -> None:
8484
@pytest.mark.asyncio
8585
async def test_perform_refresh() -> None:
8686
"""Test that _perform refresh returns valid RefreshResult"""
87-
keys = generate_keys()
87+
keys = asyncio.create_task(generate_keys())
8888
client = FakeAlloyDBClient()
8989
instance = Instance(
9090
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
@@ -104,7 +104,7 @@ async def test_schedule_refresh_replaces_result() -> None:
104104
Test to check whether _schedule_refresh replaces a valid refresh result
105105
with another refresh result.
106106
"""
107-
keys = generate_keys()
107+
keys = asyncio.create_task(generate_keys())
108108
client = FakeAlloyDBClient()
109109
instance = Instance(
110110
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
@@ -131,7 +131,7 @@ async def test_schedule_refresh_wont_replace_valid_result_with_invalid() -> None
131131
Test to check whether _schedule_refresh won't replace a valid
132132
refresh result with an invalid one.
133133
"""
134-
keys = generate_keys()
134+
keys = asyncio.create_task(generate_keys())
135135
client = FakeAlloyDBClient()
136136
instance = Instance(
137137
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
@@ -160,7 +160,7 @@ async def test_schedule_refresh_expired_cert() -> None:
160160
Test to check whether _schedule_refresh will throw RefreshError on
161161
expired certificate.
162162
"""
163-
keys = generate_keys()
163+
keys = asyncio.create_task(generate_keys())
164164
client = FakeAlloyDBClient()
165165
# set certificate to be expired
166166
client.instance.cert_before = datetime.now() - timedelta(minutes=20)
@@ -182,7 +182,7 @@ async def test_force_refresh_cancels_pending_refresh() -> None:
182182
"""
183183
Test that force_refresh cancels pending task if refresh_in_progress event is not set.
184184
"""
185-
keys = generate_keys()
185+
keys = asyncio.create_task(generate_keys())
186186
client = FakeAlloyDBClient()
187187
instance = Instance(
188188
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",

0 commit comments

Comments
 (0)