diff --git a/src/mcp/client/auth.py b/src/mcp/client/auth.py index 8bafe18eb..7016021b2 100644 --- a/src/mcp/client/auth.py +++ b/src/mcp/client/auth.py @@ -203,13 +203,24 @@ def __init__( ) self._initialized = False - async def _discover_protected_resource(self) -> httpx.Request: + def _build_well_known_path_protected_resource(self, pathname: str) -> str: + """Construct well-known path for OAuth protected resource metadata discovery.""" + well_known_path = f"/.well-known/oauth-protected-resource{pathname}" + if pathname.endswith("/"): + # Strip trailing slash from pathname to avoid double slashes + well_known_path = well_known_path[:-1] + return well_known_path + + async def _discover_protected_resource(self, is_fallback: bool = False) -> httpx.Request: """Build discovery request for protected resource metadata.""" auth_base_url = self.context.get_authorization_base_url(self.context.server_url) - url = urljoin(auth_base_url, "/.well-known/oauth-protected-resource") + auth_url_parsed = urlparse(self.context.server_url) + pathname = auth_url_parsed.path if not is_fallback else "/" + well_known_path = self._build_well_known_path_protected_resource(pathname) + url = urljoin(auth_base_url, well_known_path) return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) - async def _handle_protected_resource_response(self, response: httpx.Response) -> None: + async def _handle_protected_resource_response(self, response: httpx.Response) -> bool: """Handle discovery response.""" if response.status_code == 200: try: @@ -218,8 +229,10 @@ async def _handle_protected_resource_response(self, response: httpx.Response) -> self.context.protected_resource_metadata = metadata if metadata.authorization_servers: self.context.auth_server_url = str(metadata.authorization_servers[0]) + return True except ValidationError: pass + return False def _build_well_known_path(self, pathname: str) -> str: """Construct well-known path for OAuth metadata discovery.""" @@ -497,7 +510,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Step 1: Discover protected resource metadata (spec revision 2025-06-18) discovery_request = await self._discover_protected_resource() discovery_response = yield discovery_request - await self._handle_protected_resource_response(discovery_response) + discovery_handled = await self._handle_protected_resource_response(discovery_response) + + # If path-aware discovery failed, try fallback to root + if not discovery_handled: + discovery_request = await self._discover_protected_resource(is_fallback=True) + discovery_response = yield discovery_request + await self._handle_protected_resource_response(discovery_response) # Step 2: Discover OAuth metadata (with fallback for legacy servers) oauth_request = await self._discover_oauth_metadata() @@ -549,7 +568,13 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Step 1: Discover protected resource metadata (spec revision 2025-06-18) discovery_request = await self._discover_protected_resource() discovery_response = yield discovery_request - await self._handle_protected_resource_response(discovery_response) + discovery_handled = await self._handle_protected_resource_response(discovery_response) + + # If path-aware discovery failed, try fallback to root + if not discovery_handled: + discovery_request = await self._discover_protected_resource(is_fallback=True) + discovery_response = yield discovery_request + await self._handle_protected_resource_response(discovery_response) # Step 2: Discover OAuth metadata (with fallback for legacy servers) oauth_request = await self._discover_oauth_metadata() diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index d64687ff8..289fa6460 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -202,6 +202,17 @@ async def test_discover_protected_resource_request(self, oauth_provider): request = await oauth_provider._discover_protected_resource() assert request.method == "GET" + assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource/v1/mcp" + assert "mcp-protocol-version" in request.headers + + @pytest.mark.anyio + async def test_discover_protected_resource_request_fallback(self, oauth_provider): + """Test protected resource discovery request building after a failure to discover metadata at the + standard endpoint.""" + request = await oauth_provider._discover_protected_resource(is_fallback=True) + + assert request.method == "GET" + # Falls back to the root assert str(request.url) == "https://api.example.com/.well-known/oauth-protected-resource" assert "mcp-protocol-version" in request.headers