22from typing import Any
33
44from mcp .server .auth .provider import AccessToken
5+ from pydantic import AnyHttpUrl
56
7+ from keycardai .oauth import Client
68from keycardai .oauth .utils .jwt import (
79 get_header ,
8- get_verification_key ,
10+ get_jwks_key ,
911 parse_jwt_access_token ,
1012)
1113
@@ -22,15 +24,19 @@ def __init__(
2224 jwks_uri : str | None = None ,
2325 allowed_algorithms : list [str ] = None ,
2426 cache_ttl : int = 300 , # 5 minutes default
27+ enable_multi_zone : bool = False ,
2528 ):
2629 """Initialize the KeyCard token verifier.
2730
2831 Args:
29- issuer: Expected token issuer (required)
32+ issuer: Expected token issuer (required). When enable_multi_zone=True,
33+ this should be the top-level domain URL that will be used as base
34+ for zone-specific issuer construction.
3035 required_scopes: Required scopes for token validation
31- jwks_uri: JWKS endpoint URL for key fetching
36+ jwks_uri: JWKS endpoint URL for key fetching (deprecated, use issuer)
3237 allowed_algorithms: JWT algorithms (default RS256)
3338 cache_ttl: JWKS cache TTL in seconds (default 300 = 5 minutes)
39+ enable_multi_zone: Enable multi-zone support where issuer is top-level domain
3440 """
3541 if not issuer :
3642 raise ValueError ("Issuer is required for token verification" )
@@ -43,30 +49,99 @@ def __init__(
4349 self .cache_ttl = cache_ttl
4450
4551 self ._jwks_cache = JWKSCache (ttl = cache_ttl , max_size = 10 )
52+ self ._discovered_jwks_uri : str | None = None
53+ self ._discovered_jwks_uris : dict [str , str ] = {} # Initialize the cache dict
4654
47- async def _get_verification_key (self , token : str ) -> JWKSKey :
48- """Get the verification key for the token with caching."""
49- if not self .jwks_uri :
50- raise ValueError ("JWKS URI not configured" )
55+ self .enable_multi_zone = enable_multi_zone
56+
57+ def _discover_jwks_uri (self , zone_id : str | None = None ) -> str :
58+ """Discover JWKS URI from issuer lazily.
59+
60+ Args:
61+ zone_id: Zone ID for multi-zone scenarios. When provided with enable_multi_zone=True,
62+ constructs zone-specific issuer URL for discovery.
63+ """
64+ cache_key = f"{ zone_id or 'default' } "
65+ cached_uri = self ._discovered_jwks_uris .get (cache_key )
66+ if cached_uri is not None :
67+ return cached_uri
68+
69+ if self .jwks_uri :
70+ self ._discovered_jwks_uris [cache_key ] = self .jwks_uri
71+ return self .jwks_uri
72+
73+ discovery_issuer = self .issuer
74+ if self .enable_multi_zone and zone_id :
75+ discovery_issuer = self ._create_zone_scoped_url (self .issuer , zone_id )
76+
77+ try :
78+ with Client (discovery_issuer ) as client :
79+ server_metadata = client .discover_server_metadata ()
80+ discovered_uri = server_metadata .jwks_uri
5181
82+ if not discovered_uri :
83+ raise ValueError (f"Could not discover JWKS URI from issuer: { discovery_issuer } " )
84+
85+ # Cache the successful discovery
86+ self ._discovered_jwks_uris [cache_key ] = discovered_uri
87+ return discovered_uri
88+
89+ except Exception as e :
90+ # Don't cache failures, let them retry
91+ raise ValueError (f"Could not discover JWKS URI from issuer { discovery_issuer } : { e } " ) from e
92+
93+ def _create_zone_scoped_url (self , base_url : str , zone_id : str ) -> str :
94+ """Create zone-scoped URL by prepending zone_id to the host."""
95+ base_url_obj = AnyHttpUrl (base_url )
96+
97+ port_part = ""
98+ if base_url_obj .port and not (
99+ (base_url_obj .scheme == "https" and base_url_obj .port == 443 ) or
100+ (base_url_obj .scheme == "http" and base_url_obj .port == 80 )
101+ ):
102+ port_part = f":{ base_url_obj .port } "
103+
104+ zone_url = f"{ base_url_obj .scheme } ://{ zone_id } .{ base_url_obj .host } { port_part } "
105+ return zone_url
106+
107+ def _get_kid_and_algorithm (self , token : str ) -> tuple [str , str ]:
52108 header = get_header (token )
53109 kid = header .get ("kid" )
54110 algorithm = header .get ("alg" )
55111 if algorithm not in self .allowed_algorithms :
56112 raise ValueError (f"Unsupported algorithm: { algorithm } " )
113+ return [kid , algorithm ]
114+
115+ def _get_zone_jwks_uri (self , jwks_uri : str , zone_id : str ) -> str :
116+ jwks_url = AnyHttpUrl (jwks_uri )
117+ jwks_zone_host = jwks_url .host .replace (jwks_url .host , f"{ zone_id } .{ jwks_url .host } " )
118+ jwks_url .host = jwks_zone_host
119+ return jwks_url .to_string ()
120+
121+ async def _get_verification_key (self , token : str , zone_id : str | None = None ) -> JWKSKey :
122+ """Get the verification key for the token with caching."""
123+ kid , algorithm = self ._get_kid_and_algorithm (token )
57124
58125 cached_key = self ._jwks_cache .get_key (kid )
59126 if cached_key is not None :
60127 return cached_key
61128
62- verification_key = await get_verification_key (token , self .jwks_uri )
129+ if self .enable_multi_zone and zone_id :
130+ jwks_uri = self ._discover_jwks_uri (zone_id )
131+ else :
132+ jwks_uri = self ._discover_jwks_uri ()
133+ if zone_id :
134+ jwks_uri = self ._get_zone_jwks_uri (jwks_uri , zone_id )
135+
136+ verification_key = await get_jwks_key (kid , jwks_uri )
63137
64138 self ._jwks_cache .set_key (kid , verification_key , algorithm )
65139 cached_key = self ._jwks_cache .get_key (kid )
66140 if cached_key is None :
67141 raise ValueError ("Failed to cache verification key" )
68142 return cached_key
69143
144+
70145 def clear_cache (self ) -> None :
71146 """Clear the JWKS key cache."""
72147 self ._jwks_cache .clear ()
@@ -79,37 +154,25 @@ def get_cache_stats(self) -> dict[str, Any]:
79154 """
80155 return self ._jwks_cache .get_stats ()
81156
82- async def verify_token (self , token : str ) -> AccessToken | None :
83- """Verify a JWT token and return AccessToken if valid.
84-
85- Performs JWT verification including:
86- - Parse token into structured JWTAccessToken model internally
87- - Validate token expiration
88- - Validate issuer if configured
89- - Validate required scopes if configured
90- - Convert to AccessToken format for return
91-
92- Note: This is a simplified implementation that does not perform
93- cryptographic signature verification. For production use, proper
94- signature verification should be implemented.
95-
96- Args:
97- token: JWT token string to verify
98-
99- Returns:
100- AccessToken object if valid, None if invalid
101- """
102- try :
103- verification_key = await self ._get_verification_key (token )
157+ async def verify_token_for_zone (self , token : str , zone_id : str ) -> AccessToken | None :
158+ """Verify a JWT token for a specific zone and return AccessToken if valid."""
159+ key = await self ._get_verification_key (token , zone_id )
160+ return self ._verify_token (token , key , zone_id )
104161
162+ def _verify_token (self , token : str , key : JWKSKey , zone_id : str | None = None ) -> AccessToken | None :
105163 jwt_access_token = parse_jwt_access_token (
106- token , verification_key .key , verification_key .algorithm
164+ token , key .key , key .algorithm
107165 )
108166
109167 if jwt_access_token .exp < time .time ():
110168 return None
111169
112- if jwt_access_token .iss != self .issuer :
170+ # Validate issuer, handling multi-zone scenarios
171+ expected_issuer = self .issuer
172+ if self .enable_multi_zone and zone_id :
173+ expected_issuer = self ._create_zone_scoped_url (self .issuer , zone_id )
174+
175+ if jwt_access_token .iss != expected_issuer :
113176 return None
114177
115178 if self .required_scopes :
@@ -133,6 +196,31 @@ async def verify_token(self, token: str) -> AccessToken | None:
133196 resource = jwt_access_token .get_custom_claim ("resource" ),
134197 )
135198
199+
200+ async def verify_token (self , token : str ) -> AccessToken | None :
201+ """Verify a JWT token and return AccessToken if valid.
202+
203+ Performs JWT verification including:
204+ - Parse token into structured JWTAccessToken model internally
205+ - Validate token expiration
206+ - Validate issuer if configured
207+ - Validate required scopes if configured
208+ - Convert to AccessToken format for return
209+
210+ Note: This is a simplified implementation that does not perform
211+ cryptographic signature verification. For production use, proper
212+ signature verification should be implemented.
213+
214+ Args:
215+ token: JWT token string to verify
216+
217+ Returns:
218+ AccessToken object if valid, None if invalid
219+ """
220+ try :
221+ key = await self ._get_verification_key (token )
222+ return self ._verify_token (token , key )
223+
136224 except Exception :
137225 return None
138226
0 commit comments