Skip to content

Commit 7539808

Browse files
committed
feat: Cache google id tokens
1 parent ca9ff14 commit 7539808

File tree

4 files changed

+506
-202
lines changed

4 files changed

+506
-202
lines changed

packages/toolbox-core/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ authors = [
1212
dependencies = [
1313
"pydantic>=2.7.0,<3.0.0",
1414
"aiohttp>=3.8.6,<4.0.0",
15+
"PyJWT>=2.0.0,<3.0.0",
1516
]
1617

1718
classifiers = [
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
aiohttp==3.11.18
22
pydantic==2.11.4
3+
PyJWT==2.10.1

packages/toolbox-core/src/toolbox_core/auth_methods.py

Lines changed: 132 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,48 +12,113 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# The tokens obtained by these functions are formatted as "Bearer" tokens
16-
# and are intended to be passed in the "Authorization" header of HTTP requests.
17-
#
18-
# Example User Experience:
19-
# from toolbox_core import auth_methods
20-
#
21-
# auth_token_provider = auth_methods.aget_google_id_token
22-
# toolbox = ToolboxClient(
23-
# URL,
24-
# client_headers={"Authorization": auth_token_provider},
25-
# )
26-
# tools = await toolbox.load_toolset()
15+
"""
16+
This module provides functions to obtain Google ID tokens, formatted as "Bearer" tokens,
17+
for use in the "Authorization" header of HTTP requests.
18+
19+
Example User Experience:
20+
from toolbox_core import auth_methods
2721
22+
auth_token_provider = auth_methods.aget_google_id_token
23+
toolbox = ToolboxClient(
24+
URL,
25+
client_headers={"Authorization": auth_token_provider},
26+
)
27+
tools = await toolbox.load_toolset()
28+
"""
2829

30+
import time
2931
from functools import partial
32+
from typing import Optional, Dict, Any
3033

3134
import google.auth
3235
from google.auth._credentials_async import Credentials
36+
import jwt
3337
from google.auth._default_async import default_async
3438
from google.auth.transport import _aiohttp_requests
3539
from google.auth.transport.requests import AuthorizedSession, Request
3640

3741

38-
async def aget_google_id_token():
42+
# --- Constants and Configuration ---
43+
# Prefix for Authorization header tokens
44+
BEARER_TOKEN_PREFIX = "Bearer "
45+
# Margin in seconds to refresh token before its actual expiry
46+
CACHE_REFRESH_MARGIN_SECONDS = 60
47+
48+
49+
# --- Global Cache Storage ---
50+
# Stores the cached Google ID token and its expiry timestamp
51+
_cached_google_id_token: Dict[str, Any] = {"token": None, "expires_at": 0}
52+
53+
54+
# --- Helper Functions ---
55+
def _decode_jwt_and_get_expiry(id_token: str) -> Optional[float]:
3956
"""
40-
Asynchronously fetches a Google ID token.
57+
Decodes a JWT and extracts the 'exp' (expiration) claim.
4158
42-
The token is formatted as a 'Bearer' token string and is suitable for use
43-
in an HTTP Authorization header. This function uses Application Default
44-
Credentials.
59+
Args:
60+
id_token: The JWT string to decode.
4561
4662
Returns:
47-
A string in the format "Bearer <google_id_token>".
63+
The 'exp' timestamp as a float if present and decoding is successful,
64+
otherwise None.
4865
"""
49-
creds, _ = default_async()
50-
await creds.refresh(_aiohttp_requests.Request())
51-
creds.before_request = partial(Credentials.before_request, creds)
52-
token = creds.id_token
53-
return f"Bearer {token}"
66+
try:
67+
decoded_token = jwt.decode(
68+
id_token, options={"verify_signature": False, "verify_aud": False}
69+
)
70+
return decoded_token.get("exp")
71+
except jwt.PyJWTError:
72+
return None
5473

5574

56-
def get_google_id_token():
75+
def _is_cached_token_valid(
76+
cache: Dict[str, Any], margin_seconds: int = CACHE_REFRESH_MARGIN_SECONDS
77+
) -> bool:
78+
"""
79+
Checks if a token in the cache is valid (exists and not expired).
80+
81+
Args:
82+
cache: The dictionary containing 'token' and 'expires_at'.
83+
margin_seconds: The time in seconds before expiry to consider the token invalid.
84+
85+
Returns:
86+
True if the token is valid, False otherwise.
87+
"""
88+
if not cache.get("token"):
89+
return False
90+
91+
expires_at = cache.get("expires_at")
92+
if not isinstance(expires_at, (int, float)) or expires_at <= 0:
93+
return False
94+
95+
return time.time() < (expires_at - margin_seconds)
96+
97+
98+
def _update_token_cache(cache: Dict[str, Any], new_id_token: Optional[str]):
99+
"""
100+
Updates the global token cache with a new token and its expiry.
101+
102+
Args:
103+
cache: The dictionary containing 'token' and 'expires_at'.
104+
new_id_token: The new ID token string to cache.
105+
"""
106+
if new_id_token:
107+
cache["token"] = new_id_token
108+
expiry_timestamp = _decode_jwt_and_get_expiry(new_id_token)
109+
if expiry_timestamp:
110+
cache["expires_at"] = expiry_timestamp
111+
else:
112+
# If expiry can't be determined, treat as immediately expired to force refresh
113+
cache["expires_at"] = 0
114+
else:
115+
# Clear cache if no new token is provided
116+
cache["token"] = None
117+
cache["expires_at"] = 0
118+
119+
120+
# --- Public API Functions ---
121+
def get_google_id_token() -> str:
57122
"""
58123
Synchronously fetches a Google ID token.
59124
@@ -63,10 +128,51 @@ def get_google_id_token():
63128
64129
Returns:
65130
A string in the format "Bearer <google_id_token>".
131+
132+
Raises:
133+
Exception: If fetching the Google ID token fails.
66134
"""
135+
if _is_cached_token_valid(_cached_google_id_token):
136+
return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"]
137+
67138
credentials, _ = google.auth.default()
68139
session = AuthorizedSession(credentials)
69140
request = Request(session)
70141
credentials.refresh(request)
71-
token = credentials.id_token
72-
return f"Bearer {token}"
142+
new_id_token = getattr(credentials, "id_token", None)
143+
144+
_update_token_cache(_cached_google_id_token, new_id_token)
145+
if new_id_token:
146+
return BEARER_TOKEN_PREFIX + new_id_token
147+
else:
148+
raise Exception("Failed to fetch Google ID token.")
149+
150+
151+
async def aget_google_id_token() -> str:
152+
"""
153+
Asynchronously fetches a Google ID token.
154+
155+
The token is formatted as a 'Bearer' token string and is suitable for use
156+
in an HTTP Authorization header. This function uses Application Default
157+
Credentials.
158+
159+
Returns:
160+
A string in the format "Bearer <google_id_token>".
161+
162+
Raises:
163+
Exception: If fetching the Google ID token fails.
164+
"""
165+
if _is_cached_token_valid(_cached_google_id_token):
166+
return BEARER_TOKEN_PREFIX + _cached_google_id_token["token"]
167+
168+
credentials, _ = default_async()
169+
await credentials.refresh(_aiohttp_requests.Request())
170+
credentials.before_request = partial(Credentials.before_request, credentials)
171+
new_id_token = getattr(credentials, "id_token", None)
172+
173+
_update_token_cache(_cached_google_id_token, new_id_token)
174+
175+
if new_id_token:
176+
return BEARER_TOKEN_PREFIX + new_id_token
177+
else:
178+
raise Exception("Failed to fetch async Google ID token.")

0 commit comments

Comments
 (0)