Skip to content

Commit 681296f

Browse files
shashankramEItanya
andauthored
agentsts: allow dynamic fetching of actor token and enable caching (#1443)
- Allows fetching the actor token using a dynamic callback. - Enables using cached actor and subject tokens if they haven't expired. This allows skipping the token exchange if the cached token has not expired. --------- Signed-off-by: Shashank Ram <shashank.ram@solo.io> Co-authored-by: Eitan Yarmush <eitan.yarmush@solo.io>
1 parent e61fd48 commit 681296f

File tree

5 files changed

+945
-27
lines changed

5 files changed

+945
-27
lines changed

python/packages/agentsts-adk/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dependencies = [
1717
"typing-extensions>=4.8.0",
1818
"aiofiles>=24.1.0",
1919
"anyio>=4.9.0",
20+
"PyJWT>=2.8.0",
2021
]
2122

2223
[tool.uv.sources]

python/packages/agentsts-adk/src/agentsts/adk/_base.py

Lines changed: 187 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
"""Google ADK-specific STS integration."""
22

3+
import inspect
34
import logging
4-
from typing import Any, Dict, Optional
5+
import time
6+
from typing import Any, Awaitable, Callable, Dict, Optional, Union
57

8+
import jwt
69
from google.adk.agents import BaseAgent, LlmAgent
710
from google.adk.agents.invocation_context import InvocationContext
811
from google.adk.agents.readonly_context import ReadonlyContext
9-
from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes, HttpAuth, HttpCredentials
12+
from google.adk.auth.auth_credential import (
13+
AuthCredential,
14+
AuthCredentialTypes,
15+
HttpAuth,
16+
HttpCredentials,
17+
)
1018
from google.adk.events.event import Event
1119
from google.adk.plugins.base_plugin import BasePlugin
1220
from google.adk.runners import Runner
@@ -32,11 +40,43 @@ def __init__(
3240
self,
3341
well_known_uri: str,
3442
service_account_token_path: Optional[str] = None,
43+
fetch_actor_token: Optional[Union[Callable[[], str], Callable[[], Awaitable[str]]]] = None,
3544
timeout: int = 5,
3645
verify_ssl: bool = True,
3746
additional_config: Optional[Dict[str, Any]] = None,
3847
):
39-
super().__init__(well_known_uri, service_account_token_path, timeout, verify_ssl, additional_config)
48+
"""Initialize the ADK STS integration.
49+
50+
Args:
51+
well_known_uri: Well-known configuration URI for the STS server
52+
service_account_token_path: Path to service account token file (ignored if fetch_actor_token is set)
53+
fetch_actor_token: Optional callable (sync or async) that returns an actor token
54+
timeout: Request timeout in seconds
55+
verify_ssl: Whether to verify SSL certificates
56+
additional_config: Additional configuration
57+
"""
58+
super().__init__(
59+
well_known_uri=well_known_uri,
60+
service_account_token_path=service_account_token_path,
61+
fetch_actor_token=fetch_actor_token,
62+
timeout=timeout,
63+
verify_ssl=verify_ssl,
64+
additional_config=additional_config,
65+
)
66+
67+
68+
class _TokenCacheEntry:
69+
"""Cache entry for access tokens with metadata."""
70+
71+
def __init__(self, token: str, expiry: Optional[int] = None):
72+
"""Initialize token cache entry.
73+
74+
Args:
75+
token: The access token
76+
expiry: Token expiry timestamp (Unix epoch), if available
77+
"""
78+
self.token = token
79+
self.expiry = expiry
4080

4181

4282
class ADKTokenPropagationPlugin(BasePlugin):
@@ -50,7 +90,8 @@ def __init__(self, sts_integration: Optional[STSIntegrationBase] = None):
5090
"""
5191
super().__init__("ADKTokenPropagationPlugin")
5292
self.sts_integration = sts_integration
53-
self.token_cache: Dict[str, str] = {}
93+
self.token_cache: Dict[str, _TokenCacheEntry] = {}
94+
self.actor_token_cache: Optional[_TokenCacheEntry] = None
5495

5596
def add_to_agent(self, agent: BaseAgent):
5697
"""
@@ -70,13 +111,13 @@ def add_to_agent(self, agent: BaseAgent):
70111
logger.debug("Updated tool connection params to include access token from STS server")
71112

72113
def header_provider(self, readonly_context: Optional[ReadonlyContext]) -> Dict[str, str]:
73-
# access save token
74-
access_token = self.token_cache.get(self.cache_key(readonly_context._invocation_context), "")
75-
if not access_token:
114+
# access saved token
115+
cache_entry = self.token_cache.get(self.cache_key(readonly_context._invocation_context))
116+
if not cache_entry:
76117
return {}
77118

78119
return {
79-
"Authorization": f"Bearer {access_token}",
120+
"Authorization": f"Bearer {cache_entry.token}",
80121
}
81122

82123
@override
@@ -86,41 +127,145 @@ async def before_run_callback(
86127
invocation_context: InvocationContext,
87128
) -> Optional[dict]:
88129
"""Propagate token to model before execution."""
130+
cache_key = self.cache_key(invocation_context)
131+
132+
# Check if we have a valid cached subject token
133+
cached_entry = self.token_cache.get(cache_key)
134+
if cached_entry and not _has_token_expired(cached_entry.expiry):
135+
if cached_entry.expiry:
136+
current_time = int(time.time())
137+
logger.debug(f"Using cached subject token (expires in {cached_entry.expiry - current_time}s)")
138+
else:
139+
logger.debug("Using cached subject token (no expiry)")
140+
return None
141+
142+
# No valid cached token, need to get/exchange subject token
89143
headers = invocation_context.session.state.get(HEADERS_KEY, None)
90144
subject_token = _extract_jwt_from_headers(headers)
91145
if not subject_token:
92146
logger.debug("No subject token found in headers for token propagation")
93147
return None
148+
94149
if self.sts_integration:
150+
# Get actor token (from cache or fetch dynamically)
151+
actor_token = await self._get_actor_token()
152+
if actor_token is None and self.sts_integration.fetch_actor_token:
153+
# Dynamic fetch failed; already logged a warning in _get_actor_token
154+
return None
155+
95156
try:
96157
subject_token = await self.sts_integration.exchange_token(
97158
subject_token=subject_token,
98159
subject_token_type=TokenType.JWT,
99-
actor_token=self.sts_integration._actor_token,
100-
actor_token_type=TokenType.JWT if self.sts_integration._actor_token else None,
160+
actor_token=actor_token,
161+
actor_token_type=TokenType.JWT if actor_token else None,
101162
)
102163
except Exception as e:
103164
logger.warning(f"STS token exchange failed: {e}")
104165
return None
105-
# no sts, just propagate the subject token upstream
106-
self.token_cache[self.cache_key(invocation_context)] = subject_token
166+
167+
# Extract expiry from the token
168+
expiry = _extract_jwt_expiry(subject_token)
169+
170+
# Cache the token with metadata
171+
self.token_cache[cache_key] = _TokenCacheEntry(
172+
token=subject_token,
173+
expiry=expiry,
174+
)
175+
logger.debug("Cached new subject token")
107176
return None
108177

109178
def cache_key(self, invocation_context: InvocationContext) -> str:
110179
"""Generate a cache key based on the session ID."""
111180
return invocation_context.session.id
112181

182+
async def _get_actor_token(self) -> Optional[str]:
183+
"""Get actor token from cache or fetch dynamically.
184+
185+
Returns:
186+
Actor token string if available, None otherwise
187+
"""
188+
if not self.sts_integration:
189+
return None
190+
191+
# Use static token if no dynamic fetch function
192+
if not self.sts_integration.fetch_actor_token:
193+
return self.sts_integration._actor_token
194+
195+
# Check cache for unexpired dynamic token
196+
if self.actor_token_cache:
197+
if not _has_token_expired(self.actor_token_cache.expiry):
198+
# Token is still valid
199+
if self.actor_token_cache.expiry:
200+
current_time = int(time.time())
201+
logger.debug(
202+
f"Using cached actor token (expires in {self.actor_token_cache.expiry - current_time}s)"
203+
)
204+
else:
205+
logger.debug("Using cached actor token (no expiry)")
206+
return self.actor_token_cache.token
207+
else:
208+
logger.debug("Cached actor token expired, fetching new one")
209+
210+
# Fetch new actor token
211+
try:
212+
if inspect.iscoroutinefunction(self.sts_integration.fetch_actor_token):
213+
actor_token = await self.sts_integration.fetch_actor_token()
214+
else:
215+
actor_token = self.sts_integration.fetch_actor_token()
216+
217+
# Extract expiry and cache the token
218+
expiry = _extract_jwt_expiry(actor_token)
219+
self.actor_token_cache = _TokenCacheEntry(token=actor_token, expiry=expiry)
220+
logger.debug("Fetched and cached new actor token")
221+
return actor_token
222+
223+
except Exception as e:
224+
logger.warning(f"Failed to fetch actor token dynamically: {e}")
225+
return None
226+
113227
@override
114228
async def after_run_callback(
115229
self,
116230
*,
117231
invocation_context: InvocationContext,
118232
) -> Optional[dict]:
119-
# delete token after run
120-
self.token_cache.pop(self.cache_key(invocation_context), None)
233+
"""Clean up expired tokens after run, preserving valid tokens."""
234+
cache_key = self.cache_key(invocation_context)
235+
cache_entry = self.token_cache.get(cache_key)
236+
237+
# Clean up subject token cache - only remove if expired
238+
if cache_entry and _has_token_expired(cache_entry.expiry):
239+
logger.debug("Removing expired subject token from cache")
240+
self.token_cache.pop(cache_key, None)
241+
242+
# Clean up expired actor token cache
243+
if self.actor_token_cache and _has_token_expired(self.actor_token_cache.expiry):
244+
logger.debug("Removing expired actor token from cache")
245+
self.actor_token_cache = None
246+
121247
return None
122248

123249

250+
def _has_token_expired(expiry: Optional[int], buffer_seconds: int = 5) -> bool:
251+
"""Check if a token has expired or will expire soon.
252+
253+
Args:
254+
expiry: Token expiry timestamp (Unix epoch), or None if no expiry
255+
buffer_seconds: Additional buffer time in seconds to treat tokens
256+
expiring soon as already expired (default: 5)
257+
258+
Returns:
259+
True if token has expired or will expire within buffer_seconds,
260+
False if still valid or no expiry
261+
"""
262+
if expiry is None:
263+
return False # No expiry means never expires
264+
265+
current_time = int(time.time())
266+
return expiry <= (current_time + buffer_seconds)
267+
268+
124269
def _extract_jwt_from_headers(headers: dict[str, str]) -> Optional[str]:
125270
"""Extract JWT from request headers for STS token exchange.
126271
@@ -150,3 +295,31 @@ def _extract_jwt_from_headers(headers: dict[str, str]) -> Optional[str]:
150295

151296
logger.debug(f"Successfully extracted JWT token (length: {len(jwt_token)})")
152297
return jwt_token
298+
299+
300+
def _extract_jwt_expiry(token: str) -> Optional[int]:
301+
"""Extract expiry timestamp from JWT token.
302+
303+
NOTE: This function does NOT validate the token signature.
304+
It is only used for cache management, not security decisions.
305+
Token validation happens in the STS server during exchange.
306+
307+
Args:
308+
token: JWT token string
309+
310+
Returns:
311+
Expiry timestamp (Unix epoch) if found, None otherwise
312+
"""
313+
try:
314+
# Decode without verification (we only need the expiry claim)
315+
decoded = jwt.decode(token, options={"verify_signature": False})
316+
expiry = decoded.get("exp")
317+
if expiry:
318+
logger.debug(f"Extracted JWT expiry: {expiry}")
319+
return int(expiry)
320+
321+
logger.debug("No 'exp' claim found in JWT")
322+
return None
323+
except Exception as e:
324+
logger.warning(f"Failed to extract JWT expiry: {e}")
325+
return None

0 commit comments

Comments
 (0)