Skip to content

Commit 6cf6d6b

Browse files
authored
fix!: fix auth_methods module (#313)
* fix: fix auth_methods module * lint * add tests * change docs * add optional arg * lint * fix sync method * fix docs * Fix async methods * lint * fix tests * lint * added test dep * fix patches * add missing import * fix patches
1 parent a5897c8 commit 6cf6d6b

File tree

7 files changed

+279
-446
lines changed

7 files changed

+279
-446
lines changed

packages/toolbox-core/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ For Toolbox servers hosted on Google Cloud (e.g., Cloud Run) and requiring
337337
```python
338338
from toolbox_core import auth_methods
339339

340-
auth_token_provider = auth_methods.aget_google_id_token # can also use sync method
340+
auth_token_provider = auth_methods.aget_google_id_token(URL) # can also use sync method
341341
async with ToolboxClient(
342342
URL,
343343
client_headers={"Authorization": auth_token_provider},

packages/toolbox-core/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ dependencies = [
1313
"pydantic>=2.7.0,<3.0.0",
1414
"aiohttp>=3.8.6,<4.0.0",
1515
"deprecated>=1.2.15,<2.0.0",
16+
"google-auth>=2.0.0,<3.0.0",
17+
"requests>=2.19.0,<3.0.0"
1618
]
1719

1820
classifiers = [
@@ -52,6 +54,7 @@ test = [
5254
"pytest-mock==3.14.1",
5355
"google-cloud-secret-manager==2.24.0",
5456
"google-cloud-storage==3.2.0",
57+
"aioresponses==0.7.8"
5558
]
5659
[build-system]
5760
requires = ["setuptools"]
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
aiohttp==3.12.14
22
pydantic==2.11.7
3-
deprecated==1.2.18
3+
deprecated==1.2.18
4+
requests==2.32.4
5+
google-auth==2.40.3

packages/toolbox-core/src/toolbox_core/auth_methods.py

Lines changed: 109 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -13,159 +13,162 @@
1313
# limitations under the License.
1414

1515
"""
16-
This module provides functions to obtain Google ID tokens, formatted as "Bearer" tokens,
17-
for use in the "Authorization" header of HTTP requests.
16+
This module provides functions to obtain Google ID tokens for a specific audience.
1817
19-
Example User Experience:
18+
The tokens are returned as "Bearer" strings for direct use in HTTP Authorization
19+
headers. It uses a simple in-memory cache to avoid refetching on every call.
20+
21+
Example Usage:
2022
from toolbox_core import auth_methods
2123
22-
auth_token_provider = auth_methods.aget_google_id_token
24+
URL = "https://toolbox-service-url"
2325
async with ToolboxClient(
2426
URL,
25-
client_headers={"Authorization": auth_token_provider},
26-
) as toolbox:
27+
client_headers={"Authorization": auth_methods.aget_google_id_token})
28+
as toolbox:
2729
tools = await toolbox.load_toolset()
2830
"""
2931

32+
import asyncio
3033
from datetime import datetime, timedelta, timezone
31-
from functools import partial
32-
from typing import Any, Dict, Optional
34+
from typing import Any, Callable, Coroutine, Dict, Optional
3335

3436
import google.auth
35-
from google.auth._credentials_async import Credentials
36-
from google.auth._default_async import default_async
37-
from google.auth.transport import _aiohttp_requests
37+
from google.auth.exceptions import GoogleAuthError
3838
from google.auth.transport.requests import AuthorizedSession, Request
39+
from google.oauth2 import id_token
3940

40-
# --- Constants and Configuration ---
41-
# Prefix for Authorization header tokens
41+
# --- Constants ---
4242
BEARER_TOKEN_PREFIX = "Bearer "
43-
# Margin in seconds to refresh token before its actual expiry
44-
CACHE_REFRESH_MARGIN_SECONDS = 60
43+
CACHE_REFRESH_MARGIN = timedelta(seconds=60)
44+
45+
_token_cache: Dict[str, Any] = {
46+
"token": None,
47+
"expires_at": datetime.min.replace(tzinfo=timezone.utc),
48+
}
4549

4650

47-
# --- Global Cache Storage ---
48-
# Stores the cached Google ID token and its expiry timestamp
49-
_cached_google_id_token: Dict[str, Any] = {"token": None, "expires_at": 0}
51+
def _is_token_valid() -> bool:
52+
"""Checks if the cached token exists and is not nearing expiry."""
53+
if not _token_cache["token"]:
54+
return False
55+
return datetime.now(timezone.utc) < (
56+
_token_cache["expires_at"] - CACHE_REFRESH_MARGIN
57+
)
5058

5159

52-
# --- Helper Functions ---
53-
def _is_cached_token_valid(
54-
cache: Dict[str, Any], margin_seconds: int = CACHE_REFRESH_MARGIN_SECONDS
55-
) -> bool:
60+
def _update_cache(new_token: str) -> None:
5661
"""
57-
Checks if a token in the cache is valid (exists and not expired).
62+
Validates a new token, extracts its expiry, and updates the cache.
5863
5964
Args:
60-
cache: The dictionary containing 'token' and 'expires_at'.
61-
margin_seconds: The time in seconds before expiry to consider the token invalid.
65+
new_token: The new JWT ID token string.
6266
63-
Returns:
64-
True if the token is valid, False otherwise.
67+
Raises:
68+
ValueError: If the token is invalid or its expiry cannot be determined.
6569
"""
66-
if not cache.get("token"):
67-
return False
70+
try:
71+
# verify_oauth2_token not only decodes but also validates the token's
72+
# signature and claims against Google's public keys.
73+
# It's a synchronous, CPU-bound operation, safe for async contexts.
74+
claims = id_token.verify_oauth2_token(new_token, Request())
6875

69-
expires_at_value = cache.get("expires_at")
70-
if not isinstance(expires_at_value, datetime):
71-
return False
76+
expiry_timestamp = claims.get("exp")
77+
if not expiry_timestamp:
78+
raise ValueError("Token does not contain an 'exp' claim.")
7279

73-
# Ensure expires_at_value is timezone-aware (UTC).
74-
if (
75-
expires_at_value.tzinfo is None
76-
or expires_at_value.tzinfo.utcoffset(expires_at_value) is None
77-
):
78-
expires_at_value = expires_at_value.replace(tzinfo=timezone.utc)
80+
_token_cache["token"] = new_token
81+
_token_cache["expires_at"] = datetime.fromtimestamp(
82+
expiry_timestamp, tz=timezone.utc
83+
)
7984

80-
current_time_utc = datetime.now(timezone.utc)
81-
if current_time_utc + timedelta(seconds=margin_seconds) < expires_at_value:
82-
return True
85+
except (ValueError, GoogleAuthError) as e:
86+
# Clear cache on failure to prevent using a stale or invalid token
87+
_token_cache["token"] = None
88+
_token_cache["expires_at"] = datetime.min.replace(tzinfo=timezone.utc)
89+
raise ValueError(f"Failed to validate and cache the new token: {e}") from e
8390

84-
return False
8591

92+
def get_google_token_from_aud(audience: Optional[str] = None) -> str:
93+
if _is_token_valid():
94+
return BEARER_TOKEN_PREFIX + _token_cache["token"]
8695

87-
def _update_token_cache(
88-
cache: Dict[str, Any], new_id_token: Optional[str], expiry: Optional[datetime]
89-
) -> None:
90-
"""
91-
Updates the global token cache with a new token and its expiry.
96+
# Get local user credentials
97+
credentials, _ = google.auth.default()
98+
session = AuthorizedSession(credentials)
99+
request = Request(session)
100+
credentials.refresh(request)
92101

93-
Args:
94-
cache: The dictionary containing 'token' and 'expires_at'.
95-
new_id_token: The new ID token string to cache.
96-
"""
97-
if new_id_token:
98-
cache["token"] = new_id_token
99-
expiry_timestamp = expiry
100-
if expiry_timestamp:
101-
cache["expires_at"] = expiry_timestamp
102-
else:
103-
# If expiry can't be determined, treat as immediately expired to force refresh
104-
cache["expires_at"] = 0
105-
else:
106-
# Clear cache if no new token is provided
107-
cache["token"] = None
108-
cache["expires_at"] = 0
109-
110-
111-
# --- Public API Functions ---
112-
def get_google_id_token() -> str:
102+
if hasattr(credentials, "id_token"):
103+
new_id_token = getattr(credentials, "id_token", None)
104+
if new_id_token:
105+
_update_cache(new_id_token)
106+
return BEARER_TOKEN_PREFIX + new_id_token
107+
108+
if audience is None:
109+
raise Exception(
110+
"You are not authenticating using User Credentials."
111+
" Please set the audience string to the Toolbox service URL to get the Google ID token."
112+
)
113+
114+
# Get credentials for Google Cloud environments or for service account key files
115+
try:
116+
request = Request()
117+
new_token = id_token.fetch_id_token(request, audience)
118+
_update_cache(new_token)
119+
return BEARER_TOKEN_PREFIX + _token_cache["token"]
120+
121+
except GoogleAuthError as e:
122+
raise GoogleAuthError(
123+
f"Failed to fetch Google ID token for audience '{audience}': {e}"
124+
) from e
125+
126+
127+
def get_google_id_token(audience: Optional[str] = None) -> Callable[[], str]:
113128
"""
114-
Synchronously fetches a Google ID token.
129+
Returns a SYNC function that, when called, fetches a Google ID token.
130+
This function uses Application Default Credentials for local systems
131+
and standard google auth libraries for Google Cloud environments.
132+
It caches the token in memory.
115133
116-
The token is formatted as a 'Bearer' token string and is suitable for use
117-
in an HTTP Authorization header. This function uses Application Default
118-
Credentials.
134+
Args:
135+
audience: The audience for the ID token (e.g., a service URL or client ID).
119136
120137
Returns:
121-
A string in the format "Bearer <google_id_token>".
138+
A function that when executed returns string in the format "Bearer <google_id_token>".
122139
123140
Raises:
124-
Exception: If fetching the Google ID token fails.
141+
GoogleAuthError: If fetching credentials or the token fails.
142+
ValueError: If the fetched token is invalid.
125143
"""
126-
if _is_cached_token_valid(_cached_google_id_token):
127-
return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"]
128144

129-
credentials, _ = google.auth.default()
130-
session = AuthorizedSession(credentials)
131-
request = Request(session)
132-
credentials.refresh(request)
133-
new_id_token = getattr(credentials, "id_token", None)
134-
expiry = getattr(credentials, "expiry")
145+
def _token_getter() -> str:
146+
return get_google_token_from_aud(audience)
135147

136-
_update_token_cache(_cached_google_id_token, new_id_token, expiry)
137-
if new_id_token:
138-
return BEARER_TOKEN_PREFIX + new_id_token
139-
else:
140-
raise Exception("Failed to fetch Google ID token.")
148+
return _token_getter
141149

142150

143-
async def aget_google_id_token() -> str:
151+
def aget_google_id_token(
152+
audience: Optional[str] = None,
153+
) -> Callable[[], Coroutine[Any, Any, str]]:
144154
"""
145-
Asynchronously fetches a Google ID token.
155+
Returns an ASYNC function that, when called, fetches a Google ID token.
156+
This function uses Application Default Credentials for local systems
157+
and standard google auth libraries for Google Cloud environments.
158+
It caches the token in memory.
146159
147-
The token is formatted as a 'Bearer' token string and is suitable for use
148-
in an HTTP Authorization header. This function uses Application Default
149-
Credentials.
160+
Args:
161+
audience: The audience for the ID token (e.g., a service URL or client ID).
150162
151163
Returns:
152-
A string in the format "Bearer <google_id_token>".
164+
An async function that when executed returns string in the format "Bearer <google_id_token>".
153165
154166
Raises:
155-
Exception: If fetching the Google ID token fails.
167+
GoogleAuthError: If fetching credentials or the token fails.
168+
ValueError: If the fetched token is invalid.
156169
"""
157-
if _is_cached_token_valid(_cached_google_id_token):
158-
return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"]
159-
160-
credentials, _ = default_async()
161-
await credentials.refresh(_aiohttp_requests.Request())
162-
credentials.before_request = partial(Credentials.before_request, credentials)
163-
new_id_token = getattr(credentials, "id_token", None)
164-
expiry = getattr(credentials, "expiry")
165170

166-
_update_token_cache(_cached_google_id_token, new_id_token, expiry)
171+
async def _token_getter() -> str:
172+
return await asyncio.to_thread(get_google_token_from_aud, audience)
167173

168-
if new_id_token:
169-
return BEARER_TOKEN_PREFIX + new_id_token
170-
else:
171-
raise Exception("Failed to fetch async Google ID token.")
174+
return _token_getter

0 commit comments

Comments
 (0)