Skip to content

Commit 8fdc7d7

Browse files
WIP: replace aiohttp.ClientSession with AlloyDBAdminAsyncClient
1 parent 9dac5b2 commit 8fdc7d7

File tree

5 files changed

+119
-186
lines changed

5 files changed

+119
-186
lines changed

google/cloud/alloydb/connector/async_connector.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from google.cloud.alloydb.connector.lazy import LazyRefreshCache
3232
from google.cloud.alloydb.connector.types import CacheTypes
3333
from google.cloud.alloydb.connector.utils import generate_keys
34+
import traceback
3435

3536
if TYPE_CHECKING:
3637
from google.auth.credentials import Credentials
@@ -182,6 +183,7 @@ async def connect(
182183
conn_info = await cache.connect_info()
183184
ip_address = conn_info.get_preferred_ip(ip_type)
184185
except Exception:
186+
print(f"RISHABH DEBUG: exception = {traceback.print_exc()}")
185187
# with an error from AlloyDB API call or IP type, invalidate the
186188
# cache and re-raise the error
187189
await self._remove_cached(instance_uri)

google/cloud/alloydb/connector/client.py

Lines changed: 64 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,16 @@
1818
import logging
1919
from typing import Optional, TYPE_CHECKING
2020

21-
import aiohttp
2221
from cryptography import x509
22+
from google.api_core.client_options import ClientOptions
23+
from google.api_core.gapic_v1.client_info import ClientInfo
2324
from google.auth.credentials import TokenState
2425
from google.auth.transport import requests
2526

27+
from google.cloud import alloydb_v1beta
2628
from google.cloud.alloydb.connector.connection_info import ConnectionInfo
2729
from google.cloud.alloydb.connector.version import __version__ as version
30+
from google.protobuf import duration_pb2
2831

2932
if TYPE_CHECKING:
3033
from google.auth.credentials import Credentials
@@ -55,7 +58,7 @@ def __init__(
5558
alloydb_api_endpoint: str,
5659
quota_project: Optional[str],
5760
credentials: Credentials,
58-
client: Optional[aiohttp.ClientSession] = None,
61+
client: Optional[alloydb_v1beta.AlloyDBAdminAsyncClient] = None,
5962
driver: Optional[str] = None,
6063
user_agent: Optional[str] = None,
6164
) -> None:
@@ -72,21 +75,23 @@ def __init__(
7275
A credentials object created from the google-auth Python library.
7376
Must have the AlloyDB Admin scopes. For more info check out
7477
https://google-auth.readthedocs.io/en/latest/.
75-
client (aiohttp.ClientSession): Async client used to make requests to
76-
AlloyDB APIs.
78+
client (alloydb_v1.AlloyDBAdminAsyncClient): Async client used to
79+
make requests to AlloyDB APIs.
7780
Optional, defaults to None and creates new client.
7881
driver (str): Database driver to be used by the client.
7982
"""
8083
user_agent = _format_user_agent(driver, user_agent)
81-
headers = {
82-
"x-goog-api-client": user_agent,
83-
"User-Agent": user_agent,
84-
"Content-Type": "application/json",
85-
}
86-
if quota_project:
87-
headers["x-goog-user-project"] = quota_project
8884

89-
self._client = client if client else aiohttp.ClientSession(headers=headers)
85+
self._client = client if client else alloydb_v1beta.AlloyDBAdminAsyncClient(
86+
credentials=credentials,
87+
client_options=ClientOptions(
88+
api_endpoint=alloydb_api_endpoint,
89+
quota_project_id=quota_project,
90+
),
91+
client_info=ClientInfo(
92+
user_agent=user_agent,
93+
),
94+
)
9095
self._credentials = credentials
9196
self._alloydb_api_endpoint = alloydb_api_endpoint
9297
# asyncpg does not currently support using metadata exchange
@@ -118,35 +123,33 @@ async def _get_metadata(
118123
Returns:
119124
dict: IP addresses of the AlloyDB instance.
120125
"""
121-
headers = {
122-
"Authorization": f"Bearer {self._credentials.token}",
123-
}
124-
125-
url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}/instances/{name}/connectionInfo"
126-
127-
resp = await self._client.get(url, headers=headers)
128-
# try to get response json for better error message
129-
try:
130-
resp_dict = await resp.json()
131-
if resp.status >= 400:
132-
# if detailed error message is in json response, use as error message
133-
message = resp_dict.get("error", {}).get("message")
134-
if message:
135-
resp.reason = message
136-
# skip, raise_for_status will catch all errors in finally block
137-
except Exception:
138-
pass
139-
finally:
140-
resp.raise_for_status()
126+
parent = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}/instances/{name}"
127+
128+
req = alloydb_v1beta.GetConnectionInfoRequest(parent=parent)
129+
resp = await self._client.get_connection_info(request=req)
130+
resp = await resp
131+
# # try to get response json for better error message
132+
# try:
133+
# resp_dict = await resp.json()
134+
# if resp.status >= 400:
135+
# # if detailed error message is in json response, use as error message
136+
# message = resp_dict.get("error", {}).get("message")
137+
# if message:
138+
# resp.reason = message
139+
# # skip, raise_for_status will catch all errors in finally block
140+
# except Exception:
141+
# pass
142+
# finally:
143+
# resp.raise_for_status()
141144

142145
# Remove trailing period from PSC DNS name.
143-
psc_dns = resp_dict.get("pscDnsName")
146+
psc_dns = resp.psc_dns_name
144147
if psc_dns:
145148
psc_dns = psc_dns.rstrip(".")
146149

147150
return {
148-
"PRIVATE": resp_dict.get("ipAddress"),
149-
"PUBLIC": resp_dict.get("publicIpAddress"),
151+
"PRIVATE": resp.ip_address,
152+
"PUBLIC": resp.public_ip_address,
150153
"PSC": psc_dns,
151154
}
152155

@@ -175,34 +178,32 @@ async def _get_client_certificate(
175178
tuple[str, list[str]]: tuple containing the CA certificate
176179
and certificate chain for the AlloyDB instance.
177180
"""
178-
headers = {
179-
"Authorization": f"Bearer {self._credentials.token}",
180-
}
181-
182-
url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}:generateClientCertificate"
183-
184-
data = {
185-
"publicKey": pub_key,
186-
"certDuration": "3600s",
187-
"useMetadataExchange": self._use_metadata,
188-
}
189-
190-
resp = await self._client.post(url, headers=headers, json=data)
191-
# try to get response json for better error message
192-
try:
193-
resp_dict = await resp.json()
194-
if resp.status >= 400:
195-
# if detailed error message is in json response, use as error message
196-
message = resp_dict.get("error", {}).get("message")
197-
if message:
198-
resp.reason = message
199-
# skip, raise_for_status will catch all errors in finally block
200-
except Exception:
201-
pass
202-
finally:
203-
resp.raise_for_status()
204-
205-
return (resp_dict["caCert"], resp_dict["pemCertificateChain"])
181+
parent = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}"
182+
dur = duration_pb2.Duration()
183+
dur.seconds = 3600
184+
req = alloydb_v1beta.GenerateClientCertificateRequest(
185+
parent=parent,
186+
cert_duration=dur,
187+
public_key=pub_key,
188+
use_metadata_exchange=self._use_metadata,
189+
)
190+
resp = await self._client.generate_client_certificate(request=req)
191+
resp = await resp
192+
# # try to get response json for better error message
193+
# try:
194+
# resp_dict = await resp.json()
195+
# if resp.status >= 400:
196+
# # if detailed error message is in json response, use as error message
197+
# message = resp_dict.get("error", {}).get("message")
198+
# if message:
199+
# resp.reason = message
200+
# # skip, raise_for_status will catch all errors in finally block
201+
# except Exception:
202+
# pass
203+
# finally:
204+
# resp.raise_for_status()
205+
206+
return (resp.ca_cert, resp.pem_certificate_chain)
206207

207208
async def get_connection_info(
208209
self,
@@ -271,5 +272,4 @@ async def get_connection_info(
271272
async def close(self) -> None:
272273
"""Close AlloyDBClient gracefully."""
273274
logger.debug("Waiting for connector's http client to close")
274-
await self._client.close()
275275
logger.debug("Closed connector's http client")

tests/unit/test_async_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
import asyncio
1616
from typing import Union
1717

18-
from aiohttp import ClientResponseError
1918
from mock import patch
2019
from mocks import FakeAlloyDBClient
2120
from mocks import FakeConnectionInfo
2221
from mocks import FakeCredentials
2322
import pytest
2423

24+
from google.api_core.exceptions import RetryError
2525
from google.cloud.alloydb.connector import AsyncConnector
2626
from google.cloud.alloydb.connector import IPTypes
2727
from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError
@@ -309,7 +309,7 @@ async def test_Connector_remove_cached_bad_instance(
309309
"""
310310
instance_uri = "projects/test-project/locations/test-region/clusters/test-cluster/instances/bad-test-instance"
311311
async with AsyncConnector(credentials=credentials) as connector:
312-
with pytest.raises(ClientResponseError):
312+
with pytest.raises(RetryError):
313313
await connector.connect(instance_uri, "asyncpg")
314314
assert instance_uri not in connector._cache
315315

0 commit comments

Comments
 (0)