13
13
# limitations under the License.
14
14
15
15
"""
16
- This module provides functions to obtain Google ID tokens, formatted as "Bearer" tokens,
17
- for use in the "Authorization" header of HTTP requests.
16
+ This module provides functions to obtain Google ID tokens for a specific audience.
18
17
19
- Example User Experience:
18
+ The tokens are returned as "Bearer" strings for direct use in HTTP Authorization
19
+ headers. It uses a simple in-memory cache to avoid refetching on every call.
20
+
21
+ Example Usage:
20
22
from toolbox_core import auth_methods
23
+ from functools import partial
21
24
22
- auth_token_provider = auth_methods.aget_google_id_token
23
- toolbox = ToolboxClient(
24
- URL,
25
- client_headers={"Authorization": auth_token_provider},
25
+ auth_token_provider = functools.partial(
26
+ auth_methods.aget_google_id_token,
27
+ "https://toolbox-service-url"
26
28
)
27
- tools = await toolbox.load_toolset()
29
+ client = ToolboxClient(URL, client_headers={"Authorization": auth_token_provider})
30
+ await client.make_request()
28
31
"""
29
32
30
33
from datetime import datetime , timedelta , timezone
31
- from functools import partial
32
- from typing import Any , Dict , Optional
33
-
34
+ from typing import Any , Dict
34
35
import google .auth
35
- from google .auth ._credentials_async import Credentials
36
- from google .auth ._default_async import default_async
37
- from google .auth . transport import _aiohttp_requests
38
- from google . auth . transport . requests import AuthorizedSession , Request
36
+ from google .auth .exceptions import GoogleAuthError
37
+ from google .auth .transport . requests import Request , AuthorizedSession
38
+ from google .oauth2 import id_token
39
+ import asyncio
39
40
40
- # --- Constants and Configuration ---
41
- # Prefix for Authorization header tokens
41
+ # --- Constants ---
42
42
BEARER_TOKEN_PREFIX = "Bearer "
43
- # Margin in seconds to refresh token before its actual expiry
44
- CACHE_REFRESH_MARGIN_SECONDS = 60
45
-
43
+ CACHE_REFRESH_MARGIN = timedelta (seconds = 60 )
46
44
47
- # --- Global Cache Storage ---
48
- # Stores the cached Google ID token and its expiry timestamp
49
- _cached_google_id_token : Dict [str , Any ] = {"token" : None , "expires_at" : 0 }
50
-
51
-
52
- # --- Helper Functions ---
53
- def _is_cached_token_valid (
54
- cache : Dict [str , Any ], margin_seconds : int = CACHE_REFRESH_MARGIN_SECONDS
55
- ) -> bool :
56
- """
57
- Checks if a token in the cache is valid (exists and not expired).
58
-
59
- Args:
60
- cache: The dictionary containing 'token' and 'expires_at'.
61
- margin_seconds: The time in seconds before expiry to consider the token invalid.
45
+ _token_cache : Dict [str , Any ] = {"token" : None , "expires_at" : datetime .min .replace (tzinfo = timezone .utc )}
62
46
63
- Returns:
64
- True if the token is valid, False otherwise.
65
- """
66
- if not cache .get ("token" ):
47
+ def _is_token_valid () -> bool :
48
+ """Checks if the cached token exists and is not nearing expiry."""
49
+ if not _token_cache ["token" ]:
67
50
return False
51
+ return datetime .now (timezone .utc ) < (_token_cache ["expires_at" ] - CACHE_REFRESH_MARGIN )
68
52
69
- expires_at_value = cache .get ("expires_at" )
70
- if not isinstance (expires_at_value , datetime ):
71
- return False
72
-
73
- # Ensure expires_at_value is timezone-aware (UTC).
74
- if (
75
- expires_at_value .tzinfo is None
76
- or expires_at_value .tzinfo .utcoffset (expires_at_value ) is None
77
- ):
78
- expires_at_value = expires_at_value .replace (tzinfo = timezone .utc )
79
-
80
- current_time_utc = datetime .now (timezone .utc )
81
- if current_time_utc + timedelta (seconds = margin_seconds ) < expires_at_value :
82
- return True
83
-
84
- return False
85
-
86
-
87
- def _update_token_cache (
88
- cache : Dict [str , Any ], new_id_token : Optional [str ], expiry : Optional [datetime ]
89
- ) -> None :
53
+ def _update_cache (new_token : str ) -> None :
90
54
"""
91
- Updates the global token cache with a new token and its expiry.
92
-
55
+ Validates a new token, extracts its expiry, and updates the cache .
56
+
93
57
Args:
94
- cache: The dictionary containing 'token' and 'expires_at'.
95
- new_id_token: The new ID token string to cache.
58
+ new_token: The new JWT ID token string.
59
+
60
+ Raises:
61
+ ValueError: If the token is invalid or its expiry cannot be determined.
96
62
"""
97
- if new_id_token :
98
- cache ["token" ] = new_id_token
99
- expiry_timestamp = expiry
100
- if expiry_timestamp :
101
- cache ["expires_at" ] = expiry_timestamp
102
- else :
103
- # If expiry can't be determined, treat as immediately expired to force refresh
104
- cache ["expires_at" ] = 0
105
- else :
106
- # Clear cache if no new token is provided
107
- cache ["token" ] = None
108
- cache ["expires_at" ] = 0
63
+ try :
64
+ # verify_oauth2_token not only decodes but also validates the token's
65
+ # signature and claims against Google's public keys.
66
+ # It's a synchronous, CPU-bound operation, safe for async contexts.
67
+ claims = id_token .verify_oauth2_token (new_token , Request ())
68
+
69
+ expiry_timestamp = claims .get ("exp" )
70
+ if not expiry_timestamp :
71
+ raise ValueError ("Token does not contain an 'exp' claim." )
72
+
73
+ _token_cache ["token" ] = new_token
74
+ _token_cache ["expires_at" ] = datetime .fromtimestamp (expiry_timestamp , tz = timezone .utc )
75
+
76
+ except (ValueError , GoogleAuthError ) as e :
77
+ # Clear cache on failure to prevent using a stale or invalid token
78
+ _token_cache ["token" ] = None
79
+ _token_cache ["expires_at" ] = datetime .min .replace (tzinfo = timezone .utc )
80
+ raise ValueError (f"Failed to validate and cache the new token: { e } " ) from e
109
81
110
82
111
83
# --- Public API Functions ---
112
- def get_google_id_token () -> str :
84
+
85
+ def get_google_id_token (audience : str ) -> str :
113
86
"""
114
- Synchronously fetches a Google ID token.
87
+ Synchronously fetches a Google ID token for a specific audience .
115
88
116
- The token is formatted as a 'Bearer' token string and is suitable for use
117
- in an HTTP Authorization header. This function uses Application Default
118
- Credentials.
89
+ This function uses Application Default Credentials and caches the token in memory.
90
+
91
+ Args:
92
+ audience: The audience for the ID token (e.g., a service URL or client ID).
119
93
120
94
Returns:
121
95
A string in the format "Bearer <google_id_token>".
122
96
123
97
Raises:
124
- Exception: If fetching the Google ID token fails.
98
+ GoogleAuthError: If fetching credentials or the token fails.
99
+ ValueError: If the fetched token is invalid.
125
100
"""
126
- if _is_cached_token_valid (_cached_google_id_token ):
127
- return BEARER_TOKEN_PREFIX + _cached_google_id_token ["token" ]
128
-
101
+ if _is_token_valid ():
102
+ return BEARER_TOKEN_PREFIX + _token_cache ["token" ]
103
+
104
+ # Get local user credentials
129
105
credentials , _ = google .auth .default ()
130
106
session = AuthorizedSession (credentials )
131
107
request = Request (session )
132
108
credentials .refresh (request )
133
- new_id_token = getattr (credentials , "id_token" , None )
134
- expiry = getattr (credentials , "expiry" )
135
-
136
- _update_token_cache (_cached_google_id_token , new_id_token , expiry )
137
- if new_id_token :
138
- return BEARER_TOKEN_PREFIX + new_id_token
139
- else :
140
- raise Exception ("Failed to fetch Google ID token." )
141
-
142
-
143
- async def aget_google_id_token () -> str :
144
- """
145
- Asynchronously fetches a Google ID token.
146
-
147
- The token is formatted as a 'Bearer' token string and is suitable for use
148
- in an HTTP Authorization header. This function uses Application Default
149
- Credentials.
150
-
151
- Returns:
152
- A string in the format "Bearer <google_id_token>".
153
-
154
- Raises:
155
- Exception: If fetching the Google ID token fails.
156
- """
157
- if _is_cached_token_valid (_cached_google_id_token ):
158
- return BEARER_TOKEN_PREFIX + _cached_google_id_token ["token" ]
159
-
160
- credentials , _ = default_async ()
161
- await credentials .refresh (_aiohttp_requests .Request ())
162
- credentials .before_request = partial (Credentials .before_request , credentials )
163
- new_id_token = getattr (credentials , "id_token" , None )
164
- expiry = getattr (credentials , "expiry" )
165
-
166
- _update_token_cache (_cached_google_id_token , new_id_token , expiry )
167
109
168
- if new_id_token :
169
- return BEARER_TOKEN_PREFIX + new_id_token
170
- else :
171
- raise Exception ("Failed to fetch async Google ID token." )
110
+ if hasattr (credentials , "id_token" ):
111
+ new_id_token = getattr (credentials , "id_token" , None )
112
+ if new_id_token :
113
+ _update_cache (new_id_token )
114
+ return BEARER_TOKEN_PREFIX + new_id_token
115
+
116
+ # Get credentials for Google Cloud environments
117
+ try :
118
+ request = Request ()
119
+ new_token = id_token .fetch_id_token (request , audience )
120
+ _update_cache (new_token )
121
+ return BEARER_TOKEN_PREFIX + _token_cache ["token" ]
122
+
123
+ except GoogleAuthError as e :
124
+ raise GoogleAuthError (f"Failed to fetch Google ID token for audience '{ audience } ': { e } " ) from e
125
+
126
+ async def aget_google_id_token (audience : str ) -> str :
127
+ token = await asyncio .to_thread (get_google_id_token , audience )
128
+ return token
0 commit comments