Skip to content
Merged
109 changes: 52 additions & 57 deletions src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import base64
import hashlib
import logging
import re
import secrets
import string
import time
Expand Down Expand Up @@ -203,10 +204,39 @@ def __init__(
)
self._initialized = False

async def _discover_protected_resource(self) -> httpx.Request:
"""Build discovery request for protected resource metadata."""
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource")
def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None:
"""
Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.

Returns:
Resource metadata URL if found in WWW-Authenticate header, None otherwise
"""
if not init_response or init_response.status_code != 401:
return None

www_auth_header = init_response.headers.get("WWW-Authenticate")
if not www_auth_header:
return None

# Pattern matches: resource_metadata="url" or resource_metadata=url (unquoted)
pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))'
match = re.search(pattern, www_auth_header)

if match:
# Return quoted value if present, otherwise unquoted value
return match.group(1) or match.group(2)

return None

async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request:
# RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response
url = self._extract_resource_metadata_from_www_auth(init_response)

if not url:
# Fallback to well-known discovery
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource")

return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})

async def _handle_protected_resource_response(self, response: httpx.Response) -> None:
Expand Down Expand Up @@ -490,64 +520,26 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
# Capture protocol version from request headers
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)

# Perform OAuth flow if not authenticated
if not self.context.is_token_valid():
try:
# OAuth flow must be inline due to generator constraints
# Step 1: Discover protected resource metadata (spec revision 2025-06-18)
discovery_request = await self._discover_protected_resource()
discovery_response = yield discovery_request
await self._handle_protected_resource_response(discovery_response)

# Step 2: Discover OAuth metadata (with fallback for legacy servers)
oauth_request = await self._discover_oauth_metadata()
oauth_response = yield oauth_request
handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False)

# If path-aware discovery failed with 404, try fallback to root
if not handled:
fallback_request = await self._discover_oauth_metadata_fallback()
fallback_response = yield fallback_request
await self._handle_oauth_metadata_response(fallback_response, is_fallback=True)

# Step 3: Register client if needed
registration_request = await self._register_client()
if registration_request:
registration_response = yield registration_request
await self._handle_registration_response(registration_response)

# Step 4: Perform authorization
auth_code, code_verifier = await self._perform_authorization()

# Step 5: Exchange authorization code for tokens
token_request = await self._exchange_token(auth_code, code_verifier)
token_response = yield token_request
await self._handle_token_response(token_response)
except Exception:
logger.exception("OAuth flow error")
raise

# Add authorization header and make request
self._add_auth_header(request)
response = yield request

# Handle 401 responses
if response.status_code == 401 and self.context.can_refresh_token():
if not self.context.is_token_valid() and self.context.can_refresh_token():
# Try to refresh token
refresh_request = await self._refresh_token()
refresh_response = yield refresh_request

if await self._handle_refresh_response(refresh_response):
# Retry original request with new token
self._add_auth_header(request)
yield request
else:
if not await self._handle_refresh_response(refresh_response):
# Refresh failed, need full re-authentication
self._initialized = False

if self.context.is_token_valid():
self._add_auth_header(request)

response = yield request

if response.status_code == 401:
# Perform full OAuth flow
try:
# OAuth flow must be inline due to generator constraints
# Step 1: Discover protected resource metadata (spec revision 2025-06-18)
discovery_request = await self._discover_protected_resource()
# Step 1: Discover protected resource metadata (RFC9728 with WWW-Authenticate support)
discovery_request = await self._discover_protected_resource(response)
discovery_response = yield discovery_request
await self._handle_protected_resource_response(discovery_response)

Expand Down Expand Up @@ -575,7 +567,10 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
token_request = await self._exchange_token(auth_code, code_verifier)
token_response = yield token_request
await self._handle_token_response(token_response)
except Exception:
logger.exception("OAuth flow error")
raise

# Retry with new tokens
self._add_auth_header(request)
yield request
# Retry with new tokens
self._add_auth_header(request)
yield request
146 changes: 143 additions & 3 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,43 @@ class TestOAuthFlow:
"""Test OAuth flow methods."""

@pytest.mark.anyio
async def test_discover_protected_resource_request(self, oauth_provider):
"""Test protected resource discovery request building."""
request = await oauth_provider._discover_protected_resource()
async def test_discover_protected_resource_request(self, client_metadata, mock_storage):
"""Test protected resource discovery request building maintains backward compatibility."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = OAuthClientProvider(
server_url="https://api.example.com",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

# Test without WWW-Authenticate (fallback)
init_response = httpx.Response(
status_code=401, headers={}, request=httpx.Request("GET", "https://request-api.example.com")
)

request = await provider._discover_protected_resource(init_response)
assert request.method == "GET"
assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource"
assert "mcp-protocol-version" in request.headers

# Test with WWW-Authenticate header
init_response.headers["WWW-Authenticate"] = (
'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"'
)

request = await provider._discover_protected_resource(init_response)
assert request.method == "GET"
assert str(request.url) == "https://prm.example.com/.well-known/oauth-protected-resource/path"
assert "mcp-protocol-version" in request.headers

@pytest.mark.anyio
async def test_discover_oauth_metadata_request(self, oauth_provider):
"""Test OAuth metadata discovery request building."""
Expand Down Expand Up @@ -660,3 +689,114 @@ def test_build_metadata(
"code_challenge_methods_supported": ["S256"],
}
)


class TestProtectedResourceWWWAuthenticate:
"""Test RFC9728 WWW-Authenticate header parsing functionality for protected resource."""

@pytest.mark.parametrize(
"www_auth_header,expected_url",
[
# Quoted URL
(
'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"',
"https://api.example.com/.well-known/oauth-protected-resource",
),
# Unquoted URL
(
"Bearer resource_metadata=https://api.example.com/.well-known/oauth-protected-resource",
"https://api.example.com/.well-known/oauth-protected-resource",
),
# Complex header with multiple parameters
(
'Bearer realm="api", resource_metadata="https://api.example.com/.well-known/oauth-protected-resource", '
'error="insufficient_scope"',
"https://api.example.com/.well-known/oauth-protected-resource",
),
# Different URL format
('Bearer resource_metadata="https://custom.domain.com/metadata"', "https://custom.domain.com/metadata"),
# With path and query params
(
'Bearer resource_metadata="https://api.example.com/auth/metadata?version=1"',
"https://api.example.com/auth/metadata?version=1",
),
],
)
def test_extract_resource_metadata_from_www_auth_valid_cases(
self, client_metadata, mock_storage, www_auth_header, expected_url
):
"""Test extraction of resource_metadata URL from various valid WWW-Authenticate headers."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

init_response = httpx.Response(
status_code=401,
headers={"WWW-Authenticate": www_auth_header},
request=httpx.Request("GET", "https://api.example.com/test"),
)

result = provider._extract_resource_metadata_from_www_auth(init_response)
assert result == expected_url

@pytest.mark.parametrize(
"status_code,www_auth_header,description",
[
# No header
(401, None, "no WWW-Authenticate header"),
# Empty header
(401, "", "empty WWW-Authenticate header"),
# Header without resource_metadata
(401, 'Bearer realm="api", error="insufficient_scope"', "no resource_metadata parameter"),
# Malformed header
(401, "Bearer resource_metadata=", "malformed resource_metadata parameter"),
# Non-401 status code
(
200,
'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"',
"200 OK response",
),
(
500,
'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"',
"500 error response",
),
],
)
def test_extract_resource_metadata_from_www_auth_invalid_cases(
self, client_metadata, mock_storage, status_code, www_auth_header, description
):
"""Test extraction returns None for invalid cases."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

headers = {"WWW-Authenticate": www_auth_header} if www_auth_header is not None else {}
init_response = httpx.Response(
status_code=status_code, headers=headers, request=httpx.Request("GET", "https://api.example.com/test")
)

result = provider._extract_resource_metadata_from_www_auth(init_response)
assert result is None, f"Should return None for {description}"
Loading