Skip to content

Commit 4e832c8

Browse files
authored
feat: Cache google id tokens (#254)
* feat: Cache google id tokens * lint * lint * fix tests * get expiry directly from creds * fix tests * lint * remove pyjwt from requirements
1 parent 79005ad commit 4e832c8

File tree

2 files changed

+505
-203
lines changed

2 files changed

+505
-203
lines changed

packages/toolbox-core/src/toolbox_core/auth_methods.py

Lines changed: 125 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,48 +12,104 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# The tokens obtained by these functions are formatted as "Bearer" tokens
16-
# and are intended to be passed in the "Authorization" header of HTTP requests.
17-
#
18-
# Example User Experience:
19-
# from toolbox_core import auth_methods
20-
#
21-
# auth_token_provider = auth_methods.aget_google_id_token
22-
# toolbox = ToolboxClient(
23-
# URL,
24-
# client_headers={"Authorization": auth_token_provider},
25-
# )
26-
# tools = await toolbox.load_toolset()
15+
"""
16+
This module provides functions to obtain Google ID tokens, formatted as "Bearer" tokens,
17+
for use in the "Authorization" header of HTTP requests.
18+
19+
Example User Experience:
20+
from toolbox_core import auth_methods
2721
22+
auth_token_provider = auth_methods.aget_google_id_token
23+
toolbox = ToolboxClient(
24+
URL,
25+
client_headers={"Authorization": auth_token_provider},
26+
)
27+
tools = await toolbox.load_toolset()
28+
"""
2829

30+
from datetime import datetime, timedelta, timezone
2931
from functools import partial
32+
from typing import Any, Dict, Optional
3033

3134
import google.auth
3235
from google.auth._credentials_async import Credentials
3336
from google.auth._default_async import default_async
3437
from google.auth.transport import _aiohttp_requests
3538
from google.auth.transport.requests import AuthorizedSession, Request
3639

40+
# --- Constants and Configuration ---
41+
# Prefix for Authorization header tokens
42+
BEARER_TOKEN_PREFIX = "Bearer "
43+
# Margin in seconds to refresh token before its actual expiry
44+
CACHE_REFRESH_MARGIN_SECONDS = 60
45+
46+
47+
# --- Global Cache Storage ---
48+
# Stores the cached Google ID token and its expiry timestamp
49+
_cached_google_id_token: Dict[str, Any] = {"token": None, "expires_at": 0}
3750

38-
async def aget_google_id_token():
51+
52+
# --- Helper Functions ---
53+
def _is_cached_token_valid(
54+
cache: Dict[str, Any], margin_seconds: int = CACHE_REFRESH_MARGIN_SECONDS
55+
) -> bool:
3956
"""
40-
Asynchronously fetches a Google ID token.
57+
Checks if a token in the cache is valid (exists and not expired).
4158
42-
The token is formatted as a 'Bearer' token string and is suitable for use
43-
in an HTTP Authorization header. This function uses Application Default
44-
Credentials.
59+
Args:
60+
cache: The dictionary containing 'token' and 'expires_at'.
61+
margin_seconds: The time in seconds before expiry to consider the token invalid.
4562
4663
Returns:
47-
A string in the format "Bearer <google_id_token>".
64+
True if the token is valid, False otherwise.
65+
"""
66+
if not cache.get("token"):
67+
return False
68+
69+
expires_at_value = cache.get("expires_at")
70+
if not isinstance(expires_at_value, datetime):
71+
return False
72+
73+
# Ensure expires_at_value is timezone-aware (UTC).
74+
if (
75+
expires_at_value.tzinfo is None
76+
or expires_at_value.tzinfo.utcoffset(expires_at_value) is None
77+
):
78+
expires_at_value = expires_at_value.replace(tzinfo=timezone.utc)
79+
80+
current_time_utc = datetime.now(timezone.utc)
81+
if current_time_utc + timedelta(seconds=margin_seconds) < expires_at_value:
82+
return True
83+
84+
return False
85+
86+
87+
def _update_token_cache(
88+
cache: Dict[str, Any], new_id_token: Optional[str], expiry: Optional[datetime]
89+
) -> None:
90+
"""
91+
Updates the global token cache with a new token and its expiry.
92+
93+
Args:
94+
cache: The dictionary containing 'token' and 'expires_at'.
95+
new_id_token: The new ID token string to cache.
4896
"""
49-
creds, _ = default_async()
50-
await creds.refresh(_aiohttp_requests.Request())
51-
creds.before_request = partial(Credentials.before_request, creds)
52-
token = creds.id_token
53-
return f"Bearer {token}"
97+
if new_id_token:
98+
cache["token"] = new_id_token
99+
expiry_timestamp = expiry
100+
if expiry_timestamp:
101+
cache["expires_at"] = expiry_timestamp
102+
else:
103+
# If expiry can't be determined, treat as immediately expired to force refresh
104+
cache["expires_at"] = 0
105+
else:
106+
# Clear cache if no new token is provided
107+
cache["token"] = None
108+
cache["expires_at"] = 0
54109

55110

56-
def get_google_id_token():
111+
# --- Public API Functions ---
112+
def get_google_id_token() -> str:
57113
"""
58114
Synchronously fetches a Google ID token.
59115
@@ -63,10 +119,53 @@ def get_google_id_token():
63119
64120
Returns:
65121
A string in the format "Bearer <google_id_token>".
122+
123+
Raises:
124+
Exception: If fetching the Google ID token fails.
66125
"""
126+
if _is_cached_token_valid(_cached_google_id_token):
127+
return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"]
128+
67129
credentials, _ = google.auth.default()
68130
session = AuthorizedSession(credentials)
69131
request = Request(session)
70132
credentials.refresh(request)
71-
token = credentials.id_token
72-
return f"Bearer {token}"
133+
new_id_token = getattr(credentials, "id_token", None)
134+
expiry = getattr(credentials, "expiry")
135+
136+
_update_token_cache(_cached_google_id_token, new_id_token, expiry)
137+
if new_id_token:
138+
return BEARER_TOKEN_PREFIX + new_id_token
139+
else:
140+
raise Exception("Failed to fetch Google ID token.")
141+
142+
143+
async def aget_google_id_token() -> str:
144+
"""
145+
Asynchronously fetches a Google ID token.
146+
147+
The token is formatted as a 'Bearer' token string and is suitable for use
148+
in an HTTP Authorization header. This function uses Application Default
149+
Credentials.
150+
151+
Returns:
152+
A string in the format "Bearer <google_id_token>".
153+
154+
Raises:
155+
Exception: If fetching the Google ID token fails.
156+
"""
157+
if _is_cached_token_valid(_cached_google_id_token):
158+
return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"]
159+
160+
credentials, _ = default_async()
161+
await credentials.refresh(_aiohttp_requests.Request())
162+
credentials.before_request = partial(Credentials.before_request, credentials)
163+
new_id_token = getattr(credentials, "id_token", None)
164+
expiry = getattr(credentials, "expiry")
165+
166+
_update_token_cache(_cached_google_id_token, new_id_token, expiry)
167+
168+
if new_id_token:
169+
return BEARER_TOKEN_PREFIX + new_id_token
170+
else:
171+
raise Exception("Failed to fetch async Google ID token.")

0 commit comments

Comments
 (0)