Skip to content

Commit 726de7e

Browse files
committed
fix: fix auth_methods module
1 parent 1b0c666 commit 726de7e

File tree

1 file changed

+80
-123
lines changed

1 file changed

+80
-123
lines changed

packages/toolbox-core/src/toolbox_core/auth_methods.py

Lines changed: 80 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -13,159 +13,116 @@
1313
# limitations under the License.
1414

1515
"""
16-
This module provides functions to obtain Google ID tokens, formatted as "Bearer" tokens,
17-
for use in the "Authorization" header of HTTP requests.
16+
This module provides functions to obtain Google ID tokens for a specific audience.
1817
19-
Example User Experience:
18+
The tokens are returned as "Bearer" strings for direct use in HTTP Authorization
19+
headers. It uses a simple in-memory cache to avoid refetching on every call.
20+
21+
Example Usage:
2022
from toolbox_core import auth_methods
23+
from functools import partial
2124
22-
auth_token_provider = auth_methods.aget_google_id_token
23-
toolbox = ToolboxClient(
24-
URL,
25-
client_headers={"Authorization": auth_token_provider},
25+
auth_token_provider = functools.partial(
26+
auth_methods.aget_google_id_token,
27+
"https://toolbox-service-url"
2628
)
27-
tools = await toolbox.load_toolset()
29+
client = ToolboxClient(URL, client_headers={"Authorization": auth_token_provider})
30+
await client.make_request()
2831
"""
2932

3033
from datetime import datetime, timedelta, timezone
31-
from functools import partial
32-
from typing import Any, Dict, Optional
33-
34+
from typing import Any, Dict
3435
import google.auth
35-
from google.auth._credentials_async import Credentials
36-
from google.auth._default_async import default_async
37-
from google.auth.transport import _aiohttp_requests
38-
from google.auth.transport.requests import AuthorizedSession, Request
36+
from google.auth.exceptions import GoogleAuthError
37+
from google.auth.transport.requests import Request, AuthorizedSession
38+
from google.oauth2 import id_token
39+
import asyncio
3940

40-
# --- Constants and Configuration ---
41-
# Prefix for Authorization header tokens
41+
# --- Constants ---
4242
BEARER_TOKEN_PREFIX = "Bearer "
43-
# Margin in seconds to refresh token before its actual expiry
44-
CACHE_REFRESH_MARGIN_SECONDS = 60
45-
43+
CACHE_REFRESH_MARGIN = timedelta(seconds=60)
4644

47-
# --- Global Cache Storage ---
48-
# Stores the cached Google ID token and its expiry timestamp
49-
_cached_google_id_token: Dict[str, Any] = {"token": None, "expires_at": 0}
50-
51-
52-
# --- Helper Functions ---
53-
def _is_cached_token_valid(
54-
cache: Dict[str, Any], margin_seconds: int = CACHE_REFRESH_MARGIN_SECONDS
55-
) -> bool:
56-
"""
57-
Checks if a token in the cache is valid (exists and not expired).
58-
59-
Args:
60-
cache: The dictionary containing 'token' and 'expires_at'.
61-
margin_seconds: The time in seconds before expiry to consider the token invalid.
45+
_token_cache: Dict[str, Any] = {"token": None, "expires_at": datetime.min.replace(tzinfo=timezone.utc)}
6246

63-
Returns:
64-
True if the token is valid, False otherwise.
65-
"""
66-
if not cache.get("token"):
47+
def _is_token_valid() -> bool:
48+
"""Checks if the cached token exists and is not nearing expiry."""
49+
if not _token_cache["token"]:
6750
return False
51+
return datetime.now(timezone.utc) < (_token_cache["expires_at"] - CACHE_REFRESH_MARGIN)
6852

69-
expires_at_value = cache.get("expires_at")
70-
if not isinstance(expires_at_value, datetime):
71-
return False
72-
73-
# Ensure expires_at_value is timezone-aware (UTC).
74-
if (
75-
expires_at_value.tzinfo is None
76-
or expires_at_value.tzinfo.utcoffset(expires_at_value) is None
77-
):
78-
expires_at_value = expires_at_value.replace(tzinfo=timezone.utc)
79-
80-
current_time_utc = datetime.now(timezone.utc)
81-
if current_time_utc + timedelta(seconds=margin_seconds) < expires_at_value:
82-
return True
83-
84-
return False
85-
86-
87-
def _update_token_cache(
88-
cache: Dict[str, Any], new_id_token: Optional[str], expiry: Optional[datetime]
89-
) -> None:
53+
def _update_cache(new_token: str) -> None:
9054
"""
91-
Updates the global token cache with a new token and its expiry.
92-
55+
Validates a new token, extracts its expiry, and updates the cache.
56+
9357
Args:
94-
cache: The dictionary containing 'token' and 'expires_at'.
95-
new_id_token: The new ID token string to cache.
58+
new_token: The new JWT ID token string.
59+
60+
Raises:
61+
ValueError: If the token is invalid or its expiry cannot be determined.
9662
"""
97-
if new_id_token:
98-
cache["token"] = new_id_token
99-
expiry_timestamp = expiry
100-
if expiry_timestamp:
101-
cache["expires_at"] = expiry_timestamp
102-
else:
103-
# If expiry can't be determined, treat as immediately expired to force refresh
104-
cache["expires_at"] = 0
105-
else:
106-
# Clear cache if no new token is provided
107-
cache["token"] = None
108-
cache["expires_at"] = 0
63+
try:
64+
# verify_oauth2_token not only decodes but also validates the token's
65+
# signature and claims against Google's public keys.
66+
# It's a synchronous, CPU-bound operation, safe for async contexts.
67+
claims = id_token.verify_oauth2_token(new_token, Request())
68+
69+
expiry_timestamp = claims.get("exp")
70+
if not expiry_timestamp:
71+
raise ValueError("Token does not contain an 'exp' claim.")
72+
73+
_token_cache["token"] = new_token
74+
_token_cache["expires_at"] = datetime.fromtimestamp(expiry_timestamp, tz=timezone.utc)
75+
76+
except (ValueError, GoogleAuthError) as e:
77+
# Clear cache on failure to prevent using a stale or invalid token
78+
_token_cache["token"] = None
79+
_token_cache["expires_at"] = datetime.min.replace(tzinfo=timezone.utc)
80+
raise ValueError(f"Failed to validate and cache the new token: {e}") from e
10981

11082

11183
# --- Public API Functions ---
112-
def get_google_id_token() -> str:
84+
85+
def get_google_id_token(audience: str) -> str:
11386
"""
114-
Synchronously fetches a Google ID token.
87+
Synchronously fetches a Google ID token for a specific audience.
11588
116-
The token is formatted as a 'Bearer' token string and is suitable for use
117-
in an HTTP Authorization header. This function uses Application Default
118-
Credentials.
89+
This function uses Application Default Credentials and caches the token in memory.
90+
91+
Args:
92+
audience: The audience for the ID token (e.g., a service URL or client ID).
11993
12094
Returns:
12195
A string in the format "Bearer <google_id_token>".
12296
12397
Raises:
124-
Exception: If fetching the Google ID token fails.
98+
GoogleAuthError: If fetching credentials or the token fails.
99+
ValueError: If the fetched token is invalid.
125100
"""
126-
if _is_cached_token_valid(_cached_google_id_token):
127-
return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"]
128-
101+
if _is_token_valid():
102+
return BEARER_TOKEN_PREFIX + _token_cache["token"]
103+
104+
# Get local user credentials
129105
credentials, _ = google.auth.default()
130106
session = AuthorizedSession(credentials)
131107
request = Request(session)
132108
credentials.refresh(request)
133-
new_id_token = getattr(credentials, "id_token", None)
134-
expiry = getattr(credentials, "expiry")
135-
136-
_update_token_cache(_cached_google_id_token, new_id_token, expiry)
137-
if new_id_token:
138-
return BEARER_TOKEN_PREFIX + new_id_token
139-
else:
140-
raise Exception("Failed to fetch Google ID token.")
141-
142-
143-
async def aget_google_id_token() -> str:
144-
"""
145-
Asynchronously fetches a Google ID token.
146-
147-
The token is formatted as a 'Bearer' token string and is suitable for use
148-
in an HTTP Authorization header. This function uses Application Default
149-
Credentials.
150-
151-
Returns:
152-
A string in the format "Bearer <google_id_token>".
153-
154-
Raises:
155-
Exception: If fetching the Google ID token fails.
156-
"""
157-
if _is_cached_token_valid(_cached_google_id_token):
158-
return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"]
159-
160-
credentials, _ = default_async()
161-
await credentials.refresh(_aiohttp_requests.Request())
162-
credentials.before_request = partial(Credentials.before_request, credentials)
163-
new_id_token = getattr(credentials, "id_token", None)
164-
expiry = getattr(credentials, "expiry")
165-
166-
_update_token_cache(_cached_google_id_token, new_id_token, expiry)
167109

168-
if new_id_token:
169-
return BEARER_TOKEN_PREFIX + new_id_token
170-
else:
171-
raise Exception("Failed to fetch async Google ID token.")
110+
if hasattr(credentials, "id_token"):
111+
new_id_token = getattr(credentials, "id_token", None)
112+
if new_id_token:
113+
_update_cache(new_id_token)
114+
return BEARER_TOKEN_PREFIX + new_id_token
115+
116+
# Get credentials for Google Cloud environments
117+
try:
118+
request = Request()
119+
new_token = id_token.fetch_id_token(request, audience)
120+
_update_cache(new_token)
121+
return BEARER_TOKEN_PREFIX + _token_cache["token"]
122+
123+
except GoogleAuthError as e:
124+
raise GoogleAuthError(f"Failed to fetch Google ID token for audience '{audience}': {e}") from e
125+
126+
async def aget_google_id_token(audience: str) -> str:
127+
token = await asyncio.to_thread(get_google_id_token, audience)
128+
return token

0 commit comments

Comments
 (0)