Skip to content
131 changes: 98 additions & 33 deletions src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ def should_include_resource_param(self, protocol_version: str | None = None) ->
return protocol_version >= "2025-06-18"


OAuthDiscoveryStack = list[Callable[[], Awaitable[httpx.Request]]]


class OAuthClientProvider(httpx.Auth):
"""
OAuth2 authentication for httpx.
Expand Down Expand Up @@ -251,32 +254,60 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
except ValidationError:
pass

def _build_well_known_path(self, pathname: str) -> str:
def _build_well_known_path(self, pathname: str, well_known_endpoint: str) -> str:
"""Construct well-known path for OAuth metadata discovery."""
well_known_path = f"/.well-known/oauth-authorization-server{pathname}"
well_known_path = f"/.well-known/{well_known_endpoint}{pathname}"
if pathname.endswith("/"):
# Strip trailing slash from pathname to avoid double slashes
well_known_path = well_known_path[:-1]
return well_known_path

def _should_attempt_fallback(self, response_status: int, pathname: str) -> bool:
"""Determine if fallback to root discovery should be attempted."""
return response_status == 404 and pathname != "/"
def _build_well_known_fallback_url(self, well_known_endpoint: str) -> str:
"""Construct fallback well-known URL for OAuth metadata discovery in legacy servers."""
base_url = getattr(self.context, "discovery_base_url", "")
if not base_url:
raise OAuthFlowError("No base URL available for fallback discovery")

# Fallback to root discovery for legacy servers
return urljoin(base_url, f"/.well-known/{well_known_endpoint}")

def _build_oidc_fallback_path(self, pathname: str, well_known_endpoint: str) -> str:
"""Construct fallback well-known path for OIDC metadata discovery in legacy servers."""
# Strip trailing slash from pathname to avoid double slashes
clean_pathname = pathname[:-1] if pathname.endswith("/") else pathname
# OIDC 1.0 appends the well-known path to the full AS URL
return f"{clean_pathname}/.well-known/{well_known_endpoint}"

def _build_oidc_fallback_url(self, well_known_endpoint: str) -> str:
"""Construct fallback well-known URL for OIDC metadata discovery in legacy servers."""
if self.context.auth_server_url:
auth_server_url = self.context.auth_server_url
else:
auth_server_url = self.context.server_url

parsed = urlparse(auth_server_url)
well_known_path = self._build_oidc_fallback_path(parsed.path, well_known_endpoint)
base_url = f"{parsed.scheme}://{parsed.netloc}"
return urljoin(base_url, well_known_path)

def _should_attempt_fallback(self, response_status: int, discovery_stack: OAuthDiscoveryStack) -> bool:
"""Determine if further fallback should be attempted."""
return response_status == 404 and len(discovery_stack) > 0

async def _try_metadata_discovery(self, url: str) -> httpx.Request:
"""Build metadata discovery request for a specific URL."""
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})

async def _discover_oauth_metadata(self) -> httpx.Request:
"""Build OAuth metadata discovery request with fallback support."""
async def _discover_well_known_metadata(self, well_known_endpoint: str) -> httpx.Request:
"""Build .well-known metadata discovery request with fallback support."""
if self.context.auth_server_url:
auth_server_url = self.context.auth_server_url
else:
auth_server_url = self.context.server_url

# Per RFC 8414, try path-aware discovery first
parsed = urlparse(auth_server_url)
well_known_path = self._build_well_known_path(parsed.path)
well_known_path = self._build_well_known_path(parsed.path, well_known_endpoint)
base_url = f"{parsed.scheme}://{parsed.netloc}"
url = urljoin(base_url, well_known_path)

Expand All @@ -286,17 +317,37 @@ async def _discover_oauth_metadata(self) -> httpx.Request:

return await self._try_metadata_discovery(url)

async def _discover_well_known_metadata_fallback(self, well_known_endpoint: str) -> httpx.Request:
"""Build fallback OAuth metadata discovery request for legacy servers."""
url = self._build_well_known_fallback_url(well_known_endpoint)
return await self._try_metadata_discovery(url)

async def _discover_oauth_metadata(self) -> httpx.Request:
"""Build OAuth metadata discovery request with fallback support."""
return await self._discover_well_known_metadata("oauth-authorization-server")

async def _discover_oauth_metadata_fallback(self) -> httpx.Request:
"""Build fallback OAuth metadata discovery request for legacy servers."""
base_url = getattr(self.context, "discovery_base_url", "")
if not base_url:
raise OAuthFlowError("No base URL available for fallback discovery")
return await self._discover_well_known_metadata_fallback("oauth-authorization-server")

# Fallback to root discovery for legacy servers
url = urljoin(base_url, "/.well-known/oauth-authorization-server")
async def _discover_oidc_metadata(self) -> httpx.Request:
"""
Build fallback OIDC metadata discovery request.
See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
"""
return await self._discover_well_known_metadata("openid-configuration")

async def _discover_oidc_metadata_fallback(self) -> httpx.Request:
"""
Build fallback OIDC metadata discovery request for legacy servers.
See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
"""
url = self._build_oidc_fallback_url("openid-configuration")
return await self._try_metadata_discovery(url)

async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fallback: bool = False) -> bool:
async def _handle_oauth_metadata_response(
self, response: httpx.Response, discovery_stack: OAuthDiscoveryStack
) -> bool:
"""Handle OAuth metadata response. Returns True if handled successfully."""
if response.status_code == 200:
try:
Expand All @@ -310,13 +361,10 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fal
except ValidationError:
pass

# Check if we should attempt fallback (404 on path-aware discovery)
if not is_fallback and self._should_attempt_fallback(
response.status_code, getattr(self.context, "discovery_pathname", "/")
):
return False # Signal that fallback should be attempted

return True # Signal no fallback needed (either success or non-404 error)
# Check if we should attempt fallback
# True: No fallback needed (either success or non-404 error)
# False: Signal that fallback should be attempted
return not self._should_attempt_fallback(response.status_code, discovery_stack)

async def _register_client(self) -> httpx.Request | None:
"""Build registration request or skip if already registered."""
Expand Down Expand Up @@ -511,6 +559,26 @@ def _add_auth_header(self, request: httpx.Request) -> None:
if self.context.current_tokens and self.context.current_tokens.access_token:
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"

def _create_oauth_discovery_stack(self) -> OAuthDiscoveryStack:
"""Create a stack of attempts to discover OAuth metadata."""
discovery_attempts: OAuthDiscoveryStack = [
# Start with path-aware OAuth discovery
self._discover_oauth_metadata,
# If path-aware discovery fails with 404, try fallback to root
self._discover_oauth_metadata_fallback,
# If root discovery fails with 404, fall back to OIDC 1.0 following
# RFC 8414 path-aware semantics (see RFC 8414 section 5)
self._discover_oidc_metadata,
# If path-aware OIDC discovery failed with 404, fall back to OIDC 1.0
# following OIDC 1.0 semantics (see RFC 8414 section 5)
self._discover_oidc_metadata_fallback,
]

# Reverse the list so we can call pop() without remembering we declared
# this stack backwards for readability
discovery_attempts.reverse()
return discovery_attempts

async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
"""HTTPX auth flow integration."""
async with self.context.lock:
Expand Down Expand Up @@ -544,15 +612,12 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
await self._handle_protected_resource_response(discovery_response)

# Step 2: Discover OAuth metadata (with fallback for legacy servers)
oauth_request = await self._discover_oauth_metadata()
oauth_response = yield oauth_request
handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False)

# If path-aware discovery failed with 404, try fallback to root
if not handled:
fallback_request = await self._discover_oauth_metadata_fallback()
fallback_response = yield fallback_request
await self._handle_oauth_metadata_response(fallback_response, is_fallback=True)
oauth_discovery_stack = self._create_oauth_discovery_stack()
while len(oauth_discovery_stack) > 0:
oauth_discovery = oauth_discovery_stack.pop()
oauth_request = await oauth_discovery()
oauth_response = yield oauth_request
await self._handle_oauth_metadata_response(oauth_response, oauth_discovery_stack)

# Step 3: Register client if needed
registration_request = await self._register_client()
Expand All @@ -571,6 +636,6 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
logger.exception("OAuth flow error")
raise

# Retry with new tokens
self._add_auth_header(request)
yield request
# Retry with new tokens
self._add_auth_header(request)
yield request
Loading
Loading