Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
Expand Down
96 changes: 81 additions & 15 deletions src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,22 +204,19 @@ def __init__(
)
self._initialized = False

def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None:
def _extract_field_from_www_auth(self, init_response: httpx.Response, field_name: str) -> str | None:
"""
Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.
Extract field from WWW-Authenticate header.

Returns:
Resource metadata URL if found in WWW-Authenticate header, None otherwise
Field value if found in WWW-Authenticate header, None otherwise
"""
if not init_response or init_response.status_code != 401:
return None

www_auth_header = init_response.headers.get("WWW-Authenticate")
if not www_auth_header:
return None

# Pattern matches: resource_metadata="url" or resource_metadata=url (unquoted)
pattern = r'resource_metadata=(?:"([^"]+)"|([^\s,]+))'
# Pattern matches: field_name="value" or field_name=value (unquoted)
pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))'
match = re.search(pattern, www_auth_header)

if match:
Expand All @@ -228,6 +225,27 @@ def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response

return None

def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None:
"""
Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.

Returns:
Resource metadata URL if found in WWW-Authenticate header, None otherwise
"""
if not init_response or init_response.status_code != 401:
return None

return self._extract_field_from_www_auth(init_response, "resource_metadata")

def _extract_scope_from_www_auth(self, init_response: httpx.Response) -> str | None:
"""
Extract scope parameter from WWW-Authenticate header as per RFC6750.

Returns:
Scope string if found in WWW-Authenticate header, None otherwise
"""
return self._extract_field_from_www_auth(init_response, "scope")

async def _discover_protected_resource(self, init_response: httpx.Response) -> httpx.Request:
# RFC9728: Try to extract resource_metadata URL from WWW-Authenticate header of the initial response
url = self._extract_resource_metadata_from_www_auth(init_response)
Expand All @@ -248,8 +266,32 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
self.context.protected_resource_metadata = metadata
if metadata.authorization_servers:
self.context.auth_server_url = str(metadata.authorization_servers[0])

except ValidationError:
pass
else:
raise OAuthFlowError(f"Protected Resource Metadata request failed: {response.status_code}")

def _select_scopes(self, init_response: httpx.Response) -> None:
"""Select scopes as outlined in the 'Scope Selection Strategy in the MCP spec."""
# Per MCP spec, scope selection priority order:
# 1. Use scope from WWW-Authenticate header (if provided)
# 2. Use all scopes from PRM scopes_supported (if available)
# 3. Omit scope parameter if neither is available
#
www_authenticate_scope = self._extract_scope_from_www_auth(init_response)
if www_authenticate_scope is not None:
# Priority 1: WWW-Authenticate header scope
self.context.client_metadata.scope = www_authenticate_scope
elif (
self.context.protected_resource_metadata is not None
and self.context.protected_resource_metadata.scopes_supported is not None
):
# Priority 2: PRM scopes_supported
self.context.client_metadata.scope = " ".join(self.context.protected_resource_metadata.scopes_supported)
else:
# Priority 3: Omit scope parameter
self.context.client_metadata.scope = None

def _get_discovery_urls(self) -> list[str]:
"""Generate ordered list of (url, type) tuples for discovery attempts."""
Expand Down Expand Up @@ -478,9 +520,6 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response) -> Non
content = await response.aread()
metadata = OAuthMetadata.model_validate_json(content)
self.context.oauth_metadata = metadata
# Apply default scope if needed
if self.context.client_metadata.scope is None and metadata.scopes_supported is not None:
self.context.client_metadata.scope = " ".join(metadata.scopes_supported)

async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
"""HTTPX auth flow integration."""
Expand Down Expand Up @@ -514,7 +553,10 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
discovery_response = yield discovery_request
await self._handle_protected_resource_response(discovery_response)

# Step 2: Discover OAuth metadata (with fallback for legacy servers)
# Step 2: Apply scope selection strategy
self._select_scopes(response)

# Step 3: Discover OAuth metadata (with fallback for legacy servers)
discovery_urls = self._get_discovery_urls()
for url in discovery_urls:
oauth_metadata_request = self._create_oauth_metadata_request(url)
Expand All @@ -529,16 +571,16 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500:
break # Non-4XX error, stop trying

# Step 3: Register client if needed
# Step 4: Register client if needed
registration_request = await self._register_client()
if registration_request:
registration_response = yield registration_request
await self._handle_registration_response(registration_response)

# Step 4: Perform authorization
# Step 5: Perform authorization
auth_code, code_verifier = await self._perform_authorization()

# Step 5: Exchange authorization code for tokens
# Step 6: Exchange authorization code for tokens
token_request = await self._exchange_token(auth_code, code_verifier)
token_response = yield token_request
await self._handle_token_response(token_response)
Expand All @@ -549,3 +591,27 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
# Retry with new tokens
self._add_auth_header(request)
yield request
elif response.status_code == 403:
# Step 1: Extract error field from WWW-Authenticate header
error = self._extract_field_from_www_auth(response, "error")

# Step 2: Check if we need to step-up authorization
if error == "insufficient_scope":
try:
# Step 2a: Update the required scopes
self._select_scopes(response)

# Step 2b: Perform (re-)authorization
auth_code, code_verifier = await self._perform_authorization()

# Step 2c: Exchange authorization code for tokens
token_request = await self._exchange_token(auth_code, code_verifier)
token_response = yield token_request
await self._handle_token_response(token_response)
except Exception:
logger.exception("OAuth flow error")
raise

# Retry with new tokens
self._add_auth_header(request)
yield request
Loading
Loading