11"""Google ADK-specific STS integration."""
22
3+ import inspect
34import logging
4- from typing import Any , Dict , Optional
5+ import time
6+ from typing import Any , Awaitable , Callable , Dict , Optional , Union
57
8+ import jwt
69from google .adk .agents import BaseAgent , LlmAgent
710from google .adk .agents .invocation_context import InvocationContext
811from google .adk .agents .readonly_context import ReadonlyContext
9- from google .adk .auth .auth_credential import AuthCredential , AuthCredentialTypes , HttpAuth , HttpCredentials
12+ from google .adk .auth .auth_credential import (
13+ AuthCredential ,
14+ AuthCredentialTypes ,
15+ HttpAuth ,
16+ HttpCredentials ,
17+ )
1018from google .adk .events .event import Event
1119from google .adk .plugins .base_plugin import BasePlugin
1220from google .adk .runners import Runner
@@ -32,11 +40,43 @@ def __init__(
3240 self ,
3341 well_known_uri : str ,
3442 service_account_token_path : Optional [str ] = None ,
43+ fetch_actor_token : Optional [Union [Callable [[], str ], Callable [[], Awaitable [str ]]]] = None ,
3544 timeout : int = 5 ,
3645 verify_ssl : bool = True ,
3746 additional_config : Optional [Dict [str , Any ]] = None ,
3847 ):
39- super ().__init__ (well_known_uri , service_account_token_path , timeout , verify_ssl , additional_config )
48+ """Initialize the ADK STS integration.
49+
50+ Args:
51+ well_known_uri: Well-known configuration URI for the STS server
52+ service_account_token_path: Path to service account token file (ignored if fetch_actor_token is set)
53+ fetch_actor_token: Optional callable (sync or async) that returns an actor token
54+ timeout: Request timeout in seconds
55+ verify_ssl: Whether to verify SSL certificates
56+ additional_config: Additional configuration
57+ """
58+ super ().__init__ (
59+ well_known_uri = well_known_uri ,
60+ service_account_token_path = service_account_token_path ,
61+ fetch_actor_token = fetch_actor_token ,
62+ timeout = timeout ,
63+ verify_ssl = verify_ssl ,
64+ additional_config = additional_config ,
65+ )
66+
67+
68+ class _TokenCacheEntry :
69+ """Cache entry for access tokens with metadata."""
70+
71+ def __init__ (self , token : str , expiry : Optional [int ] = None ):
72+ """Initialize token cache entry.
73+
74+ Args:
75+ token: The access token
76+ expiry: Token expiry timestamp (Unix epoch), if available
77+ """
78+ self .token = token
79+ self .expiry = expiry
4080
4181
4282class ADKTokenPropagationPlugin (BasePlugin ):
@@ -50,7 +90,8 @@ def __init__(self, sts_integration: Optional[STSIntegrationBase] = None):
5090 """
5191 super ().__init__ ("ADKTokenPropagationPlugin" )
5292 self .sts_integration = sts_integration
53- self .token_cache : Dict [str , str ] = {}
93+ self .token_cache : Dict [str , _TokenCacheEntry ] = {}
94+ self .actor_token_cache : Optional [_TokenCacheEntry ] = None
5495
5596 def add_to_agent (self , agent : BaseAgent ):
5697 """
@@ -70,13 +111,13 @@ def add_to_agent(self, agent: BaseAgent):
70111 logger .debug ("Updated tool connection params to include access token from STS server" )
71112
72113 def header_provider (self , readonly_context : Optional [ReadonlyContext ]) -> Dict [str , str ]:
73- # access save token
74- access_token = self .token_cache .get (self .cache_key (readonly_context ._invocation_context ), "" )
75- if not access_token :
114+ # access saved token
115+ cache_entry = self .token_cache .get (self .cache_key (readonly_context ._invocation_context ))
116+ if not cache_entry :
76117 return {}
77118
78119 return {
79- "Authorization" : f"Bearer { access_token } " ,
120+ "Authorization" : f"Bearer { cache_entry . token } " ,
80121 }
81122
82123 @override
@@ -86,41 +127,145 @@ async def before_run_callback(
86127 invocation_context : InvocationContext ,
87128 ) -> Optional [dict ]:
88129 """Propagate token to model before execution."""
130+ cache_key = self .cache_key (invocation_context )
131+
132+ # Check if we have a valid cached subject token
133+ cached_entry = self .token_cache .get (cache_key )
134+ if cached_entry and not _has_token_expired (cached_entry .expiry ):
135+ if cached_entry .expiry :
136+ current_time = int (time .time ())
137+ logger .debug (f"Using cached subject token (expires in { cached_entry .expiry - current_time } s)" )
138+ else :
139+ logger .debug ("Using cached subject token (no expiry)" )
140+ return None
141+
142+ # No valid cached token, need to get/exchange subject token
89143 headers = invocation_context .session .state .get (HEADERS_KEY , None )
90144 subject_token = _extract_jwt_from_headers (headers )
91145 if not subject_token :
92146 logger .debug ("No subject token found in headers for token propagation" )
93147 return None
148+
94149 if self .sts_integration :
150+ # Get actor token (from cache or fetch dynamically)
151+ actor_token = await self ._get_actor_token ()
152+ if actor_token is None and self .sts_integration .fetch_actor_token :
153+ # Dynamic fetch failed; already logged a warning in _get_actor_token
154+ return None
155+
95156 try :
96157 subject_token = await self .sts_integration .exchange_token (
97158 subject_token = subject_token ,
98159 subject_token_type = TokenType .JWT ,
99- actor_token = self . sts_integration . _actor_token ,
100- actor_token_type = TokenType .JWT if self . sts_integration . _actor_token else None ,
160+ actor_token = actor_token ,
161+ actor_token_type = TokenType .JWT if actor_token else None ,
101162 )
102163 except Exception as e :
103164 logger .warning (f"STS token exchange failed: { e } " )
104165 return None
105- # no sts, just propagate the subject token upstream
106- self .token_cache [self .cache_key (invocation_context )] = subject_token
166+
167+ # Extract expiry from the token
168+ expiry = _extract_jwt_expiry (subject_token )
169+
170+ # Cache the token with metadata
171+ self .token_cache [cache_key ] = _TokenCacheEntry (
172+ token = subject_token ,
173+ expiry = expiry ,
174+ )
175+ logger .debug ("Cached new subject token" )
107176 return None
108177
109178 def cache_key (self , invocation_context : InvocationContext ) -> str :
110179 """Generate a cache key based on the session ID."""
111180 return invocation_context .session .id
112181
182+ async def _get_actor_token (self ) -> Optional [str ]:
183+ """Get actor token from cache or fetch dynamically.
184+
185+ Returns:
186+ Actor token string if available, None otherwise
187+ """
188+ if not self .sts_integration :
189+ return None
190+
191+ # Use static token if no dynamic fetch function
192+ if not self .sts_integration .fetch_actor_token :
193+ return self .sts_integration ._actor_token
194+
195+ # Check cache for unexpired dynamic token
196+ if self .actor_token_cache :
197+ if not _has_token_expired (self .actor_token_cache .expiry ):
198+ # Token is still valid
199+ if self .actor_token_cache .expiry :
200+ current_time = int (time .time ())
201+ logger .debug (
202+ f"Using cached actor token (expires in { self .actor_token_cache .expiry - current_time } s)"
203+ )
204+ else :
205+ logger .debug ("Using cached actor token (no expiry)" )
206+ return self .actor_token_cache .token
207+ else :
208+ logger .debug ("Cached actor token expired, fetching new one" )
209+
210+ # Fetch new actor token
211+ try :
212+ if inspect .iscoroutinefunction (self .sts_integration .fetch_actor_token ):
213+ actor_token = await self .sts_integration .fetch_actor_token ()
214+ else :
215+ actor_token = self .sts_integration .fetch_actor_token ()
216+
217+ # Extract expiry and cache the token
218+ expiry = _extract_jwt_expiry (actor_token )
219+ self .actor_token_cache = _TokenCacheEntry (token = actor_token , expiry = expiry )
220+ logger .debug ("Fetched and cached new actor token" )
221+ return actor_token
222+
223+ except Exception as e :
224+ logger .warning (f"Failed to fetch actor token dynamically: { e } " )
225+ return None
226+
113227 @override
114228 async def after_run_callback (
115229 self ,
116230 * ,
117231 invocation_context : InvocationContext ,
118232 ) -> Optional [dict ]:
119- # delete token after run
120- self .token_cache .pop (self .cache_key (invocation_context ), None )
233+ """Clean up expired tokens after run, preserving valid tokens."""
234+ cache_key = self .cache_key (invocation_context )
235+ cache_entry = self .token_cache .get (cache_key )
236+
237+ # Clean up subject token cache - only remove if expired
238+ if cache_entry and _has_token_expired (cache_entry .expiry ):
239+ logger .debug ("Removing expired subject token from cache" )
240+ self .token_cache .pop (cache_key , None )
241+
242+ # Clean up expired actor token cache
243+ if self .actor_token_cache and _has_token_expired (self .actor_token_cache .expiry ):
244+ logger .debug ("Removing expired actor token from cache" )
245+ self .actor_token_cache = None
246+
121247 return None
122248
123249
250+ def _has_token_expired (expiry : Optional [int ], buffer_seconds : int = 5 ) -> bool :
251+ """Check if a token has expired or will expire soon.
252+
253+ Args:
254+ expiry: Token expiry timestamp (Unix epoch), or None if no expiry
255+ buffer_seconds: Additional buffer time in seconds to treat tokens
256+ expiring soon as already expired (default: 5)
257+
258+ Returns:
259+ True if token has expired or will expire within buffer_seconds,
260+ False if still valid or no expiry
261+ """
262+ if expiry is None :
263+ return False # No expiry means never expires
264+
265+ current_time = int (time .time ())
266+ return expiry <= (current_time + buffer_seconds )
267+
268+
124269def _extract_jwt_from_headers (headers : dict [str , str ]) -> Optional [str ]:
125270 """Extract JWT from request headers for STS token exchange.
126271
@@ -150,3 +295,31 @@ def _extract_jwt_from_headers(headers: dict[str, str]) -> Optional[str]:
150295
151296 logger .debug (f"Successfully extracted JWT token (length: { len (jwt_token )} )" )
152297 return jwt_token
298+
299+
300+ def _extract_jwt_expiry (token : str ) -> Optional [int ]:
301+ """Extract expiry timestamp from JWT token.
302+
303+ NOTE: This function does NOT validate the token signature.
304+ It is only used for cache management, not security decisions.
305+ Token validation happens in the STS server during exchange.
306+
307+ Args:
308+ token: JWT token string
309+
310+ Returns:
311+ Expiry timestamp (Unix epoch) if found, None otherwise
312+ """
313+ try :
314+ # Decode without verification (we only need the expiry claim)
315+ decoded = jwt .decode (token , options = {"verify_signature" : False })
316+ expiry = decoded .get ("exp" )
317+ if expiry :
318+ logger .debug (f"Extracted JWT expiry: { expiry } " )
319+ return int (expiry )
320+
321+ logger .debug ("No 'exp' claim found in JWT" )
322+ return None
323+ except Exception as e :
324+ logger .warning (f"Failed to extract JWT expiry: { e } " )
325+ return None
0 commit comments