|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | """
|
16 |
| -This module provides functions to obtain Google ID tokens, formatted as "Bearer" tokens, |
17 |
| -for use in the "Authorization" header of HTTP requests. |
| 16 | +This module provides functions to obtain Google ID tokens for a specific audience. |
18 | 17 |
|
19 |
| -Example User Experience: |
| 18 | +The tokens are returned as "Bearer" strings for direct use in HTTP Authorization |
| 19 | +headers. It uses a simple in-memory cache to avoid refetching on every call. |
| 20 | +
|
| 21 | +Example Usage: |
20 | 22 | from toolbox_core import auth_methods
|
21 | 23 |
|
22 |
| -auth_token_provider = auth_methods.aget_google_id_token |
| 24 | +URL = "https://toolbox-service-url" |
23 | 25 | async with ToolboxClient(
|
24 | 26 | URL,
|
25 |
| - client_headers={"Authorization": auth_token_provider}, |
26 |
| -) as toolbox: |
| 27 | + client_headers={"Authorization": auth_methods.aget_google_id_token}) |
| 28 | +as toolbox: |
27 | 29 | tools = await toolbox.load_toolset()
|
28 | 30 | """
|
29 | 31 |
|
| 32 | +import asyncio |
30 | 33 | from datetime import datetime, timedelta, timezone
|
31 |
| -from functools import partial |
32 |
| -from typing import Any, Dict, Optional |
| 34 | +from typing import Any, Callable, Coroutine, Dict, Optional |
33 | 35 |
|
34 | 36 | import google.auth
|
35 |
| -from google.auth._credentials_async import Credentials |
36 |
| -from google.auth._default_async import default_async |
37 |
| -from google.auth.transport import _aiohttp_requests |
| 37 | +from google.auth.exceptions import GoogleAuthError |
38 | 38 | from google.auth.transport.requests import AuthorizedSession, Request
|
| 39 | +from google.oauth2 import id_token |
39 | 40 |
|
40 |
| -# --- Constants and Configuration --- |
41 |
| -# Prefix for Authorization header tokens |
| 41 | +# --- Constants --- |
42 | 42 | BEARER_TOKEN_PREFIX = "Bearer "
|
43 |
| -# Margin in seconds to refresh token before its actual expiry |
44 |
| -CACHE_REFRESH_MARGIN_SECONDS = 60 |
| 43 | +CACHE_REFRESH_MARGIN = timedelta(seconds=60) |
| 44 | + |
| 45 | +_token_cache: Dict[str, Any] = { |
| 46 | + "token": None, |
| 47 | + "expires_at": datetime.min.replace(tzinfo=timezone.utc), |
| 48 | +} |
45 | 49 |
|
46 | 50 |
|
47 |
| -# --- Global Cache Storage --- |
48 |
| -# Stores the cached Google ID token and its expiry timestamp |
49 |
| -_cached_google_id_token: Dict[str, Any] = {"token": None, "expires_at": 0} |
| 51 | +def _is_token_valid() -> bool: |
| 52 | + """Checks if the cached token exists and is not nearing expiry.""" |
| 53 | + if not _token_cache["token"]: |
| 54 | + return False |
| 55 | + return datetime.now(timezone.utc) < ( |
| 56 | + _token_cache["expires_at"] - CACHE_REFRESH_MARGIN |
| 57 | + ) |
50 | 58 |
|
51 | 59 |
|
52 |
| -# --- Helper Functions --- |
53 |
| -def _is_cached_token_valid( |
54 |
| - cache: Dict[str, Any], margin_seconds: int = CACHE_REFRESH_MARGIN_SECONDS |
55 |
| -) -> bool: |
| 60 | +def _update_cache(new_token: str) -> None: |
56 | 61 | """
|
57 |
| - Checks if a token in the cache is valid (exists and not expired). |
| 62 | + Validates a new token, extracts its expiry, and updates the cache. |
58 | 63 |
|
59 | 64 | Args:
|
60 |
| - cache: The dictionary containing 'token' and 'expires_at'. |
61 |
| - margin_seconds: The time in seconds before expiry to consider the token invalid. |
| 65 | + new_token: The new JWT ID token string. |
62 | 66 |
|
63 |
| - Returns: |
64 |
| - True if the token is valid, False otherwise. |
| 67 | + Raises: |
| 68 | + ValueError: If the token is invalid or its expiry cannot be determined. |
65 | 69 | """
|
66 |
| - if not cache.get("token"): |
67 |
| - return False |
| 70 | + try: |
| 71 | + # verify_oauth2_token not only decodes but also validates the token's |
| 72 | + # signature and claims against Google's public keys. |
| 73 | + # It's a synchronous, CPU-bound operation, safe for async contexts. |
| 74 | + claims = id_token.verify_oauth2_token(new_token, Request()) |
68 | 75 |
|
69 |
| - expires_at_value = cache.get("expires_at") |
70 |
| - if not isinstance(expires_at_value, datetime): |
71 |
| - return False |
| 76 | + expiry_timestamp = claims.get("exp") |
| 77 | + if not expiry_timestamp: |
| 78 | + raise ValueError("Token does not contain an 'exp' claim.") |
72 | 79 |
|
73 |
| - # Ensure expires_at_value is timezone-aware (UTC). |
74 |
| - if ( |
75 |
| - expires_at_value.tzinfo is None |
76 |
| - or expires_at_value.tzinfo.utcoffset(expires_at_value) is None |
77 |
| - ): |
78 |
| - expires_at_value = expires_at_value.replace(tzinfo=timezone.utc) |
| 80 | + _token_cache["token"] = new_token |
| 81 | + _token_cache["expires_at"] = datetime.fromtimestamp( |
| 82 | + expiry_timestamp, tz=timezone.utc |
| 83 | + ) |
79 | 84 |
|
80 |
| - current_time_utc = datetime.now(timezone.utc) |
81 |
| - if current_time_utc + timedelta(seconds=margin_seconds) < expires_at_value: |
82 |
| - return True |
| 85 | + except (ValueError, GoogleAuthError) as e: |
| 86 | + # Clear cache on failure to prevent using a stale or invalid token |
| 87 | + _token_cache["token"] = None |
| 88 | + _token_cache["expires_at"] = datetime.min.replace(tzinfo=timezone.utc) |
| 89 | + raise ValueError(f"Failed to validate and cache the new token: {e}") from e |
83 | 90 |
|
84 |
| - return False |
85 | 91 |
|
| 92 | +def get_google_token_from_aud(audience: Optional[str] = None) -> str: |
| 93 | + if _is_token_valid(): |
| 94 | + return BEARER_TOKEN_PREFIX + _token_cache["token"] |
86 | 95 |
|
87 |
| -def _update_token_cache( |
88 |
| - cache: Dict[str, Any], new_id_token: Optional[str], expiry: Optional[datetime] |
89 |
| -) -> None: |
90 |
| - """ |
91 |
| - Updates the global token cache with a new token and its expiry. |
| 96 | + # Get local user credentials |
| 97 | + credentials, _ = google.auth.default() |
| 98 | + session = AuthorizedSession(credentials) |
| 99 | + request = Request(session) |
| 100 | + credentials.refresh(request) |
92 | 101 |
|
93 |
| - Args: |
94 |
| - cache: The dictionary containing 'token' and 'expires_at'. |
95 |
| - new_id_token: The new ID token string to cache. |
96 |
| - """ |
97 |
| - if new_id_token: |
98 |
| - cache["token"] = new_id_token |
99 |
| - expiry_timestamp = expiry |
100 |
| - if expiry_timestamp: |
101 |
| - cache["expires_at"] = expiry_timestamp |
102 |
| - else: |
103 |
| - # If expiry can't be determined, treat as immediately expired to force refresh |
104 |
| - cache["expires_at"] = 0 |
105 |
| - else: |
106 |
| - # Clear cache if no new token is provided |
107 |
| - cache["token"] = None |
108 |
| - cache["expires_at"] = 0 |
109 |
| - |
110 |
| - |
111 |
| -# --- Public API Functions --- |
112 |
| -def get_google_id_token() -> str: |
| 102 | + if hasattr(credentials, "id_token"): |
| 103 | + new_id_token = getattr(credentials, "id_token", None) |
| 104 | + if new_id_token: |
| 105 | + _update_cache(new_id_token) |
| 106 | + return BEARER_TOKEN_PREFIX + new_id_token |
| 107 | + |
| 108 | + if audience is None: |
| 109 | + raise Exception( |
| 110 | + "You are not authenticating using User Credentials." |
| 111 | + " Please set the audience string to the Toolbox service URL to get the Google ID token." |
| 112 | + ) |
| 113 | + |
| 114 | + # Get credentials for Google Cloud environments or for service account key files |
| 115 | + try: |
| 116 | + request = Request() |
| 117 | + new_token = id_token.fetch_id_token(request, audience) |
| 118 | + _update_cache(new_token) |
| 119 | + return BEARER_TOKEN_PREFIX + _token_cache["token"] |
| 120 | + |
| 121 | + except GoogleAuthError as e: |
| 122 | + raise GoogleAuthError( |
| 123 | + f"Failed to fetch Google ID token for audience '{audience}': {e}" |
| 124 | + ) from e |
| 125 | + |
| 126 | + |
| 127 | +def get_google_id_token(audience: Optional[str] = None) -> Callable[[], str]: |
113 | 128 | """
|
114 |
| - Synchronously fetches a Google ID token. |
| 129 | + Returns a SYNC function that, when called, fetches a Google ID token. |
| 130 | + This function uses Application Default Credentials for local systems |
| 131 | + and standard google auth libraries for Google Cloud environments. |
| 132 | + It caches the token in memory. |
115 | 133 |
|
116 |
| - The token is formatted as a 'Bearer' token string and is suitable for use |
117 |
| - in an HTTP Authorization header. This function uses Application Default |
118 |
| - Credentials. |
| 134 | + Args: |
| 135 | + audience: The audience for the ID token (e.g., a service URL or client ID). |
119 | 136 |
|
120 | 137 | Returns:
|
121 |
| - A string in the format "Bearer <google_id_token>". |
| 138 | + A function that when executed returns string in the format "Bearer <google_id_token>". |
122 | 139 |
|
123 | 140 | Raises:
|
124 |
| - Exception: If fetching the Google ID token fails. |
| 141 | + GoogleAuthError: If fetching credentials or the token fails. |
| 142 | + ValueError: If the fetched token is invalid. |
125 | 143 | """
|
126 |
| - if _is_cached_token_valid(_cached_google_id_token): |
127 |
| - return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"] |
128 | 144 |
|
129 |
| - credentials, _ = google.auth.default() |
130 |
| - session = AuthorizedSession(credentials) |
131 |
| - request = Request(session) |
132 |
| - credentials.refresh(request) |
133 |
| - new_id_token = getattr(credentials, "id_token", None) |
134 |
| - expiry = getattr(credentials, "expiry") |
| 145 | + def _token_getter() -> str: |
| 146 | + return get_google_token_from_aud(audience) |
135 | 147 |
|
136 |
| - _update_token_cache(_cached_google_id_token, new_id_token, expiry) |
137 |
| - if new_id_token: |
138 |
| - return BEARER_TOKEN_PREFIX + new_id_token |
139 |
| - else: |
140 |
| - raise Exception("Failed to fetch Google ID token.") |
| 148 | + return _token_getter |
141 | 149 |
|
142 | 150 |
|
143 |
| -async def aget_google_id_token() -> str: |
| 151 | +def aget_google_id_token( |
| 152 | + audience: Optional[str] = None, |
| 153 | +) -> Callable[[], Coroutine[Any, Any, str]]: |
144 | 154 | """
|
145 |
| - Asynchronously fetches a Google ID token. |
| 155 | + Returns an ASYNC function that, when called, fetches a Google ID token. |
| 156 | + This function uses Application Default Credentials for local systems |
| 157 | + and standard google auth libraries for Google Cloud environments. |
| 158 | + It caches the token in memory. |
146 | 159 |
|
147 |
| - The token is formatted as a 'Bearer' token string and is suitable for use |
148 |
| - in an HTTP Authorization header. This function uses Application Default |
149 |
| - Credentials. |
| 160 | + Args: |
| 161 | + audience: The audience for the ID token (e.g., a service URL or client ID). |
150 | 162 |
|
151 | 163 | Returns:
|
152 |
| - A string in the format "Bearer <google_id_token>". |
| 164 | + An async function that when executed returns string in the format "Bearer <google_id_token>". |
153 | 165 |
|
154 | 166 | Raises:
|
155 |
| - Exception: If fetching the Google ID token fails. |
| 167 | + GoogleAuthError: If fetching credentials or the token fails. |
| 168 | + ValueError: If the fetched token is invalid. |
156 | 169 | """
|
157 |
| - if _is_cached_token_valid(_cached_google_id_token): |
158 |
| - return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"] |
159 |
| - |
160 |
| - credentials, _ = default_async() |
161 |
| - await credentials.refresh(_aiohttp_requests.Request()) |
162 |
| - credentials.before_request = partial(Credentials.before_request, credentials) |
163 |
| - new_id_token = getattr(credentials, "id_token", None) |
164 |
| - expiry = getattr(credentials, "expiry") |
165 | 170 |
|
166 |
| - _update_token_cache(_cached_google_id_token, new_id_token, expiry) |
| 171 | + async def _token_getter() -> str: |
| 172 | + return await asyncio.to_thread(get_google_token_from_aud, audience) |
167 | 173 |
|
168 |
| - if new_id_token: |
169 |
| - return BEARER_TOKEN_PREFIX + new_id_token |
170 |
| - else: |
171 |
| - raise Exception("Failed to fetch async Google ID token.") |
| 174 | + return _token_getter |
0 commit comments