Skip to content

Commit 9af1131

Browse files
Merge pull request #9 from keycardai/feat/starlette-middleware
Feat/starlette middleware
2 parents c74eeb3 + dbbf7f9 commit 9af1131

File tree

20 files changed

+1466
-66
lines changed

20 files changed

+1466
-66
lines changed

packages/mcp/pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ requires-python = ">=3.10"
77
license = { text = "MIT" }
88
authors = [{ name = "KeyCard AI", email = "support@keycard.ai" }]
99
dependencies = [
10-
"keycardai-oauth==0.3.0",
11-
"mcp==1.14.0",
10+
"keycardai-oauth>=0.4.0,<1.0.0",
11+
"mcp>=1.13.1",
1212
"pydantic>=2.11.7",
13+
"httpx>=0.27.2",
14+
"starlette>=0.47.3",
1315
]
1416
keywords = ["mcp", "model-context-protocol", "authentication", "authorization", "ai", "llm"]
1517
classifiers = [

packages/mcp/src/keycardai/mcp/server/auth/provider.py

Lines changed: 250 additions & 29 deletions
Large diffs are not rendered by default.

packages/mcp/src/keycardai/mcp/server/auth/verifier.py

Lines changed: 120 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
from typing import Any
33

44
from mcp.server.auth.provider import AccessToken
5+
from pydantic import AnyHttpUrl
56

7+
from keycardai.oauth import Client
68
from keycardai.oauth.utils.jwt import (
79
get_header,
8-
get_verification_key,
10+
get_jwks_key,
911
parse_jwt_access_token,
1012
)
1113

@@ -22,15 +24,19 @@ def __init__(
2224
jwks_uri: str | None = None,
2325
allowed_algorithms: list[str] = None,
2426
cache_ttl: int = 300, # 5 minutes default
27+
enable_multi_zone: bool = False,
2528
):
2629
"""Initialize the KeyCard token verifier.
2730
2831
Args:
29-
issuer: Expected token issuer (required)
32+
issuer: Expected token issuer (required). When enable_multi_zone=True,
33+
this should be the top-level domain URL that will be used as base
34+
for zone-specific issuer construction.
3035
required_scopes: Required scopes for token validation
31-
jwks_uri: JWKS endpoint URL for key fetching
36+
jwks_uri: JWKS endpoint URL for key fetching (deprecated, use issuer)
3237
allowed_algorithms: JWT algorithms (default RS256)
3338
cache_ttl: JWKS cache TTL in seconds (default 300 = 5 minutes)
39+
enable_multi_zone: Enable multi-zone support where issuer is top-level domain
3440
"""
3541
if not issuer:
3642
raise ValueError("Issuer is required for token verification")
@@ -43,30 +49,99 @@ def __init__(
4349
self.cache_ttl = cache_ttl
4450

4551
self._jwks_cache = JWKSCache(ttl=cache_ttl, max_size=10)
52+
self._discovered_jwks_uri: str | None = None
53+
self._discovered_jwks_uris: dict[str, str] = {} # Initialize the cache dict
4654

47-
async def _get_verification_key(self, token: str) -> JWKSKey:
48-
"""Get the verification key for the token with caching."""
49-
if not self.jwks_uri:
50-
raise ValueError("JWKS URI not configured")
55+
self.enable_multi_zone = enable_multi_zone
56+
57+
def _discover_jwks_uri(self, zone_id: str | None = None) -> str:
58+
"""Discover JWKS URI from issuer lazily.
59+
60+
Args:
61+
zone_id: Zone ID for multi-zone scenarios. When provided with enable_multi_zone=True,
62+
constructs zone-specific issuer URL for discovery.
63+
"""
64+
cache_key = f"{zone_id or 'default'}"
65+
cached_uri = self._discovered_jwks_uris.get(cache_key)
66+
if cached_uri is not None:
67+
return cached_uri
68+
69+
if self.jwks_uri:
70+
self._discovered_jwks_uris[cache_key] = self.jwks_uri
71+
return self.jwks_uri
72+
73+
discovery_issuer = self.issuer
74+
if self.enable_multi_zone and zone_id:
75+
discovery_issuer = self._create_zone_scoped_url(self.issuer, zone_id)
76+
77+
try:
78+
with Client(discovery_issuer) as client:
79+
server_metadata = client.discover_server_metadata()
80+
discovered_uri = server_metadata.jwks_uri
5181

82+
if not discovered_uri:
83+
raise ValueError(f"Could not discover JWKS URI from issuer: {discovery_issuer}")
84+
85+
# Cache the successful discovery
86+
self._discovered_jwks_uris[cache_key] = discovered_uri
87+
return discovered_uri
88+
89+
except Exception as e:
90+
# Don't cache failures, let them retry
91+
raise ValueError(f"Could not discover JWKS URI from issuer {discovery_issuer}: {e}") from e
92+
93+
def _create_zone_scoped_url(self, base_url: str, zone_id: str) -> str:
94+
"""Create zone-scoped URL by prepending zone_id to the host."""
95+
base_url_obj = AnyHttpUrl(base_url)
96+
97+
port_part = ""
98+
if base_url_obj.port and not (
99+
(base_url_obj.scheme == "https" and base_url_obj.port == 443) or
100+
(base_url_obj.scheme == "http" and base_url_obj.port == 80)
101+
):
102+
port_part = f":{base_url_obj.port}"
103+
104+
zone_url = f"{base_url_obj.scheme}://{zone_id}.{base_url_obj.host}{port_part}"
105+
return zone_url
106+
107+
def _get_kid_and_algorithm(self, token: str) -> tuple[str, str]:
52108
header = get_header(token)
53109
kid = header.get("kid")
54110
algorithm = header.get("alg")
55111
if algorithm not in self.allowed_algorithms:
56112
raise ValueError(f"Unsupported algorithm: {algorithm}")
113+
return [kid, algorithm]
114+
115+
def _get_zone_jwks_uri(self, jwks_uri: str, zone_id: str) -> str:
116+
jwks_url = AnyHttpUrl(jwks_uri)
117+
jwks_zone_host = jwks_url.host.replace(jwks_url.host, f"{zone_id}.{jwks_url.host}")
118+
jwks_url.host = jwks_zone_host
119+
return jwks_url.to_string()
120+
121+
async def _get_verification_key(self, token: str, zone_id: str | None = None) -> JWKSKey:
122+
"""Get the verification key for the token with caching."""
123+
kid, algorithm = self._get_kid_and_algorithm(token)
57124

58125
cached_key = self._jwks_cache.get_key(kid)
59126
if cached_key is not None:
60127
return cached_key
61128

62-
verification_key = await get_verification_key(token, self.jwks_uri)
129+
if self.enable_multi_zone and zone_id:
130+
jwks_uri = self._discover_jwks_uri(zone_id)
131+
else:
132+
jwks_uri = self._discover_jwks_uri()
133+
if zone_id:
134+
jwks_uri = self._get_zone_jwks_uri(jwks_uri, zone_id)
135+
136+
verification_key = await get_jwks_key(kid, jwks_uri)
63137

64138
self._jwks_cache.set_key(kid, verification_key, algorithm)
65139
cached_key = self._jwks_cache.get_key(kid)
66140
if cached_key is None:
67141
raise ValueError("Failed to cache verification key")
68142
return cached_key
69143

144+
70145
def clear_cache(self) -> None:
71146
"""Clear the JWKS key cache."""
72147
self._jwks_cache.clear()
@@ -79,37 +154,25 @@ def get_cache_stats(self) -> dict[str, Any]:
79154
"""
80155
return self._jwks_cache.get_stats()
81156

82-
async def verify_token(self, token: str) -> AccessToken | None:
83-
"""Verify a JWT token and return AccessToken if valid.
84-
85-
Performs JWT verification including:
86-
- Parse token into structured JWTAccessToken model internally
87-
- Validate token expiration
88-
- Validate issuer if configured
89-
- Validate required scopes if configured
90-
- Convert to AccessToken format for return
91-
92-
Note: This is a simplified implementation that does not perform
93-
cryptographic signature verification. For production use, proper
94-
signature verification should be implemented.
95-
96-
Args:
97-
token: JWT token string to verify
98-
99-
Returns:
100-
AccessToken object if valid, None if invalid
101-
"""
102-
try:
103-
verification_key = await self._get_verification_key(token)
157+
async def verify_token_for_zone(self, token: str, zone_id: str) -> AccessToken | None:
158+
"""Verify a JWT token for a specific zone and return AccessToken if valid."""
159+
key = await self._get_verification_key(token, zone_id)
160+
return self._verify_token(token, key, zone_id)
104161

162+
def _verify_token(self, token: str, key: JWKSKey, zone_id: str | None = None) -> AccessToken | None:
105163
jwt_access_token = parse_jwt_access_token(
106-
token, verification_key.key, verification_key.algorithm
164+
token, key.key, key.algorithm
107165
)
108166

109167
if jwt_access_token.exp < time.time():
110168
return None
111169

112-
if jwt_access_token.iss != self.issuer:
170+
# Validate issuer, handling multi-zone scenarios
171+
expected_issuer = self.issuer
172+
if self.enable_multi_zone and zone_id:
173+
expected_issuer = self._create_zone_scoped_url(self.issuer, zone_id)
174+
175+
if jwt_access_token.iss != expected_issuer:
113176
return None
114177

115178
if self.required_scopes:
@@ -133,6 +196,31 @@ async def verify_token(self, token: str) -> AccessToken | None:
133196
resource=jwt_access_token.get_custom_claim("resource"),
134197
)
135198

199+
200+
async def verify_token(self, token: str) -> AccessToken | None:
201+
"""Verify a JWT token and return AccessToken if valid.
202+
203+
Performs JWT verification including:
204+
- Parse token into structured JWTAccessToken model internally
205+
- Validate token expiration
206+
- Validate issuer if configured
207+
- Validate required scopes if configured
208+
- Convert to AccessToken format for return
209+
210+
Note: This is a simplified implementation that does not perform
211+
cryptographic signature verification. For production use, proper
212+
signature verification should be implemented.
213+
214+
Args:
215+
token: JWT token string to verify
216+
217+
Returns:
218+
AccessToken object if valid, None if invalid
219+
"""
220+
try:
221+
key = await self._get_verification_key(token)
222+
return self._verify_token(token, key)
223+
136224
except Exception:
137225
return None
138226

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .metadata import (
2+
InferredProtectedResourceMetadata,
3+
authorization_server_metadata,
4+
protected_resource_metadata,
5+
)
6+
7+
__all__ = [
8+
"protected_resource_metadata",
9+
"authorization_server_metadata",
10+
"InferredProtectedResourceMetadata",
11+
]
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import json
2+
from collections.abc import Callable
3+
from dataclasses import dataclass
4+
5+
import httpx
6+
from mcp.shared.auth import ProtectedResourceMetadata
7+
from pydantic import AnyHttpUrl, Field
8+
from starlette.requests import Request
9+
from starlette.responses import Response
10+
11+
12+
class InferredProtectedResourceMetadata(ProtectedResourceMetadata):
13+
"""Extended ProtectedResourceMetadata that allows resource to be inferred from request."""
14+
resource: AnyHttpUrl | None = Field(default=None) # Override to make it optional
15+
16+
@dataclass
17+
class AuthorizationServerMetadata:
18+
base_url: str
19+
20+
21+
def _is_authorization_server_zone_scoped(authorization_server_urls: AnyHttpUrl) -> bool:
22+
if len(authorization_server_urls) != 1:
23+
return False
24+
return len(authorization_server_urls[0].host.split(".")) == 3
25+
26+
def _get_zone_id_from_path(path: str) -> str | None:
27+
path = path.lstrip("/").rstrip("/")
28+
zone_id = path.split("/")[0]
29+
if zone_id == "" or zone_id == "/":
30+
return None
31+
return zone_id
32+
33+
def _remove_well_known_prefix(path: str) -> str:
34+
prefix = ".well-known/oauth-protected-resource"
35+
path = path.lstrip("/").rstrip("/")
36+
if path.startswith(prefix):
37+
return path[len(prefix):]
38+
return path
39+
40+
def _create_zone_scoped_authorization_server_url(zone_id: str, authorization_server_url: AnyHttpUrl) -> AnyHttpUrl:
41+
port_part = f":{authorization_server_url.port}" if authorization_server_url.port else ""
42+
url = f"{authorization_server_url.scheme}://{zone_id}.{authorization_server_url.host}{port_part}"
43+
return AnyHttpUrl(url)
44+
45+
def _strip_zone_id_from_path(zone_id: str, path: str) -> str:
46+
path = path.lstrip("/").rstrip("/")
47+
if path.startswith(zone_id):
48+
return path[len(zone_id):]
49+
return path
50+
51+
def _create_resource_url(base_url: str | AnyHttpUrl, path: str) -> AnyHttpUrl:
52+
base_url_str = str(base_url).rstrip("/")
53+
if path and not path.startswith("/"):
54+
path = "/" + path
55+
url = f"{base_url_str}{path}".rstrip("/")
56+
if url.endswith("://") or (path == "/" and not url.endswith("/")):
57+
url += "/"
58+
return AnyHttpUrl(url)
59+
60+
def _remove_authorization_server_prefix(path: str) -> str:
61+
"""Remove the /.well-known/oauth-authorization-server prefix from the path."""
62+
auth_server_prefix = "/.well-known/oauth-authorization-server"
63+
if path.startswith(auth_server_prefix):
64+
return path[len(auth_server_prefix):]
65+
return path
66+
67+
def protected_resource_metadata(metadata: InferredProtectedResourceMetadata, enable_multi_zone: bool = False) -> Callable:
68+
def wrapper(request: Request) -> Response:
69+
# Create a copy of the metadata to avoid mutating the original
70+
request_metadata = metadata.model_copy(deep=True)
71+
path = _remove_well_known_prefix(request.url.path)
72+
if enable_multi_zone or not _is_authorization_server_zone_scoped(request_metadata.authorization_servers):
73+
zone_id = _get_zone_id_from_path(path)
74+
if zone_id:
75+
request_metadata.authorization_servers = [ _create_zone_scoped_authorization_server_url(zone_id, request_metadata.authorization_servers[0]) ]
76+
77+
resource = _create_resource_url(request.base_url, path)
78+
mcp_version = request.headers.get("mcp-protocol-version")
79+
request_metadata.resource = resource
80+
# TODO: what is the reason for this?
81+
if mcp_version == "2025-03-26":
82+
json["authorization_servers"] = [ request.base_url ]
83+
return Response(content=request_metadata.model_dump_json(exclude_none=True), status_code=200)
84+
return wrapper
85+
86+
def authorization_server_metadata(issuer: str, enable_multi_zone: bool = False) -> Callable:
87+
def wrapper(request: Request) -> Response:
88+
try:
89+
actual_issuer = issuer
90+
path = _remove_authorization_server_prefix(request.url.path)
91+
92+
if enable_multi_zone or not _is_authorization_server_zone_scoped([AnyHttpUrl(issuer)]):
93+
zone_id = _get_zone_id_from_path(path)
94+
if zone_id:
95+
actual_issuer = str(_create_zone_scoped_authorization_server_url(zone_id, AnyHttpUrl(issuer)))
96+
97+
with httpx.Client() as client:
98+
resp = client.get(f"{actual_issuer}/.well-known/oauth-authorization-server")
99+
resp.raise_for_status()
100+
authorization_server_metadata = resp.json()
101+
authorization_server_metadata["authorization_endpoint"] = f"{request.base_url}{authorization_server_metadata['authorization_endpoint']}"
102+
return Response(content=json.dumps(authorization_server_metadata), status_code=200)
103+
except Exception as e:
104+
error_message = {"error": str(e), "type": type(e).__name__}
105+
return Response(content=json.dumps(error_message), status_code=500)
106+
return wrapper
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .bearer import BearerAuthMiddleware
2+
3+
__all__ = [
4+
"BearerAuthMiddleware",
5+
]

0 commit comments

Comments
 (0)