Skip to content

fix!: fix auth_methods module #313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 17, 2025
Merged
2 changes: 1 addition & 1 deletion packages/toolbox-core/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ For Toolbox servers hosted on Google Cloud (e.g., Cloud Run) and requiring
```python
from toolbox_core import auth_methods

auth_token_provider = auth_methods.aget_google_id_token # can also use sync method
auth_token_provider = auth_methods.aget_google_id_token(URL) # can also use sync method
async with ToolboxClient(
URL,
client_headers={"Authorization": auth_token_provider},
Expand Down
3 changes: 3 additions & 0 deletions packages/toolbox-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ dependencies = [
"pydantic>=2.7.0,<3.0.0",
"aiohttp>=3.8.6,<4.0.0",
"deprecated>=1.2.15,<2.0.0",
"google-auth>=2.0.0,<3.0.0",
"requests>=2.19.0,<3.0.0"
]

classifiers = [
Expand Down Expand Up @@ -52,6 +54,7 @@ test = [
"pytest-mock==3.14.1",
"google-cloud-secret-manager==2.24.0",
"google-cloud-storage==3.2.0",
"aioresponses==0.7.8"
]
[build-system]
requires = ["setuptools"]
Expand Down
4 changes: 3 additions & 1 deletion packages/toolbox-core/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
aiohttp==3.12.14
pydantic==2.11.7
deprecated==1.2.18
deprecated==1.2.18
requests==2.32.4
google-auth==2.40.3
215 changes: 109 additions & 106 deletions packages/toolbox-core/src/toolbox_core/auth_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,159 +13,162 @@
# limitations under the License.

"""
This module provides functions to obtain Google ID tokens, formatted as "Bearer" tokens,
for use in the "Authorization" header of HTTP requests.
This module provides functions to obtain Google ID tokens for a specific audience.

Example User Experience:
The tokens are returned as "Bearer" strings for direct use in HTTP Authorization
headers. It uses a simple in-memory cache to avoid refetching on every call.

Example Usage:
from toolbox_core import auth_methods

auth_token_provider = auth_methods.aget_google_id_token
URL = "https://toolbox-service-url"
async with ToolboxClient(
URL,
client_headers={"Authorization": auth_token_provider},
) as toolbox:
client_headers={"Authorization": auth_methods.aget_google_id_token})
as toolbox:
tools = await toolbox.load_toolset()
"""

import asyncio
from datetime import datetime, timedelta, timezone
from functools import partial
from typing import Any, Dict, Optional
from typing import Any, Callable, Coroutine, Dict, Optional

import google.auth
from google.auth._credentials_async import Credentials
from google.auth._default_async import default_async
from google.auth.transport import _aiohttp_requests
from google.auth.exceptions import GoogleAuthError
from google.auth.transport.requests import AuthorizedSession, Request
from google.oauth2 import id_token

# --- Constants and Configuration ---
# Prefix for Authorization header tokens
# --- Constants ---
BEARER_TOKEN_PREFIX = "Bearer "
# Margin in seconds to refresh token before its actual expiry
CACHE_REFRESH_MARGIN_SECONDS = 60
CACHE_REFRESH_MARGIN = timedelta(seconds=60)

_token_cache: Dict[str, Any] = {
"token": None,
"expires_at": datetime.min.replace(tzinfo=timezone.utc),
}


# --- Global Cache Storage ---
# Stores the cached Google ID token and its expiry timestamp
_cached_google_id_token: Dict[str, Any] = {"token": None, "expires_at": 0}
def _is_token_valid() -> bool:
"""Checks if the cached token exists and is not nearing expiry."""
if not _token_cache["token"]:
return False
return datetime.now(timezone.utc) < (
_token_cache["expires_at"] - CACHE_REFRESH_MARGIN
)


# --- Helper Functions ---
def _is_cached_token_valid(
cache: Dict[str, Any], margin_seconds: int = CACHE_REFRESH_MARGIN_SECONDS
) -> bool:
def _update_cache(new_token: str) -> None:
"""
Checks if a token in the cache is valid (exists and not expired).
Validates a new token, extracts its expiry, and updates the cache.

Args:
cache: The dictionary containing 'token' and 'expires_at'.
margin_seconds: The time in seconds before expiry to consider the token invalid.
new_token: The new JWT ID token string.

Returns:
True if the token is valid, False otherwise.
Raises:
ValueError: If the token is invalid or its expiry cannot be determined.
"""
if not cache.get("token"):
return False
try:
# verify_oauth2_token not only decodes but also validates the token's
# signature and claims against Google's public keys.
# It's a synchronous, CPU-bound operation, safe for async contexts.
claims = id_token.verify_oauth2_token(new_token, Request())

expires_at_value = cache.get("expires_at")
if not isinstance(expires_at_value, datetime):
return False
expiry_timestamp = claims.get("exp")
if not expiry_timestamp:
raise ValueError("Token does not contain an 'exp' claim.")

# Ensure expires_at_value is timezone-aware (UTC).
if (
expires_at_value.tzinfo is None
or expires_at_value.tzinfo.utcoffset(expires_at_value) is None
):
expires_at_value = expires_at_value.replace(tzinfo=timezone.utc)
_token_cache["token"] = new_token
_token_cache["expires_at"] = datetime.fromtimestamp(
expiry_timestamp, tz=timezone.utc
)

current_time_utc = datetime.now(timezone.utc)
if current_time_utc + timedelta(seconds=margin_seconds) < expires_at_value:
return True
except (ValueError, GoogleAuthError) as e:
# Clear cache on failure to prevent using a stale or invalid token
_token_cache["token"] = None
_token_cache["expires_at"] = datetime.min.replace(tzinfo=timezone.utc)
raise ValueError(f"Failed to validate and cache the new token: {e}") from e

return False

def get_google_token_from_aud(audience: Optional[str] = None) -> str:
if _is_token_valid():
return BEARER_TOKEN_PREFIX + _token_cache["token"]

def _update_token_cache(
cache: Dict[str, Any], new_id_token: Optional[str], expiry: Optional[datetime]
) -> None:
"""
Updates the global token cache with a new token and its expiry.
# Get local user credentials
credentials, _ = google.auth.default()
session = AuthorizedSession(credentials)
request = Request(session)
credentials.refresh(request)

Args:
cache: The dictionary containing 'token' and 'expires_at'.
new_id_token: The new ID token string to cache.
"""
if new_id_token:
cache["token"] = new_id_token
expiry_timestamp = expiry
if expiry_timestamp:
cache["expires_at"] = expiry_timestamp
else:
# If expiry can't be determined, treat as immediately expired to force refresh
cache["expires_at"] = 0
else:
# Clear cache if no new token is provided
cache["token"] = None
cache["expires_at"] = 0


# --- Public API Functions ---
def get_google_id_token() -> str:
if hasattr(credentials, "id_token"):
new_id_token = getattr(credentials, "id_token", None)
if new_id_token:
_update_cache(new_id_token)
return BEARER_TOKEN_PREFIX + new_id_token

if audience is None:
raise Exception(
"You are not authenticating using User Credentials."
" Please set the audience string to the Toolbox service URL to get the Google ID token."
)

# Get credentials for Google Cloud environments or for service account key files
try:
request = Request()
new_token = id_token.fetch_id_token(request, audience)
_update_cache(new_token)
return BEARER_TOKEN_PREFIX + _token_cache["token"]

except GoogleAuthError as e:
raise GoogleAuthError(
f"Failed to fetch Google ID token for audience '{audience}': {e}"
) from e


def get_google_id_token(audience: Optional[str] = None) -> Callable[[], str]:
"""
Synchronously fetches a Google ID token.
Returns a SYNC function that, when called, fetches a Google ID token.
This function uses Application Default Credentials for local systems
and standard google auth libraries for Google Cloud environments.
It caches the token in memory.

The token is formatted as a 'Bearer' token string and is suitable for use
in an HTTP Authorization header. This function uses Application Default
Credentials.
Args:
audience: The audience for the ID token (e.g., a service URL or client ID).

Returns:
A string in the format "Bearer <google_id_token>".
A function that when executed returns string in the format "Bearer <google_id_token>".

Raises:
Exception: If fetching the Google ID token fails.
GoogleAuthError: If fetching credentials or the token fails.
ValueError: If the fetched token is invalid.
"""
if _is_cached_token_valid(_cached_google_id_token):
return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"]

credentials, _ = google.auth.default()
session = AuthorizedSession(credentials)
request = Request(session)
credentials.refresh(request)
new_id_token = getattr(credentials, "id_token", None)
expiry = getattr(credentials, "expiry")
def _token_getter() -> str:
return get_google_token_from_aud(audience)

_update_token_cache(_cached_google_id_token, new_id_token, expiry)
if new_id_token:
return BEARER_TOKEN_PREFIX + new_id_token
else:
raise Exception("Failed to fetch Google ID token.")
return _token_getter


async def aget_google_id_token() -> str:
def aget_google_id_token(
audience: Optional[str] = None,
) -> Callable[[], Coroutine[Any, Any, str]]:
"""
Asynchronously fetches a Google ID token.
Returns an ASYNC function that, when called, fetches a Google ID token.
This function uses Application Default Credentials for local systems
and standard google auth libraries for Google Cloud environments.
It caches the token in memory.

The token is formatted as a 'Bearer' token string and is suitable for use
in an HTTP Authorization header. This function uses Application Default
Credentials.
Args:
audience: The audience for the ID token (e.g., a service URL or client ID).

Returns:
A string in the format "Bearer <google_id_token>".
An async function that when executed returns string in the format "Bearer <google_id_token>".

Raises:
Exception: If fetching the Google ID token fails.
GoogleAuthError: If fetching credentials or the token fails.
ValueError: If the fetched token is invalid.
"""
if _is_cached_token_valid(_cached_google_id_token):
return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"]

credentials, _ = default_async()
await credentials.refresh(_aiohttp_requests.Request())
credentials.before_request = partial(Credentials.before_request, credentials)
new_id_token = getattr(credentials, "id_token", None)
expiry = getattr(credentials, "expiry")

_update_token_cache(_cached_google_id_token, new_id_token, expiry)
async def _token_getter() -> str:
return await asyncio.to_thread(get_google_token_from_aud, audience)

if new_id_token:
return BEARER_TOKEN_PREFIX + new_id_token
else:
raise Exception("Failed to fetch async Google ID token.")
return _token_getter
Loading