Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 34 additions & 30 deletions google/cloud/alloydbconnector/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,19 @@ async def connect(

enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)

# callable to be used for auto IAM authn
def get_authentication_token() -> str:
"""Get OAuth2 access token to be used for IAM database authentication"""
# refresh credentials if expired
if not self._credentials.valid:
request = google.auth.transport.requests.Request()
self._credentials.refresh(request)
return self._credentials.token

# if enable_iam_auth is set, use auth token as database password
if enable_iam_auth:
kwargs["password"] = get_authentication_token

# use existing connection info if possible
if instance_uri in self._cache:
cache = self._cache[instance_uri]
Expand Down Expand Up @@ -185,36 +198,27 @@ async def connect(
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes(ip_type.upper())
try:
conn_info = await cache.connect_info()
ip_address = conn_info.get_preferred_ip(ip_type)
except Exception:
# with an error from AlloyDB API call or IP type, invalidate the
# cache and re-raise the error
await self._remove_cached(instance_uri)
raise
logger.debug(f"['{instance_uri}']: Connecting to {ip_address}:5433")

# callable to be used for auto IAM authn
def get_authentication_token() -> str:
"""Get OAuth2 access token to be used for IAM database authentication"""
# refresh credentials if expired
if not self._credentials.valid:
request = google.auth.transport.requests.Request()
self._credentials.refresh(request)
return self._credentials.token

# if enable_iam_auth is set, use auth token as database password
if enable_iam_auth:
kwargs["password"] = get_authentication_token
try:
return await connector(
ip_address, await conn_info.create_ssl_context(), **kwargs
)
except Exception:
# we attempt a force refresh, then throw the error
await cache.force_refresh()
raise
for i in range(2):
try:
conn_info = await cache.connect_info()
ip_address = conn_info.get_preferred_ip(ip_type)
except Exception:
# with an error from AlloyDB API call or IP type, invalidate the
# cache and re-raise the error
await self._remove_cached(instance_uri)
raise
logger.debug(f"['{instance_uri}']: Connecting to {ip_address}:5433")

try:
return await connector(
ip_address, await conn_info.create_ssl_context(), **kwargs
)
except Exception:
# we attempt a force refresh and retry the connection before
# throwing an error
if i == 1:
raise
await cache.force_refresh(True)

async def _remove_cached(self, instance_uri: str) -> None:
"""Stops all background refreshes and deletes the connection
Expand Down
53 changes: 28 additions & 25 deletions google/cloud/alloydbconnector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,31 +226,34 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes(ip_type.upper())
try:
conn_info = await cache.connect_info()
ip_address = conn_info.get_preferred_ip(ip_type)
except Exception:
# with an error from AlloyDB API call or IP type, invalidate the
# cache and re-raise the error
await self._remove_cached(instance_uri)
raise
logger.debug(f"['{instance_uri}']: Connecting to {ip_address}:5433")

# synchronous drivers are blocking and run using executor
try:
metadata_partial = partial(
self.metadata_exchange,
ip_address,
await conn_info.create_ssl_context(),
enable_iam_auth,
)
sock = await self._loop.run_in_executor(None, metadata_partial)
connect_partial = partial(connector, sock, **kwargs)
return await self._loop.run_in_executor(None, connect_partial)
except Exception:
# we attempt a force refresh, then throw the error
await cache.force_refresh()
raise
for i in range(2):
try:
conn_info = await cache.connect_info()
ip_address = conn_info.get_preferred_ip(ip_type)
except Exception:
# with an error from AlloyDB API call or IP type, invalidate the
# cache and re-raise the error
await self._remove_cached(instance_uri)
raise
logger.debug(f"['{instance_uri}']: Connecting to {ip_address}:5433")

# synchronous drivers are blocking and run using executor
try:
metadata_partial = partial(
self.metadata_exchange,
ip_address,
await conn_info.create_ssl_context(),
enable_iam_auth,
)
sock = await self._loop.run_in_executor(None, metadata_partial)
connect_partial = partial(connector, sock, **kwargs)
return await self._loop.run_in_executor(None, connect_partial)
except Exception:
# we attempt a force refresh and retry the connection before
# throwing an error
if i == 1:
raise
await cache.force_refresh(True)

def metadata_exchange(
self, ip_address: str, ctx: ssl.SSLContext, enable_iam_auth: bool
Expand Down
4 changes: 2 additions & 2 deletions google/cloud/alloydbconnector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ async def _refresh_operation(self, delay: int) -> ConnectionInfo:

return refresh_result

async def force_refresh(self) -> None:
async def force_refresh(self, block: bool = False) -> None:
"""
Schedules a new refresh operation immediately to be used
for future connection attempts.
Expand All @@ -218,7 +218,7 @@ async def force_refresh(self) -> None:
self._next.cancel()
self._next = self._schedule_refresh(0)
# block all sequential connection attempts on the next refresh result if current is invalid
if not await _is_valid(self._current):
if block or not await _is_valid(self._current):
self._current = self._next

async def connect_info(self) -> ConnectionInfo:
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/alloydbconnector/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
self._cached: Optional[ConnectionInfo] = None
self._needs_refresh = False

async def force_refresh(self) -> None:
async def force_refresh(self, block: bool = False) -> None:
"""
Invalidates the cache and configures the next call to
connect_info() to retrieve a fresh ConnectionInfo instance.
Expand Down
2 changes: 1 addition & 1 deletion google/cloud/alloydbconnector/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(self, instance_uri: str, static_conn_info: io.TextIOBase) -> None:
cert_chain, ca_cert, priv_key_bytes, ip_addrs, expiration
)

async def force_refresh(self) -> None:
async def force_refresh(self, block: bool = False) -> None:
"""
This is a no-op as the cache holds only static connection information
and does no refresh.
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ class FakeConnectionInfo:
def __init__(self) -> None:
self._close_called = False
self._force_refresh_called = False
self._force_refresh_blocking = False

def connect_info(self) -> Any:
f = asyncio.Future()
Expand All @@ -400,8 +401,9 @@ def get_preferred_ip(self, ip_type: Any) -> tuple[str, Any]:
async def create_ssl_context(self) -> None:
return None

async def force_refresh(self) -> None:
async def force_refresh(self, block: bool) -> None:
self._force_refresh_called = True
self._force_refresh_blocking = block

async def close(self) -> None:
self._close_called = True
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/test_async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,30 @@ async def test_connect_when_closed(credentials: FakeCredentials) -> None:
exc_info.value.args[0]
== "Connection attempt failed because the connector has already been closed."
)


async def test_connect_after_force_refresh(
credentials: FakeCredentials, fake_client: FakeAlloyDBClient
) -> None:
"""
Test that connector.connect can succeed after force refreshing its cache.
"""
async with AsyncConnector(credentials) as connector:
fake = FakeConnectionInfo()
connector._cache[TEST_INSTANCE_NAME] = fake

connector._client = fake_client
# patch db connection creation
with patch("google.cloud.alloydbconnector.asyncpg.connect") as mock_connect:
mock_connect.side_effect = [Exception(), True]
connection = await connector.connect(
TEST_INSTANCE_NAME,
"asyncpg",
user="test-user",
password="test-password",
db="test-db",
)
# check connection after force refreshing cache
assert connection is True
assert fake._force_refresh_called is True
assert fake._force_refresh_blocking is True
23 changes: 23 additions & 0 deletions tests/unit/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,3 +337,26 @@ def test_connect_when_closed(credentials: FakeCredentials) -> None:
exc_info.value.args[0]
== "Connection attempt failed because the connector has already been closed."
)


@pytest.mark.usefixtures("proxy_server")
def test_connect_after_force_refresh(
credentials: FakeCredentials, fake_client: FakeAlloyDBClient
) -> None:
"""
Test that connector.connect can succeed after force refreshing its cache.
"""
with Connector(credentials) as connector:
connector._client = fake_client
# patch db connection creation
with patch("google.cloud.alloydbconnector.pg8000.connect") as mock_connect:
mock_connect.side_effect = [Exception(), True]
connection = connector.connect(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
"pg8000",
user="test-user",
password="test-password",
db="test-db",
)
# check connection is returned after force refreshing cache
assert connection is True
22 changes: 22 additions & 0 deletions tests/unit/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,25 @@ async def test_force_refresh_cancels_pending_refresh() -> None:
assert isinstance(await cache._current, ConnectionInfo)
# close instance
await cache.close()


@pytest.mark.asyncio
async def test_force_refresh_after_blocking_sets_current_to_next() -> None:
"""
Test that force_refresh sets the current task to the next task.
"""
keys = asyncio.create_task(generate_keys())
client = FakeAlloyDBClient()
cache = RefreshAheadCache(
"projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance",
client,
keys,
)
# make sure initial refresh is finished
await cache._current

assert cache._current != cache._next
await cache.force_refresh(True)
assert cache._current == cache._next
# close instance
await cache.close()
Loading