diff --git a/packages/toolbox-core/README.md b/packages/toolbox-core/README.md index 791b87a9..7a553291 100644 --- a/packages/toolbox-core/README.md +++ b/packages/toolbox-core/README.md @@ -337,7 +337,7 @@ For Toolbox servers hosted on Google Cloud (e.g., Cloud Run) and requiring ```python from toolbox_core import auth_methods - auth_token_provider = auth_methods.aget_google_id_token # can also use sync method + auth_token_provider = auth_methods.aget_google_id_token(URL) # can also use sync method async with ToolboxClient( URL, client_headers={"Authorization": auth_token_provider}, diff --git a/packages/toolbox-core/pyproject.toml b/packages/toolbox-core/pyproject.toml index 27d32e65..0312aaec 100644 --- a/packages/toolbox-core/pyproject.toml +++ b/packages/toolbox-core/pyproject.toml @@ -13,6 +13,8 @@ dependencies = [ "pydantic>=2.7.0,<3.0.0", "aiohttp>=3.8.6,<4.0.0", "deprecated>=1.2.15,<2.0.0", + "google-auth>=2.0.0,<3.0.0", + "requests>=2.19.0,<3.0.0" ] classifiers = [ @@ -52,6 +54,7 @@ test = [ "pytest-mock==3.14.1", "google-cloud-secret-manager==2.24.0", "google-cloud-storage==3.2.0", + "aioresponses==0.7.8" ] [build-system] requires = ["setuptools"] diff --git a/packages/toolbox-core/requirements.txt b/packages/toolbox-core/requirements.txt index c07e83b8..7b3d2a83 100644 --- a/packages/toolbox-core/requirements.txt +++ b/packages/toolbox-core/requirements.txt @@ -1,3 +1,5 @@ aiohttp==3.12.14 pydantic==2.11.7 -deprecated==1.2.18 \ No newline at end of file +deprecated==1.2.18 +requests==2.32.4 +google-auth==2.40.3 \ No newline at end of file diff --git a/packages/toolbox-core/src/toolbox_core/auth_methods.py b/packages/toolbox-core/src/toolbox_core/auth_methods.py index 003ddcf2..9d1c77a0 100644 --- a/packages/toolbox-core/src/toolbox_core/auth_methods.py +++ b/packages/toolbox-core/src/toolbox_core/auth_methods.py @@ -13,159 +13,162 @@ # limitations under the License. """ -This module provides functions to obtain Google ID tokens, formatted as "Bearer" tokens, -for use in the "Authorization" header of HTTP requests. +This module provides functions to obtain Google ID tokens for a specific audience. -Example User Experience: +The tokens are returned as "Bearer" strings for direct use in HTTP Authorization +headers. It uses a simple in-memory cache to avoid refetching on every call. + +Example Usage: from toolbox_core import auth_methods -auth_token_provider = auth_methods.aget_google_id_token +URL = "https://toolbox-service-url" async with ToolboxClient( URL, - client_headers={"Authorization": auth_token_provider}, -) as toolbox: + client_headers={"Authorization": auth_methods.aget_google_id_token}) +as toolbox: tools = await toolbox.load_toolset() """ +import asyncio from datetime import datetime, timedelta, timezone -from functools import partial -from typing import Any, Dict, Optional +from typing import Any, Callable, Coroutine, Dict, Optional import google.auth -from google.auth._credentials_async import Credentials -from google.auth._default_async import default_async -from google.auth.transport import _aiohttp_requests +from google.auth.exceptions import GoogleAuthError from google.auth.transport.requests import AuthorizedSession, Request +from google.oauth2 import id_token -# --- Constants and Configuration --- -# Prefix for Authorization header tokens +# --- Constants --- BEARER_TOKEN_PREFIX = "Bearer " -# Margin in seconds to refresh token before its actual expiry -CACHE_REFRESH_MARGIN_SECONDS = 60 +CACHE_REFRESH_MARGIN = timedelta(seconds=60) + +_token_cache: Dict[str, Any] = { + "token": None, + "expires_at": datetime.min.replace(tzinfo=timezone.utc), +} -# --- Global Cache Storage --- -# Stores the cached Google ID token and its expiry timestamp -_cached_google_id_token: Dict[str, Any] = {"token": None, "expires_at": 0} +def _is_token_valid() -> bool: + """Checks if the cached token exists and is not nearing expiry.""" + if not _token_cache["token"]: + return False + return datetime.now(timezone.utc) < ( + _token_cache["expires_at"] - CACHE_REFRESH_MARGIN + ) -# --- Helper Functions --- -def _is_cached_token_valid( - cache: Dict[str, Any], margin_seconds: int = CACHE_REFRESH_MARGIN_SECONDS -) -> bool: +def _update_cache(new_token: str) -> None: """ - Checks if a token in the cache is valid (exists and not expired). + Validates a new token, extracts its expiry, and updates the cache. Args: - cache: The dictionary containing 'token' and 'expires_at'. - margin_seconds: The time in seconds before expiry to consider the token invalid. + new_token: The new JWT ID token string. - Returns: - True if the token is valid, False otherwise. + Raises: + ValueError: If the token is invalid or its expiry cannot be determined. """ - if not cache.get("token"): - return False + try: + # verify_oauth2_token not only decodes but also validates the token's + # signature and claims against Google's public keys. + # It's a synchronous, CPU-bound operation, safe for async contexts. + claims = id_token.verify_oauth2_token(new_token, Request()) - expires_at_value = cache.get("expires_at") - if not isinstance(expires_at_value, datetime): - return False + expiry_timestamp = claims.get("exp") + if not expiry_timestamp: + raise ValueError("Token does not contain an 'exp' claim.") - # Ensure expires_at_value is timezone-aware (UTC). - if ( - expires_at_value.tzinfo is None - or expires_at_value.tzinfo.utcoffset(expires_at_value) is None - ): - expires_at_value = expires_at_value.replace(tzinfo=timezone.utc) + _token_cache["token"] = new_token + _token_cache["expires_at"] = datetime.fromtimestamp( + expiry_timestamp, tz=timezone.utc + ) - current_time_utc = datetime.now(timezone.utc) - if current_time_utc + timedelta(seconds=margin_seconds) < expires_at_value: - return True + except (ValueError, GoogleAuthError) as e: + # Clear cache on failure to prevent using a stale or invalid token + _token_cache["token"] = None + _token_cache["expires_at"] = datetime.min.replace(tzinfo=timezone.utc) + raise ValueError(f"Failed to validate and cache the new token: {e}") from e - return False +def get_google_token_from_aud(audience: Optional[str] = None) -> str: + if _is_token_valid(): + return BEARER_TOKEN_PREFIX + _token_cache["token"] -def _update_token_cache( - cache: Dict[str, Any], new_id_token: Optional[str], expiry: Optional[datetime] -) -> None: - """ - Updates the global token cache with a new token and its expiry. + # Get local user credentials + credentials, _ = google.auth.default() + session = AuthorizedSession(credentials) + request = Request(session) + credentials.refresh(request) - Args: - cache: The dictionary containing 'token' and 'expires_at'. - new_id_token: The new ID token string to cache. - """ - if new_id_token: - cache["token"] = new_id_token - expiry_timestamp = expiry - if expiry_timestamp: - cache["expires_at"] = expiry_timestamp - else: - # If expiry can't be determined, treat as immediately expired to force refresh - cache["expires_at"] = 0 - else: - # Clear cache if no new token is provided - cache["token"] = None - cache["expires_at"] = 0 - - -# --- Public API Functions --- -def get_google_id_token() -> str: + if hasattr(credentials, "id_token"): + new_id_token = getattr(credentials, "id_token", None) + if new_id_token: + _update_cache(new_id_token) + return BEARER_TOKEN_PREFIX + new_id_token + + if audience is None: + raise Exception( + "You are not authenticating using User Credentials." + " Please set the audience string to the Toolbox service URL to get the Google ID token." + ) + + # Get credentials for Google Cloud environments or for service account key files + try: + request = Request() + new_token = id_token.fetch_id_token(request, audience) + _update_cache(new_token) + return BEARER_TOKEN_PREFIX + _token_cache["token"] + + except GoogleAuthError as e: + raise GoogleAuthError( + f"Failed to fetch Google ID token for audience '{audience}': {e}" + ) from e + + +def get_google_id_token(audience: Optional[str] = None) -> Callable[[], str]: """ - Synchronously fetches a Google ID token. + Returns a SYNC function that, when called, fetches a Google ID token. + This function uses Application Default Credentials for local systems + and standard google auth libraries for Google Cloud environments. + It caches the token in memory. - The token is formatted as a 'Bearer' token string and is suitable for use - in an HTTP Authorization header. This function uses Application Default - Credentials. + Args: + audience: The audience for the ID token (e.g., a service URL or client ID). Returns: - A string in the format "Bearer ". + A function that when executed returns string in the format "Bearer ". Raises: - Exception: If fetching the Google ID token fails. + GoogleAuthError: If fetching credentials or the token fails. + ValueError: If the fetched token is invalid. """ - if _is_cached_token_valid(_cached_google_id_token): - return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"] - credentials, _ = google.auth.default() - session = AuthorizedSession(credentials) - request = Request(session) - credentials.refresh(request) - new_id_token = getattr(credentials, "id_token", None) - expiry = getattr(credentials, "expiry") + def _token_getter() -> str: + return get_google_token_from_aud(audience) - _update_token_cache(_cached_google_id_token, new_id_token, expiry) - if new_id_token: - return BEARER_TOKEN_PREFIX + new_id_token - else: - raise Exception("Failed to fetch Google ID token.") + return _token_getter -async def aget_google_id_token() -> str: +def aget_google_id_token( + audience: Optional[str] = None, +) -> Callable[[], Coroutine[Any, Any, str]]: """ - Asynchronously fetches a Google ID token. + Returns an ASYNC function that, when called, fetches a Google ID token. + This function uses Application Default Credentials for local systems + and standard google auth libraries for Google Cloud environments. + It caches the token in memory. - The token is formatted as a 'Bearer' token string and is suitable for use - in an HTTP Authorization header. This function uses Application Default - Credentials. + Args: + audience: The audience for the ID token (e.g., a service URL or client ID). Returns: - A string in the format "Bearer ". + An async function that when executed returns string in the format "Bearer ". Raises: - Exception: If fetching the Google ID token fails. + GoogleAuthError: If fetching credentials or the token fails. + ValueError: If the fetched token is invalid. """ - if _is_cached_token_valid(_cached_google_id_token): - return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"] - - credentials, _ = default_async() - await credentials.refresh(_aiohttp_requests.Request()) - credentials.before_request = partial(Credentials.before_request, credentials) - new_id_token = getattr(credentials, "id_token", None) - expiry = getattr(credentials, "expiry") - _update_token_cache(_cached_google_id_token, new_id_token, expiry) + async def _token_getter() -> str: + return await asyncio.to_thread(get_google_token_from_aud, audience) - if new_id_token: - return BEARER_TOKEN_PREFIX + new_id_token - else: - raise Exception("Failed to fetch async Google ID token.") + return _token_getter diff --git a/packages/toolbox-core/tests/test_auth_methods.py b/packages/toolbox-core/tests/test_auth_methods.py index 68d0fef2..db145a2c 100644 --- a/packages/toolbox-core/tests/test_auth_methods.py +++ b/packages/toolbox-core/tests/test_auth_methods.py @@ -12,394 +12,219 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch +import time +from unittest.mock import ANY, MagicMock, patch import pytest +from google.auth.exceptions import GoogleAuthError from toolbox_core import auth_methods # Constants for test values -MOCK_GOOGLE_ID_TOKEN = "test_id_token_123" +MOCK_ID_TOKEN = "test_id_token_123" MOCK_PROJECT_ID = "test-project" +MOCK_AUDIENCE = "https://test-audience.com" # A realistic expiry timestamp (e.g., 1 hour from now) -MOCK_EXPIRY_DATETIME = auth_methods.datetime.now( - auth_methods.timezone.utc -) + auth_methods.timedelta(hours=1) - - -# Expected exception messages from auth_methods.py -FETCH_TOKEN_FAILURE_MSG = "Failed to fetch Google ID token." -FETCH_ASYNC_TOKEN_FAILURE_MSG = "Failed to fetch async Google ID token." -# These will now match the actual messages from refresh.side_effect -NETWORK_ERROR_MSG = "Network error" -TIMEOUT_ERROR_MSG = "Timeout error" +MOCK_EXPIRY_TIMESTAMP = int(time.time()) + 3600 +MOCK_EXPIRY_DATETIME = auth_methods.datetime.fromtimestamp( + MOCK_EXPIRY_TIMESTAMP, tz=auth_methods.timezone.utc +) @pytest.fixture(autouse=True) -def reset_cache_after_each_test(): - """Fixture to reset the cache before each test.""" - # Store initial state - original_cache_state = auth_methods._cached_google_id_token.copy() - auth_methods._cached_google_id_token = {"token": None, "expires_at": 0} +def reset_cache(): + """Fixture to reset the module's token cache before each test.""" + original_cache = auth_methods._token_cache.copy() + # Reset to the initial empty state as defined in the new module + auth_methods._token_cache["token"] = None + auth_methods._token_cache["expires_at"] = auth_methods.datetime.min.replace( + tzinfo=auth_methods.timezone.utc + ) yield - # Restore initial state (optional, but good for isolation) - auth_methods._cached_google_id_token = original_cache_state + auth_methods._token_cache = original_cache +@pytest.mark.asyncio class TestAsyncAuthMethods: """Tests for asynchronous Google ID token fetching.""" - @pytest.mark.asyncio - @patch("toolbox_core.auth_methods._aiohttp_requests.Request") - @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) + @patch("toolbox_core.auth_methods.id_token.verify_oauth2_token") + @patch("toolbox_core.auth_methods.id_token.fetch_id_token") + @patch( + "toolbox_core.auth_methods.google.auth.default", + return_value=(MagicMock(id_token=None), MOCK_PROJECT_ID), + ) async def test_aget_google_id_token_success_first_call( - self, mock_default_async, mock_async_req_class + self, mock_default, mock_fetch, mock_verify ): """Tests successful fetching of an async token on the first call.""" - mock_creds_instance = AsyncMock() - mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN - type(mock_creds_instance).expiry = PropertyMock( - return_value=MOCK_EXPIRY_DATETIME - ) - mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - - mock_async_req_instance = MagicMock() - mock_async_req_class.return_value = mock_async_req_instance - token = await auth_methods.aget_google_id_token() - - mock_default_async.assert_called_once_with() - mock_async_req_class.assert_called_once_with() - mock_creds_instance.refresh.assert_called_once_with(mock_async_req_instance) - - assert ( - mock_creds_instance.before_request.func - is auth_methods.Credentials.before_request - ) - assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" - assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN - assert ( - auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_DATETIME - ) - - @pytest.mark.asyncio - @patch("toolbox_core.auth_methods._aiohttp_requests.Request") - @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) - async def test_aget_google_id_token_success_uses_cache( - self, mock_default_async, mock_async_req_class - ): + mock_fetch.return_value = MOCK_ID_TOKEN + mock_verify.return_value = {"exp": MOCK_EXPIRY_TIMESTAMP} + + token_getter = auth_methods.aget_google_id_token(MOCK_AUDIENCE) + token = await token_getter() + + mock_default.assert_called_once() + mock_fetch.assert_called_once_with(ANY, MOCK_AUDIENCE) + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_ID_TOKEN}" + assert auth_methods._token_cache["token"] == MOCK_ID_TOKEN + assert auth_methods._token_cache["expires_at"] == MOCK_EXPIRY_DATETIME + + @patch("toolbox_core.auth_methods.google.auth.default") + async def test_aget_google_id_token_success_uses_cache(self, mock_default): """Tests that subsequent calls use the cached token if valid.""" - auth_methods._cached_google_id_token["token"] = MOCK_GOOGLE_ID_TOKEN - auth_methods._cached_google_id_token["expires_at"] = auth_methods.datetime.now( + auth_methods._token_cache["token"] = MOCK_ID_TOKEN + auth_methods._token_cache["expires_at"] = auth_methods.datetime.now( auth_methods.timezone.utc - ) + auth_methods.timedelta( - seconds=auth_methods.CACHE_REFRESH_MARGIN_SECONDS + 100 - ) # Ensure it's valid - - token = await auth_methods.aget_google_id_token() + ) + auth_methods.timedelta(hours=1) - mock_default_async.assert_not_called() - mock_async_req_class.assert_not_called() + token_getter = auth_methods.aget_google_id_token(MOCK_AUDIENCE) + token = await token_getter() - assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" - assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN + mock_default.assert_not_called() + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_ID_TOKEN}" - @pytest.mark.asyncio - @patch("toolbox_core.auth_methods._aiohttp_requests.Request") - @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) + @patch("toolbox_core.auth_methods.id_token.verify_oauth2_token") + @patch("toolbox_core.auth_methods.id_token.fetch_id_token") + @patch( + "toolbox_core.auth_methods.google.auth.default", + return_value=(MagicMock(id_token=None), MOCK_PROJECT_ID), + ) async def test_aget_google_id_token_refreshes_expired_cache( - self, mock_default_async, mock_async_req_class + self, mock_default, mock_fetch, mock_verify ): """Tests that an expired cached token triggers a refresh.""" - auth_methods._cached_google_id_token["token"] = "expired_token" - auth_methods._cached_google_id_token["expires_at"] = auth_methods.datetime.now( + auth_methods._token_cache["token"] = "expired_token" + auth_methods._token_cache["expires_at"] = auth_methods.datetime.now( auth_methods.timezone.utc - ) - auth_methods.timedelta( - seconds=100 - ) # Expired - - mock_creds_instance = AsyncMock() - mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN # New token after refresh - type(mock_creds_instance).expiry = PropertyMock( - return_value=MOCK_EXPIRY_DATETIME - ) - mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - - mock_async_req_instance = MagicMock() - mock_async_req_class.return_value = mock_async_req_instance - - token = await auth_methods.aget_google_id_token() - - mock_default_async.assert_called_once_with() - mock_async_req_class.assert_called_once_with() - mock_creds_instance.refresh.assert_called_once_with(mock_async_req_instance) - assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" - assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN - assert ( - auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_DATETIME - ) - - @pytest.mark.asyncio - @patch("toolbox_core.auth_methods._aiohttp_requests.Request") - @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) - async def test_aget_google_id_token_fetch_failure( - self, mock_default_async, mock_async_req_class + ) - auth_methods.timedelta(seconds=100) + + mock_fetch.return_value = MOCK_ID_TOKEN + mock_verify.return_value = {"exp": MOCK_EXPIRY_TIMESTAMP} + + token_getter = auth_methods.aget_google_id_token(MOCK_AUDIENCE) + token = await token_getter() + + mock_default.assert_called_once() + mock_fetch.assert_called_once() + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_ID_TOKEN}" + assert auth_methods._token_cache["token"] == MOCK_ID_TOKEN + + @patch("toolbox_core.auth_methods.id_token.fetch_id_token") + @patch( + "toolbox_core.auth_methods.google.auth.default", + return_value=(MagicMock(id_token=None), MOCK_PROJECT_ID), + ) + async def test_aget_raises_if_no_audience_and_no_local_token( + self, mock_default, mock_fetch ): - """Tests error handling when fetching the token fails (no id_token returned).""" - mock_creds_instance = AsyncMock() - mock_creds_instance.id_token = None # Simulate no ID token after refresh - type(mock_creds_instance).expiry = PropertyMock( - return_value=MOCK_EXPIRY_DATETIME - ) # Still need expiry for update_cache - mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - mock_async_req_class.return_value = MagicMock() - - with pytest.raises(Exception, match=FETCH_ASYNC_TOKEN_FAILURE_MSG): - await auth_methods.aget_google_id_token() - - assert auth_methods._cached_google_id_token["token"] is None - assert auth_methods._cached_google_id_token["expires_at"] == 0 - mock_async_req_class.assert_called_once_with() - mock_creds_instance.refresh.assert_called_once() - - @pytest.mark.asyncio - @patch("toolbox_core.auth_methods._aiohttp_requests.Request") - @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) - async def test_aget_google_id_token_refresh_raises_exception( - self, mock_default_async, mock_async_req_class - ): - """Tests exception handling when credentials refresh fails.""" - mock_creds_instance = AsyncMock() - mock_creds_instance.refresh.side_effect = Exception(NETWORK_ERROR_MSG) - mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - mock_async_req_class.return_value = MagicMock() - - with pytest.raises(Exception, match=NETWORK_ERROR_MSG): - await auth_methods.aget_google_id_token() - - assert auth_methods._cached_google_id_token["token"] is None - assert auth_methods._cached_google_id_token["expires_at"] == 0 - mock_async_req_class.assert_called_once_with() - mock_creds_instance.refresh.assert_called_once() - - @pytest.mark.asyncio - @patch("toolbox_core.auth_methods._aiohttp_requests.Request") - @patch("toolbox_core.auth_methods.default_async", new_callable=MagicMock) - async def test_aget_google_id_token_no_expiry_info( - self, mock_default_async, mock_async_req_class - ): - """Tests that a token without expiry info is still cached but effectively expired.""" - mock_creds_instance = AsyncMock() - mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN - type(mock_creds_instance).expiry = PropertyMock( - return_value=None - ) # Simulate no expiry info - mock_default_async.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - - mock_async_req_class.return_value = MagicMock() - - token = await auth_methods.aget_google_id_token() + """Tests that the async function propagates the missing audience exception.""" + error_msg = "You are not authenticating using User Credentials." + with pytest.raises(Exception, match=error_msg): + token_getter = auth_methods.aget_google_id_token() + await token_getter() - assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" - assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN - assert ( - auth_methods._cached_google_id_token["expires_at"] == 0 - ) # Should be 0 if no expiry - mock_async_req_class.assert_called_once_with() + mock_default.assert_called_once() + mock_fetch.assert_not_called() class TestSyncAuthMethods: """Tests for synchronous Google ID token fetching.""" + @patch("toolbox_core.auth_methods.id_token.verify_oauth2_token") @patch("toolbox_core.auth_methods.Request") @patch("toolbox_core.auth_methods.AuthorizedSession") @patch("toolbox_core.auth_methods.google.auth.default") - def test_get_google_id_token_success_first_call( - self, - mock_sync_default, - mock_auth_session_class, - mock_sync_req_class, + def test_get_google_id_token_success_local_creds( + self, mock_default, mock_session, mock_request, mock_verify ): - """Tests successful fetching of a sync token on the first call.""" - mock_creds_instance = MagicMock() - mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN - type(mock_creds_instance).expiry = PropertyMock( - return_value=MOCK_EXPIRY_DATETIME - ) - mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - + """Tests successful fetching via local credentials.""" + mock_creds = MagicMock() + mock_creds.id_token = MOCK_ID_TOKEN + mock_default.return_value = (mock_creds, MOCK_PROJECT_ID) + mock_verify.return_value = {"exp": MOCK_EXPIRY_TIMESTAMP} mock_session_instance = MagicMock() - mock_auth_session_class.return_value = mock_session_instance - - mock_sync_request_instance = MagicMock() - mock_sync_req_class.return_value = mock_sync_request_instance - - token = auth_methods.get_google_id_token() + mock_session.return_value = mock_session_instance + mock_request_instance = MagicMock() + mock_request.return_value = mock_request_instance - mock_sync_default.assert_called_once_with() - mock_auth_session_class.assert_called_once_with(mock_creds_instance) - mock_sync_req_class.assert_called_once_with(mock_session_instance) - mock_creds_instance.refresh.assert_called_once_with(mock_sync_request_instance) + token_getter = auth_methods.get_google_id_token(MOCK_AUDIENCE) + token = token_getter() - assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" - assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN - assert ( - auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_DATETIME - ) + mock_default.assert_called_once_with() + mock_session.assert_called_once_with(mock_creds) + mock_creds.refresh.assert_called_once_with(mock_request_instance) + mock_verify.assert_called_once_with(MOCK_ID_TOKEN, ANY) + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_ID_TOKEN}" + assert auth_methods._token_cache["token"] == MOCK_ID_TOKEN + assert auth_methods._token_cache["expires_at"] == MOCK_EXPIRY_DATETIME - @patch("toolbox_core.auth_methods.Request") - @patch("toolbox_core.auth_methods.AuthorizedSession") @patch("toolbox_core.auth_methods.google.auth.default") - def test_get_google_id_token_success_uses_cache( - self, - mock_sync_default, - mock_auth_session_class, - mock_sync_req_class, - ): + def test_get_google_id_token_success_uses_cache(self, mock_default): """Tests that subsequent calls use the cached token if valid.""" - auth_methods._cached_google_id_token["token"] = MOCK_GOOGLE_ID_TOKEN - auth_methods._cached_google_id_token["expires_at"] = auth_methods.datetime.now( - auth_methods.timezone.utc - ) + auth_methods.timedelta( - seconds=auth_methods.CACHE_REFRESH_MARGIN_SECONDS + 100 - ) # Ensure it's valid - - token = auth_methods.get_google_id_token() - - mock_sync_default.assert_not_called() - mock_auth_session_class.assert_not_called() - mock_sync_req_class.assert_not_called() - - assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" - assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN - - @patch("toolbox_core.auth_methods.Request") - @patch("toolbox_core.auth_methods.AuthorizedSession") - @patch("toolbox_core.auth_methods.google.auth.default") - def test_get_google_id_token_refreshes_expired_cache( - self, - mock_sync_default, - mock_auth_session_class, - mock_sync_req_class, - ): - """Tests that an expired cached token triggers a refresh.""" - # Prime the cache with an expired token - auth_methods._cached_google_id_token["token"] = "expired_token_sync" - auth_methods._cached_google_id_token["expires_at"] = auth_methods.datetime.now( + auth_methods._token_cache["token"] = MOCK_ID_TOKEN + auth_methods._token_cache["expires_at"] = auth_methods.datetime.now( auth_methods.timezone.utc - ) - auth_methods.timedelta( - seconds=100 - ) # Expired + ) + auth_methods.timedelta(hours=1) - mock_creds_instance = MagicMock() - mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN # New token after refresh - type(mock_creds_instance).expiry = PropertyMock( - return_value=MOCK_EXPIRY_DATETIME - ) - mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) + token_getter = auth_methods.get_google_id_token(MOCK_AUDIENCE) + token = token_getter() - mock_session_instance = MagicMock() - mock_auth_session_class.return_value = mock_session_instance + mock_default.assert_not_called() + assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_ID_TOKEN}" - mock_sync_request_instance = MagicMock() - mock_sync_req_class.return_value = mock_sync_request_instance - - token = auth_methods.get_google_id_token() - - mock_sync_default.assert_called_once_with() - mock_auth_session_class.assert_called_once_with(mock_creds_instance) - mock_sync_req_class.assert_called_once_with(mock_session_instance) - mock_creds_instance.refresh.assert_called_once_with(mock_sync_request_instance) - assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" - assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN - assert ( - auth_methods._cached_google_id_token["expires_at"] == MOCK_EXPIRY_DATETIME - ) - - @patch("toolbox_core.auth_methods.Request") - @patch("toolbox_core.auth_methods.AuthorizedSession") - @patch("toolbox_core.auth_methods.google.auth.default") + @patch("toolbox_core.auth_methods.id_token.verify_oauth2_token") + @patch("toolbox_core.auth_methods.id_token.fetch_id_token") + @patch( + "toolbox_core.auth_methods.google.auth.default", + return_value=(MagicMock(id_token=None), MOCK_PROJECT_ID), + ) def test_get_google_id_token_fetch_failure( - self, mock_sync_default, mock_auth_session_class, mock_sync_req_class + self, mock_default, mock_fetch, mock_verify ): - """Tests error handling when fetching the token fails (no id_token returned).""" - mock_creds_instance = MagicMock() - mock_creds_instance.id_token = None # Simulate no ID token after refresh - type(mock_creds_instance).expiry = PropertyMock( - return_value=MOCK_EXPIRY_DATETIME - ) # Still need expiry for update_cache - mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - - mock_session_instance = MagicMock() - mock_auth_session_class.return_value = mock_session_instance - - mock_sync_req_class.return_value = MagicMock() - - with pytest.raises(Exception, match=FETCH_TOKEN_FAILURE_MSG): - auth_methods.get_google_id_token() - - assert auth_methods._cached_google_id_token["token"] is None - assert auth_methods._cached_google_id_token["expires_at"] == 0 - mock_sync_default.assert_called_once_with() - mock_auth_session_class.assert_called_once_with(mock_creds_instance) - mock_sync_req_class.assert_called_once_with(mock_session_instance) - mock_creds_instance.refresh.assert_called_once() - - @patch("toolbox_core.auth_methods.Request") - @patch("toolbox_core.auth_methods.AuthorizedSession") - @patch("toolbox_core.auth_methods.google.auth.default") - def test_get_google_id_token_refresh_raises_exception( - self, mock_sync_default, mock_auth_session_class, mock_sync_req_class + """Tests error handling when fetching the token fails.""" + mock_fetch.side_effect = GoogleAuthError("Fetch failed") + + with pytest.raises(GoogleAuthError, match="Fetch failed"): + auth_methods.get_google_id_token(MOCK_AUDIENCE)() + + assert auth_methods._token_cache["token"] is None + mock_default.assert_called_once() + mock_fetch.assert_called_once() + mock_verify.assert_not_called() + + @patch("toolbox_core.auth_methods.id_token.verify_oauth2_token") + @patch("toolbox_core.auth_methods.id_token.fetch_id_token") + @patch( + "toolbox_core.auth_methods.google.auth.default", + return_value=(MagicMock(id_token=None), MOCK_PROJECT_ID), + ) + def test_get_google_id_token_validation_failure( + self, mock_default, mock_fetch, mock_verify ): - """Tests exception handling when credentials refresh fails.""" - mock_creds_instance = MagicMock() - mock_creds_instance.refresh.side_effect = Exception(TIMEOUT_ERROR_MSG) - mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - - mock_session_instance = MagicMock() - mock_auth_session_class.return_value = mock_session_instance - - mock_sync_req_class.return_value = MagicMock() - - with pytest.raises(Exception, match=TIMEOUT_ERROR_MSG): - auth_methods.get_google_id_token() - - assert auth_methods._cached_google_id_token["token"] is None - assert auth_methods._cached_google_id_token["expires_at"] == 0 - mock_sync_default.assert_called_once_with() - mock_auth_session_class.assert_called_once_with(mock_creds_instance) - mock_sync_req_class.assert_called_once_with(mock_session_instance) - mock_creds_instance.refresh.assert_called_once() - - @patch("toolbox_core.auth_methods.Request") - @patch("toolbox_core.auth_methods.AuthorizedSession") - @patch("toolbox_core.auth_methods.google.auth.default") - def test_get_google_id_token_no_expiry_info( - self, - mock_sync_default, - mock_auth_session_class, - mock_sync_req_class, + """Tests that an invalid token from fetch raises a ValueError.""" + mock_fetch.return_value = MOCK_ID_TOKEN + mock_verify.side_effect = ValueError("Invalid signature") + + with pytest.raises( + ValueError, match="Failed to validate and cache the new token" + ): + auth_methods.get_google_id_token(MOCK_AUDIENCE)() + + assert auth_methods._token_cache["token"] is None + + @patch("toolbox_core.auth_methods.id_token.fetch_id_token") + @patch( + "toolbox_core.auth_methods.google.auth.default", + return_value=(MagicMock(id_token=None), MOCK_PROJECT_ID), + ) + def test_get_raises_if_no_audience_and_no_local_token( + self, mock_default, mock_fetch ): - """Tests that a token without expiry info is still cached but effectively expired.""" - mock_creds_instance = MagicMock() - mock_creds_instance.id_token = MOCK_GOOGLE_ID_TOKEN - type(mock_creds_instance).expiry = PropertyMock( - return_value=None - ) # Simulate no expiry info - mock_sync_default.return_value = (mock_creds_instance, MOCK_PROJECT_ID) - - mock_session_instance = MagicMock() - mock_auth_session_class.return_value = mock_session_instance - - mock_sync_request_instance = MagicMock() - mock_sync_req_class.return_value = mock_sync_request_instance - - token = auth_methods.get_google_id_token() + """Tests exception is raised if audience is required but not provided.""" + error_msg = "You are not authenticating using User Credentials." + with pytest.raises(Exception, match=error_msg): + auth_methods.get_google_id_token()() - assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_GOOGLE_ID_TOKEN}" - assert auth_methods._cached_google_id_token["token"] == MOCK_GOOGLE_ID_TOKEN - assert ( - auth_methods._cached_google_id_token["expires_at"] == 0 - ) # Should be 0 if no expiry - mock_sync_default.assert_called_once_with() - mock_auth_session_class.assert_called_once_with(mock_creds_instance) - mock_sync_req_class.assert_called_once_with(mock_session_instance) + mock_default.assert_called_once() + mock_fetch.assert_not_called() diff --git a/packages/toolbox-langchain/README.md b/packages/toolbox-langchain/README.md index 9bf95f48..eafd010a 100644 --- a/packages/toolbox-langchain/README.md +++ b/packages/toolbox-langchain/README.md @@ -272,7 +272,7 @@ For Toolbox servers hosted on Google Cloud (e.g., Cloud Run) and requiring from toolbox_langchain import ToolboxClient from toolbox_core import auth_methods - auth_token_provider = auth_methods.aget_google_id_token # can also use sync method + auth_token_provider = auth_methods.aget_google_id_token(URL) # can also use sync method async with ToolboxClient( URL, client_headers={"Authorization": auth_token_provider}, diff --git a/packages/toolbox-llamaindex/README.md b/packages/toolbox-llamaindex/README.md index 6a993f68..04f7ac0d 100644 --- a/packages/toolbox-llamaindex/README.md +++ b/packages/toolbox-llamaindex/README.md @@ -251,7 +251,7 @@ For Toolbox servers hosted on Google Cloud (e.g., Cloud Run) and requiring from toolbox_llamaindex import ToolboxClient from toolbox_core import auth_methods - auth_token_provider = auth_methods.aget_google_id_token # can also use sync method + auth_token_provider = auth_methods.aget_google_id_token(URL) async with ToolboxClient( URL, client_headers={"Authorization": auth_token_provider},