Skip to content

Commit 603fdc1

Browse files
committed
Enhance OAuth metadata discovery in client
- Added a new method to discover OAuth Protected Resource Metadata, improving the handling of authorization server URLs. - Updated the OAuthClientProvider to utilize the discovered protected resource metadata when fetching OAuth metadata. - Refactored tests to validate the new discovery logic and ensure correct URL calls for protected resource and authorization server metadata. Signed-off-by: Xin Fu <[email protected]>
1 parent c9f5df7 commit 603fdc1

File tree

3 files changed

+123
-16
lines changed

3 files changed

+123
-16
lines changed

src/mcp/client/auth.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import time
1313
from collections.abc import AsyncGenerator, Awaitable, Callable
1414
from typing import Protocol
15-
from urllib.parse import urlencode, urljoin
15+
from urllib.parse import urlencode, urljoin, urlparse, urlunparse
1616

1717
import anyio
1818
import httpx
@@ -21,6 +21,7 @@
2121
OAuthClientInformationFull,
2222
OAuthClientMetadata,
2323
OAuthMetadata,
24+
OAuthProtectedResourceMetadata,
2425
OAuthToken,
2526
)
2627
from mcp.types import LATEST_PROTOCOL_VERSION
@@ -116,19 +117,59 @@ def _get_authorization_base_url(self, server_url: str) -> str:
116117
117118
Per MCP spec 2.3.2: https://api.example.com/v1/mcp -> https://api.example.com
118119
"""
119-
from urllib.parse import urlparse, urlunparse
120-
121120
parsed = urlparse(server_url)
122-
# Remove path component
123121
return urlunparse((parsed.scheme, parsed.netloc, "", "", "", ""))
124122

123+
async def _discover_protected_resource_metadata(
124+
self, resource_server_url: str
125+
) -> OAuthProtectedResourceMetadata | None:
126+
"""
127+
Looks up RFC 9728 OAuth 2.0 Protected Resource Metadata.
128+
129+
If the server returns a 404 for the well-known endpoint, returns None.
130+
"""
131+
async with httpx.AsyncClient() as client:
132+
response = await client.get(
133+
urljoin(resource_server_url, "/.well-known/oauth-protected-resource")
134+
)
135+
if response.status_code == 404:
136+
return None
137+
response.raise_for_status()
138+
metadata_json = response.json()
139+
logger.debug(
140+
f"OAuth protected resource metadata discovered: {metadata_json}"
141+
)
142+
return OAuthProtectedResourceMetadata.model_validate(metadata_json)
143+
125144
async def _discover_oauth_metadata(self, server_url: str) -> OAuthMetadata | None:
126145
"""
127146
Discover OAuth metadata from server's well-known endpoint.
147+
148+
First tries to discover protected resource metadata and use its authorization
149+
server URL if available, otherwise falls back to the server's own well-known.
128150
"""
129-
# Extract base URL per MCP spec
130-
auth_base_url = self._get_authorization_base_url(server_url)
131-
url = urljoin(auth_base_url, "/.well-known/oauth-authorization-server")
151+
auth_server_url = self._get_authorization_base_url(server_url)
152+
153+
try:
154+
protected_resource_metadata = (
155+
await self._discover_protected_resource_metadata(server_url)
156+
)
157+
158+
if (
159+
protected_resource_metadata
160+
and protected_resource_metadata.authorization_servers
161+
and len(protected_resource_metadata.authorization_servers) > 0
162+
):
163+
auth_server_url = str(
164+
protected_resource_metadata.authorization_servers[0]
165+
)
166+
except Exception as e:
167+
logger.warning(
168+
"Could not load OAuth Protected Resource metadata, "
169+
f"falling back to /.well-known/oauth-authorization-server: {e}"
170+
)
171+
172+
url = urljoin(auth_server_url, "/.well-known/oauth-authorization-server")
132173
headers = {"MCP-Protocol-Version": LATEST_PROTOCOL_VERSION}
133174

134175
async with httpx.AsyncClient() as client:

src/mcp/server/auth/routes.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,8 @@ def build_metadata(
172172
client_registration_options: ClientRegistrationOptions,
173173
revocation_options: RevocationOptions,
174174
) -> OAuthMetadata:
175-
authorization_url = AnyHttpUrl(
176-
str(issuer_url).rstrip("/") + AUTHORIZATION_PATH
177-
)
178-
token_url = AnyHttpUrl(
179-
str(issuer_url).rstrip("/") + TOKEN_PATH
180-
)
175+
authorization_url = AnyHttpUrl(str(issuer_url).rstrip("/") + AUTHORIZATION_PATH)
176+
token_url = AnyHttpUrl(str(issuer_url).rstrip("/") + TOKEN_PATH)
181177

182178
# Create metadata
183179
metadata = OAuthMetadata(

tests/client/test_auth.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
OAuthClientInformationFull,
2121
OAuthClientMetadata,
2222
OAuthMetadata,
23+
OAuthProtectedResourceMetadata,
2324
OAuthToken,
2425
)
2526

@@ -74,6 +75,16 @@ def oauth_metadata():
7475
)
7576

7677

78+
@pytest.fixture
79+
def oauth_protected_resource_metadata():
80+
return OAuthProtectedResourceMetadata(
81+
resource="https://api.example.com/v1/mcp",
82+
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
83+
scopes_supported=["read", "write"],
84+
bearer_methods_supported=["header"],
85+
)
86+
87+
7788
@pytest.fixture
7889
def oauth_client_info():
7990
return OAuthClientInformationFull(
@@ -210,10 +221,13 @@ async def test_discover_oauth_metadata_success(
210221
assert result.token_endpoint == oauth_metadata.token_endpoint
211222

212223
# Verify correct URL was called
213-
mock_client.get.assert_called_once()
214-
call_args = mock_client.get.call_args[0]
224+
assert mock_client.get.call_count == 2
215225
assert (
216-
call_args[0]
226+
mock_client.get.call_args_list[0][0][0]
227+
== "https://api.example.com/.well-known/oauth-protected-resource"
228+
)
229+
assert (
230+
mock_client.get.call_args_list[1][0][0]
217231
== "https://api.example.com/.well-known/oauth-authorization-server"
218232
)
219233

@@ -262,6 +276,62 @@ async def test_discover_oauth_metadata_cors_fallback(
262276
assert result is not None
263277
assert mock_client.get.call_count == 2
264278

279+
@pytest.mark.anyio
280+
async def test_discover_oauth_metadata_from_protected_resource(
281+
self, oauth_provider, oauth_metadata, oauth_protected_resource_metadata
282+
):
283+
"""Test OAuth metadata discovery using protected resource metadata."""
284+
protected_resource_response = oauth_protected_resource_metadata.model_dump(
285+
by_alias=True, mode="json"
286+
)
287+
oauth_metadata_response = oauth_metadata.model_dump(by_alias=True, mode="json")
288+
289+
with patch("httpx.AsyncClient") as mock_client_class:
290+
mock_client = AsyncMock()
291+
mock_client_class.return_value.__aenter__.return_value = mock_client
292+
293+
# First call returns protected resource metadata
294+
protected_resource_mock = Mock()
295+
protected_resource_mock.status_code = 200
296+
protected_resource_mock.json.return_value = protected_resource_response
297+
298+
# Second call returns OAuth metadata from authorization server
299+
oauth_metadata_mock = Mock()
300+
oauth_metadata_mock.status_code = 200
301+
oauth_metadata_mock.json.return_value = oauth_metadata_response
302+
303+
mock_client.get.side_effect = [
304+
protected_resource_mock, # Protected resource metadata call
305+
oauth_metadata_mock, # OAuth metadata from auth server call
306+
]
307+
308+
result = await oauth_provider._discover_oauth_metadata(
309+
"https://api.example.com/v1/mcp"
310+
)
311+
312+
assert result is not None
313+
assert (
314+
result.authorization_endpoint == oauth_metadata.authorization_endpoint
315+
)
316+
assert result.token_endpoint == oauth_metadata.token_endpoint
317+
318+
# Verify correct URLs were called in order
319+
assert mock_client.get.call_count == 2
320+
321+
# First call should be to protected resource metadata endpoint
322+
first_call_args = mock_client.get.call_args_list[0][0]
323+
assert (
324+
first_call_args[0]
325+
== "https://api.example.com/.well-known/oauth-protected-resource"
326+
)
327+
328+
# Second call should be to authorization server's OAuth metadata endpoint
329+
second_call_args = mock_client.get.call_args_list[1][0]
330+
assert (
331+
second_call_args[0]
332+
== "https://auth.example.com/.well-known/oauth-authorization-server"
333+
)
334+
265335
@pytest.mark.anyio
266336
async def test_register_oauth_client_success(
267337
self, oauth_provider, oauth_metadata, oauth_client_info

0 commit comments

Comments
 (0)