Skip to content

Commit dcd0310

Browse files
committed
add secure annotation
1 parent c77dd2c commit dcd0310

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

105 files changed

+297
-13662
lines changed

src/mcp/client/auth.py

Lines changed: 95 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import base64
88
import hashlib
99
import logging
10+
import re
1011
import secrets
1112
import string
1213
import time
@@ -203,10 +204,39 @@ def __init__(
203204
)
204205
self._initialized = False
205206

206-
async def _discover_protected_resource(self) -> httpx.Request:
207-
"""Build discovery request for protected resource metadata."""
208-
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
209-
url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource")
207+
def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None:
208+
"""
209+
Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.
210+
211+
Returns:
212+
Resource metadata URL if found in WWW-Authenticate header, None otherwise
213+
"""
214+
if not init_response or init_response.status_code != 401:
215+
return None
216+
217+
www_auth_header = init_response.headers.get("WWW-Authenticate")
218+
if not www_auth_header:
219+
return None
220+
221+
# Pattern matches: resource_metadata="url" or resource_metadata=url (unquoted)
222+
pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))'
223+
match = re.search(pattern, www_auth_header)
224+
225+
if match:
226+
# Return quoted value if present, otherwise unquoted value
227+
return match.group(1) or match.group(2)
228+
229+
return None
230+
231+
async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request:
232+
# RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response
233+
url = self._extract_resource_metadata_from_www_auth(init_response)
234+
235+
if not url:
236+
# Fallback to well-known discovery
237+
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
238+
url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource")
239+
210240
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
211241

212242
async def _handle_protected_resource_response(self, response: httpx.Response) -> None:
@@ -221,72 +251,32 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
221251
except ValidationError:
222252
pass
223253

224-
def _build_well_known_path(self, pathname: str) -> str:
225-
"""Construct well-known path for OAuth metadata discovery."""
226-
well_known_path = f"/.well-known/oauth-authorization-server{pathname}"
227-
if pathname.endswith("/"):
228-
# Strip trailing slash from pathname to avoid double slashes
229-
well_known_path = well_known_path[:-1]
230-
return well_known_path
231-
232-
def _should_attempt_fallback(self, response_status: int, pathname: str) -> bool:
233-
"""Determine if fallback to root discovery should be attempted."""
234-
return response_status == 404 and pathname != "/"
235-
236-
async def _try_metadata_discovery(self, url: str) -> httpx.Request:
237-
"""Build metadata discovery request for a specific URL."""
238-
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
239-
240-
async def _discover_oauth_metadata(self) -> httpx.Request:
241-
"""Build OAuth metadata discovery request with fallback support."""
242-
if self.context.auth_server_url:
243-
auth_server_url = self.context.auth_server_url
244-
else:
245-
auth_server_url = self.context.server_url
246-
247-
# Per RFC 8414, try path-aware discovery first
254+
def _get_discovery_urls(self) -> list[str]:
255+
"""Generate ordered list of (url, type) tuples for discovery attempts."""
256+
urls: list[str] = []
257+
auth_server_url = self.context.auth_server_url or self.context.server_url
248258
parsed = urlparse(auth_server_url)
249-
well_known_path = self._build_well_known_path(parsed.path)
250259
base_url = f"{parsed.scheme}://{parsed.netloc}"
251-
url = urljoin(base_url, well_known_path)
252-
253-
# Store fallback info for use in response handler
254-
self.context.discovery_base_url = base_url
255-
self.context.discovery_pathname = parsed.path
256260

257-
return await self._try_metadata_discovery(url)
261+
# RFC 8414: Path-aware OAuth discovery
262+
if parsed.path and parsed.path != "/":
263+
oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}"
264+
urls.append(urljoin(base_url, oauth_path))
258265

259-
async def _discover_oauth_metadata_fallback(self) -> httpx.Request:
260-
"""Build fallback OAuth metadata discovery request for legacy servers."""
261-
base_url = getattr(self.context, "discovery_base_url", "")
262-
if not base_url:
263-
raise OAuthFlowError("No base URL available for fallback discovery")
266+
# OAuth root fallback
267+
urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server"))
264268

265-
# Fallback to root discovery for legacy servers
266-
url = urljoin(base_url, "/.well-known/oauth-authorization-server")
267-
return await self._try_metadata_discovery(url)
268-
269-
async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fallback: bool = False) -> bool:
270-
"""Handle OAuth metadata response. Returns True if handled successfully."""
271-
if response.status_code == 200:
272-
try:
273-
content = await response.aread()
274-
metadata = OAuthMetadata.model_validate_json(content)
275-
self.context.oauth_metadata = metadata
276-
# Apply default scope if none specified
277-
if self.context.client_metadata.scope is None and metadata.scopes_supported is not None:
278-
self.context.client_metadata.scope = " ".join(metadata.scopes_supported)
279-
return True
280-
except ValidationError:
281-
pass
269+
# RFC 8414 section 5: Path-aware OIDC discovery
270+
# See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
271+
if parsed.path and parsed.path != "/":
272+
oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
273+
urls.append(urljoin(base_url, oidc_path))
282274

283-
# Check if we should attempt fallback (404 on path-aware discovery)
284-
if not is_fallback and self._should_attempt_fallback(
285-
response.status_code, getattr(self.context, "discovery_pathname", "/")
286-
):
287-
return False # Signal that fallback should be attempted
275+
# OIDC 1.0 fallback (appends to full URL per OIDC spec)
276+
oidc_fallback = f"{auth_server_url.rstrip('/')}/.well-known/openid-configuration"
277+
urls.append(oidc_fallback)
288278

289-
return True # Signal no fallback needed (either success or non-404 error)
279+
return urls
290280

291281
async def _register_client(self) -> httpx.Request | None:
292282
"""Build registration request or skip if already registered."""
@@ -481,6 +471,17 @@ def _add_auth_header(self, request: httpx.Request) -> None:
481471
if self.context.current_tokens and self.context.current_tokens.access_token:
482472
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
483473

474+
def _create_oauth_metadata_request(self, url: str) -> httpx.Request:
475+
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
476+
477+
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
478+
content = await response.aread()
479+
metadata = OAuthMetadata.model_validate_json(content)
480+
self.context.oauth_metadata = metadata
481+
# Apply default scope if needed
482+
if self.context.client_metadata.scope is None and metadata.scopes_supported is not None:
483+
self.context.client_metadata.scope = " ".join(metadata.scopes_supported)
484+
484485
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
485486
"""HTTPX auth flow integration."""
486487
async with self.context.lock:
@@ -490,77 +491,43 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
490491
# Capture protocol version from request headers
491492
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
492493

493-
# Perform OAuth flow if not authenticated
494-
if not self.context.is_token_valid():
495-
try:
496-
# OAuth flow must be inline due to generator constraints
497-
# Step 1: Discover protected resource metadata (spec revision 2025-06-18)
498-
discovery_request = await self._discover_protected_resource()
499-
discovery_response = yield discovery_request
500-
await self._handle_protected_resource_response(discovery_response)
501-
502-
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
503-
oauth_request = await self._discover_oauth_metadata()
504-
oauth_response = yield oauth_request
505-
handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False)
506-
507-
# If path-aware discovery failed with 404, try fallback to root
508-
if not handled:
509-
fallback_request = await self._discover_oauth_metadata_fallback()
510-
fallback_response = yield fallback_request
511-
await self._handle_oauth_metadata_response(fallback_response, is_fallback=True)
512-
513-
# Step 3: Register client if needed
514-
registration_request = await self._register_client()
515-
if registration_request:
516-
registration_response = yield registration_request
517-
await self._handle_registration_response(registration_response)
518-
519-
# Step 4: Perform authorization
520-
auth_code, code_verifier = await self._perform_authorization()
521-
522-
# Step 5: Exchange authorization code for tokens
523-
token_request = await self._exchange_token(auth_code, code_verifier)
524-
token_response = yield token_request
525-
await self._handle_token_response(token_response)
526-
except Exception:
527-
logger.exception("OAuth flow error")
528-
raise
529-
530-
# Add authorization header and make request
531-
self._add_auth_header(request)
532-
response = yield request
533-
534-
# Handle 401 responses
535-
if response.status_code == 401 and self.context.can_refresh_token():
494+
if not self.context.is_token_valid() and self.context.can_refresh_token():
536495
# Try to refresh token
537496
refresh_request = await self._refresh_token()
538497
refresh_response = yield refresh_request
539498

540-
if await self._handle_refresh_response(refresh_response):
541-
# Retry original request with new token
542-
self._add_auth_header(request)
543-
yield request
544-
else:
499+
if not await self._handle_refresh_response(refresh_response):
545500
# Refresh failed, need full re-authentication
546501
self._initialized = False
547502

503+
if self.context.is_token_valid():
504+
self._add_auth_header(request)
505+
506+
response = yield request
507+
508+
if response.status_code == 401:
509+
# Perform full OAuth flow
510+
try:
548511
# OAuth flow must be inline due to generator constraints
549-
# Step 1: Discover protected resource metadata (spec revision 2025-06-18)
550-
discovery_request = await self._discover_protected_resource()
512+
# Step 1: Discover protected resource metadata (RFC9728 with WWW-Authenticate support)
513+
discovery_request = await self._discover_protected_resource(response)
551514
discovery_response = yield discovery_request
552515
await self._handle_protected_resource_response(discovery_response)
553516

554517
# Step 2: Discover OAuth metadata (with fallback for legacy servers)
555-
oauth_request = await self._discover_oauth_metadata()
556-
oauth_response = yield oauth_request
557-
handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False)
558-
559-
# If path-aware discovery failed with 404, try fallback to root
560-
if not handled:
561-
fallback_request = await self._discover_oauth_metadata_fallback()
562-
fallback_response = yield fallback_request
563-
await self._handle_oauth_metadata_response(fallback_response, is_fallback=True)
518+
discovery_urls = self._get_discovery_urls()
519+
for url in discovery_urls:
520+
oauth_metadata_request = self._create_oauth_metadata_request(url)
521+
oauth_metadata_response = yield oauth_metadata_request
522+
523+
if oauth_metadata_response.status_code == 200:
524+
try:
525+
await self._handle_oauth_metadata_response(oauth_metadata_response)
526+
break
527+
except ValidationError:
528+
continue
529+
elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500:
530+
break # Non-4XX error, stop trying
564531

565532
# Step 3: Register client if needed
566533
registration_request = await self._register_client()
@@ -575,7 +542,10 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
575542
token_request = await self._exchange_token(auth_code, code_verifier)
576543
token_response = yield token_request
577544
await self._handle_token_response(token_response)
545+
except Exception:
546+
logger.exception("OAuth flow error")
547+
raise
578548

579-
# Retry with new tokens
580-
self._add_auth_header(request)
581-
yield request
549+
# Retry with new tokens
550+
self._add_auth_header(request)
551+
yield request

src/mcp/client/stdio/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"HOMEPATH",
3333
"LOCALAPPDATA",
3434
"PATH",
35+
"PATHEXT",
3536
"PROCESSOR_ARCHITECTURE",
3637
"SYSTEMDRIVE",
3738
"SYSTEMROOT",

src/mcp/client/streamable_http.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
MCP_PROTOCOL_VERSION = "mcp-protocol-version"
4444
LAST_EVENT_ID = "last-event-id"
4545
CONTENT_TYPE = "content-type"
46-
ACCEPT = "Accept"
46+
ACCEPT = "accept"
4747

4848

4949
JSON = "application/json"
@@ -248,6 +248,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:
248248
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
249249
)
250250
if is_complete:
251+
await event_source.response.aclose()
251252
break
252253

253254
async def _handle_post_request(self, ctx: RequestContext) -> None:
@@ -330,6 +331,7 @@ async def _handle_sse_response(
330331
# If the SSE event indicates completion, like returning respose/error
331332
# break the loop
332333
if is_complete:
334+
await response.aclose()
333335
break
334336
except Exception as e:
335337
logger.exception("Error reading SSE stream:")

0 commit comments

Comments
 (0)