Skip to content
117 changes: 46 additions & 71 deletions src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,72 +251,32 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
except ValidationError:
pass

def _build_well_known_path(self, pathname: str) -> str:
"""Construct well-known path for OAuth metadata discovery."""
well_known_path = f"/.well-known/oauth-authorization-server{pathname}"
if pathname.endswith("/"):
# Strip trailing slash from pathname to avoid double slashes
well_known_path = well_known_path[:-1]
return well_known_path

def _should_attempt_fallback(self, response_status: int, pathname: str) -> bool:
"""Determine if fallback to root discovery should be attempted."""
return response_status == 404 and pathname != "/"

async def _try_metadata_discovery(self, url: str) -> httpx.Request:
"""Build metadata discovery request for a specific URL."""
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})

async def _discover_oauth_metadata(self) -> httpx.Request:
"""Build OAuth metadata discovery request with fallback support."""
if self.context.auth_server_url:
auth_server_url = self.context.auth_server_url
else:
auth_server_url = self.context.server_url

# Per RFC 8414, try path-aware discovery first
def _get_discovery_urls(self) -> list[str]:
"""Generate ordered list of (url, type) tuples for discovery attempts."""
urls: list[str] = []
auth_server_url = self.context.auth_server_url or self.context.server_url
parsed = urlparse(auth_server_url)
well_known_path = self._build_well_known_path(parsed.path)
base_url = f"{parsed.scheme}://{parsed.netloc}"
url = urljoin(base_url, well_known_path)

# Store fallback info for use in response handler
self.context.discovery_base_url = base_url
self.context.discovery_pathname = parsed.path
# RFC 8414: Path-aware OAuth discovery
if parsed.path and parsed.path != "/":
oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}"
urls.append(urljoin(base_url, oauth_path))

return await self._try_metadata_discovery(url)
# OAuth root fallback
urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server"))

async def _discover_oauth_metadata_fallback(self) -> httpx.Request:
"""Build fallback OAuth metadata discovery request for legacy servers."""
base_url = getattr(self.context, "discovery_base_url", "")
if not base_url:
raise OAuthFlowError("No base URL available for fallback discovery")

# Fallback to root discovery for legacy servers
url = urljoin(base_url, "/.well-known/oauth-authorization-server")
return await self._try_metadata_discovery(url)

async def _handle_oauth_metadata_response(self, response: httpx.Response, is_fallback: bool = False) -> bool:
"""Handle OAuth metadata response. Returns True if handled successfully."""
if response.status_code == 200:
try:
content = await response.aread()
metadata = OAuthMetadata.model_validate_json(content)
self.context.oauth_metadata = metadata
# Apply default scope if none specified
if self.context.client_metadata.scope is None and metadata.scopes_supported is not None:
self.context.client_metadata.scope = " ".join(metadata.scopes_supported)
return True
except ValidationError:
pass
# RFC 8414 section 5: Path-aware OIDC discovery
# See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
if parsed.path and parsed.path != "/":
oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
urls.append(urljoin(base_url, oidc_path))

# Check if we should attempt fallback (404 on path-aware discovery)
if not is_fallback and self._should_attempt_fallback(
response.status_code, getattr(self.context, "discovery_pathname", "/")
):
return False # Signal that fallback should be attempted
# OIDC 1.0 fallback (appends to full URL per OIDC spec)
oidc_fallback = f"{auth_server_url.rstrip('/')}/.well-known/openid-configuration"
urls.append(oidc_fallback)

return True # Signal no fallback needed (either success or non-404 error)
return urls

async def _register_client(self) -> httpx.Request | None:
"""Build registration request or skip if already registered."""
Expand Down Expand Up @@ -511,6 +471,17 @@ def _add_auth_header(self, request: httpx.Request) -> None:
if self.context.current_tokens and self.context.current_tokens.access_token:
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"

def _create_oauth_metadata_request(self, url: str) -> httpx.Request:
return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})

async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
content = await response.aread()
metadata = OAuthMetadata.model_validate_json(content)
self.context.oauth_metadata = metadata
# Apply default scope if needed
if self.context.client_metadata.scope is None and metadata.scopes_supported is not None:
self.context.client_metadata.scope = " ".join(metadata.scopes_supported)

async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
"""HTTPX auth flow integration."""
async with self.context.lock:
Expand Down Expand Up @@ -544,15 +515,19 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
await self._handle_protected_resource_response(discovery_response)

# Step 2: Discover OAuth metadata (with fallback for legacy servers)
oauth_request = await self._discover_oauth_metadata()
oauth_response = yield oauth_request
handled = await self._handle_oauth_metadata_response(oauth_response, is_fallback=False)

# If path-aware discovery failed with 404, try fallback to root
if not handled:
fallback_request = await self._discover_oauth_metadata_fallback()
fallback_response = yield fallback_request
await self._handle_oauth_metadata_response(fallback_response, is_fallback=True)
discovery_urls = self._get_discovery_urls()
for url in discovery_urls:
request = self._create_oauth_metadata_request(url)
response = yield request

if response.status_code == 200:
try:
await self._handle_oauth_metadata_response(response)
break
except ValidationError:
continue
elif response.status_code != 404:
break # Non-404 error, stop trying

# Step 3: Register client if needed
registration_request = await self._register_client()
Expand All @@ -571,6 +546,6 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
logger.exception("OAuth flow error")
raise

# Retry with new tokens
self._add_auth_header(request)
yield request
# Retry with new tokens
self._add_auth_header(request)
yield request
146 changes: 15 additions & 131 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,107 +235,30 @@ async def callback_handler() -> tuple[str, str | None]:
assert "mcp-protocol-version" in request.headers

@pytest.mark.anyio
async def test_discover_oauth_metadata_request(self, oauth_provider):
def test_create_oauth_metadata_request(self, oauth_provider):
"""Test OAuth metadata discovery request building."""
request = await oauth_provider._discover_oauth_metadata()
request = oauth_provider._create_oauth_metadata_request("https://example.com")

# Ensure correct method and headers, and that the URL is unmodified
assert request.method == "GET"
assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp"
assert "mcp-protocol-version" in request.headers

@pytest.mark.anyio
async def test_discover_oauth_metadata_request_no_path(self, client_metadata, mock_storage):
"""Test OAuth metadata discovery request building when server has no path."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = OAuthClientProvider(
server_url="https://api.example.com",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

request = await provider._discover_oauth_metadata()

assert request.method == "GET"
assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server"
assert "mcp-protocol-version" in request.headers

@pytest.mark.anyio
async def test_discover_oauth_metadata_request_trailing_slash(self, client_metadata, mock_storage):
"""Test OAuth metadata discovery request building when server path has trailing slash."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp/",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

request = await provider._discover_oauth_metadata()

assert request.method == "GET"
assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp"
assert str(request.url) == "https://example.com"
assert "mcp-protocol-version" in request.headers


class TestOAuthFallback:
"""Test OAuth discovery fallback behavior for legacy (act as AS not RS) servers."""

@pytest.mark.anyio
async def test_fallback_discovery_request(self, client_metadata, mock_storage):
"""Test fallback discovery request building."""

async def redirect_handler(url: str) -> None:
pass

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state"

provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

# Set up discovery state manually as if path-aware discovery was attempted
provider.context.discovery_base_url = "https://api.example.com"
provider.context.discovery_pathname = "/v1/mcp"
async def test_oauth_discovery_fallback_order(self, oauth_provider):
"""Test fallback URL construction order."""
discovery_urls = oauth_provider._get_discovery_urls()

# Test fallback request building
request = await provider._discover_oauth_metadata_fallback()

assert request.method == "GET"
assert str(request.url) == "https://api.example.com/.well-known/oauth-authorization-server"
assert "mcp-protocol-version" in request.headers

@pytest.mark.anyio
async def test_should_attempt_fallback(self, oauth_provider):
"""Test fallback decision logic."""
# Should attempt fallback on 404 with non-root path
assert oauth_provider._should_attempt_fallback(404, "/v1/mcp")

# Should NOT attempt fallback on 404 with root path
assert not oauth_provider._should_attempt_fallback(404, "/")

# Should NOT attempt fallback on other status codes
assert not oauth_provider._should_attempt_fallback(200, "/v1/mcp")
assert not oauth_provider._should_attempt_fallback(500, "/v1/mcp")
assert discovery_urls == [
"https://api.example.com/.well-known/oauth-authorization-server/v1/mcp",
"https://api.example.com/.well-known/oauth-authorization-server",
"https://api.example.com/.well-known/openid-configuration/v1/mcp",
"https://api.example.com/v1/mcp/.well-known/openid-configuration",
]

@pytest.mark.anyio
async def test_handle_metadata_response_success(self, oauth_provider):
Expand All @@ -348,50 +271,11 @@ async def test_handle_metadata_response_success(self, oauth_provider):
}"""
response = httpx.Response(200, content=content)

# Should return True (success) and set metadata
result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False)
assert result is True
# Should set metadata
await oauth_provider._handle_oauth_metadata_response(response)
assert oauth_provider.context.oauth_metadata is not None
assert str(oauth_provider.context.oauth_metadata.issuer) == "https://auth.example.com/"

@pytest.mark.anyio
async def test_handle_metadata_response_404_needs_fallback(self, oauth_provider):
"""Test 404 response handling that should trigger fallback."""
# Set up discovery state for non-root path
oauth_provider.context.discovery_base_url = "https://api.example.com"
oauth_provider.context.discovery_pathname = "/v1/mcp"

# Mock 404 response
response = httpx.Response(404)

# Should return False (needs fallback)
result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False)
assert result is False

@pytest.mark.anyio
async def test_handle_metadata_response_404_no_fallback_needed(self, oauth_provider):
"""Test 404 response handling when no fallback is needed."""
# Set up discovery state for root path
oauth_provider.context.discovery_base_url = "https://api.example.com"
oauth_provider.context.discovery_pathname = "/"

# Mock 404 response
response = httpx.Response(404)

# Should return True (no fallback needed)
result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=False)
assert result is True

@pytest.mark.anyio
async def test_handle_metadata_response_404_fallback_attempt(self, oauth_provider):
"""Test 404 response handling during fallback attempt."""
# Mock 404 response during fallback
response = httpx.Response(404)

# Should return True (fallback attempt complete, no further action needed)
result = await oauth_provider._handle_oauth_metadata_response(response, is_fallback=True)
assert result is True

@pytest.mark.anyio
async def test_register_client_request(self, oauth_provider):
"""Test client registration request building."""
Expand Down
39 changes: 39 additions & 0 deletions tests/shared/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Tests for OAuth 2.0 shared code."""

from mcp.shared.auth import OAuthMetadata


class TestOAuthMetadata:
"""Tests for OAuthMetadata parsing."""

def test_oauth(self):
"""Should not throw when parsing OAuth metadata."""
OAuthMetadata.model_validate(
{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/oauth2/authorize",
"token_endpoint": "https://example.com/oauth2/token",
"scopes_supported": ["read", "write"],
"response_types_supported": ["code", "token"],
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
}
)

def test_oidc(self):
"""Should not throw when parsing OIDC metadata."""
OAuthMetadata.model_validate(
{
"issuer": "https://example.com",
"authorization_endpoint": "https://example.com/oauth2/authorize",
"token_endpoint": "https://example.com/oauth2/token",
"end_session_endpoint": "https://example.com/logout",
"id_token_signing_alg_values_supported": ["RS256"],
"jwks_uri": "https://example.com/.well-known/jwks.json",
"response_types_supported": ["code", "token"],
"revocation_endpoint": "https://example.com/oauth2/revoke",
"scopes_supported": ["openid", "read", "write"],
"subject_types_supported": ["public"],
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
"userinfo_endpoint": "https://example.com/oauth2/userInfo",
}
)
Loading