Skip to content

Commit ee63405

Browse files
authored
Merge pull request #2507 from jlowin/sdk-auth-updates
[2.14] Update for MCP SDK auth changes
2 parents 54692c3 + e3b103d commit ee63405

File tree

9 files changed

+96
-83
lines changed

9 files changed

+96
-83
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ dependencies = [
77
"python-dotenv>=1.1.0",
88
"exceptiongroup>=1.2.2",
99
"httpx>=0.28.1",
10-
"mcp>=1.19.0,<2.0.0,!=1.21.1",
10+
"mcp>=1.23.1",
1111
"openapi-pydantic>=0.5.1",
1212
"platformdirs>=4.0.0",
1313
"rich>=13.9.4",

src/fastmcp/server/auth/oauth_proxy.py

Lines changed: 38 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from key_value.aio.protocols import AsyncKeyValue
3737
from key_value.aio.stores.disk import DiskStore
3838
from key_value.aio.wrappers.encryption import FernetEncryptionWrapper
39-
from mcp.server.auth.handlers.token import TokenErrorResponse, TokenSuccessResponse
39+
from mcp.server.auth.handlers.token import TokenErrorResponse
4040
from mcp.server.auth.handlers.token import TokenHandler as _SDKTokenHandler
4141
from mcp.server.auth.json_response import PydanticJSONResponse
4242
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
@@ -522,50 +522,43 @@ def create_error_html(
522522
class TokenHandler(_SDKTokenHandler):
523523
"""TokenHandler that returns OAuth 2.1 compliant error responses.
524524
525-
The MCP SDK always returns HTTP 400 for all client authentication issues.
526-
However, OAuth 2.1 Section 5.3 and the MCP specification require that
527-
invalid or expired tokens MUST receive a HTTP 401 response.
525+
The MCP SDK returns `unauthorized_client` for client authentication failures.
526+
However, per RFC 6749 Section 5.2, authentication failures should return
527+
`invalid_client` with HTTP 401, not `unauthorized_client`.
528528
529-
This handler extends the base MCP SDK TokenHandler to transform client
530-
authentication failures into OAuth 2.1 compliant responses:
531-
- Changes 'unauthorized_client' to 'invalid_client' error code
532-
- Returns HTTP 401 status code instead of 400 for client auth failures
529+
This distinction matters: `unauthorized_client` means "client exists but
530+
can't do this", while `invalid_client` means "client doesn't exist or
531+
credentials are wrong". Claude's OAuth client uses this to decide whether
532+
to re-register.
533533
534-
Per OAuth 2.1 Section 5.3: "The authorization server MAY return an HTTP 401
535-
(Unauthorized) status code to indicate which HTTP authentication schemes
536-
are supported."
537-
538-
Per MCP spec: "Invalid or expired tokens MUST receive a HTTP 401 response."
534+
This handler transforms 401 responses with `unauthorized_client` to use
535+
`invalid_client` instead, making the error semantics correct per OAuth spec.
539536
"""
540537

541-
def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
542-
"""Override response method to provide OAuth 2.1 compliant error handling."""
543-
# Check if this is a client authentication failure (not just unauthorized for grant type)
544-
# unauthorized_client can mean two things:
545-
# 1. Client authentication failed (client_id not found or wrong credentials) -> invalid_client 401
546-
# 2. Client not authorized for this grant type -> unauthorized_client 400 (correct per spec)
547-
if (
548-
isinstance(obj, TokenErrorResponse)
549-
and obj.error == "unauthorized_client"
550-
and obj.error_description
551-
and "Invalid client_id" in obj.error_description
552-
):
553-
# Transform client auth failure to OAuth 2.1 compliant response
554-
return PydanticJSONResponse(
555-
content=TokenErrorResponse(
556-
error="invalid_client",
557-
error_description=obj.error_description,
558-
error_uri=obj.error_uri,
559-
),
560-
status_code=401,
561-
headers={
562-
"Cache-Control": "no-store",
563-
"Pragma": "no-cache",
564-
},
565-
)
538+
async def handle(self, request: Any):
539+
"""Wrap SDK handle() and transform auth error responses."""
540+
response = await super().handle(request)
566541

567-
# Otherwise use default behavior from parent class
568-
return super().response(obj)
542+
# Transform 401 unauthorized_client -> invalid_client
543+
if response.status_code == 401:
544+
try:
545+
body = json.loads(response.body)
546+
if body.get("error") == "unauthorized_client":
547+
return PydanticJSONResponse(
548+
content=TokenErrorResponse(
549+
error="invalid_client",
550+
error_description=body.get("error_description"),
551+
),
552+
status_code=401,
553+
headers={
554+
"Cache-Control": "no-store",
555+
"Pragma": "no-cache",
556+
},
557+
)
558+
except (json.JSONDecodeError, AttributeError):
559+
pass # Not JSON or unexpected format, return as-is
560+
561+
return response
569562

570563

571564
class OAuthProxy(OAuthProvider):
@@ -993,9 +986,13 @@ async def register_client(self, client_info: OAuthClientInformationFull) -> None
993986
# Create a ProxyDCRClient with configured redirect URI validation
994987
if client_info.client_id is None:
995988
raise ValueError("client_id is required for client registration")
989+
# We use token_endpoint_auth_method="none" because the proxy handles
990+
# all upstream authentication. The client_secret must also be None
991+
# because the SDK requires secrets to be provided if they're set,
992+
# regardless of auth method.
996993
proxy_client: ProxyDCRClient = ProxyDCRClient(
997994
client_id=client_info.client_id,
998-
client_secret=client_info.client_secret,
995+
client_secret=None,
999996
redirect_uris=client_info.redirect_uris or [AnyUrl("http://localhost")],
1000997
grant_types=client_info.grant_types
1001998
or ["authorization_code", "refresh_token"],

src/fastmcp/server/context.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,16 @@
1818
from mcp.server.lowlevel.server import request_ctx
1919
from mcp.shared.context import RequestContext
2020
from mcp.types import (
21-
AudioContent,
2221
ClientCapabilities,
2322
CreateMessageResult,
2423
GetPromptResult,
25-
ImageContent,
2624
IncludeContext,
2725
ModelHint,
2826
ModelPreferences,
2927
Root,
3028
SamplingCapability,
3129
SamplingMessage,
30+
SamplingMessageContentBlock,
3231
TextContent,
3332
)
3433
from mcp.types import CreateMessageRequestParams as SamplingParams
@@ -59,6 +58,7 @@
5958

6059

6160
T = TypeVar("T", default=Any)
61+
6262
_current_context: ContextVar[Context | None] = ContextVar("context", default=None) # type: ignore[assignment]
6363
_flush_lock = anyio.Lock()
6464

@@ -479,7 +479,7 @@ async def sample(
479479
temperature: float | None = None,
480480
max_tokens: int | None = None,
481481
model_preferences: ModelPreferences | str | list[str] | None = None,
482-
) -> TextContent | ImageContent | AudioContent:
482+
) -> SamplingMessageContentBlock | list[SamplingMessageContentBlock]:
483483
"""
484484
Send a sampling request to the client and await the response.
485485

tests/client/test_sampling.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def sampling_handler(
149149
"annotations": None,
150150
"_meta": None,
151151
},
152+
"_meta": None,
152153
},
153154
{
154155
"role": "user",
@@ -159,5 +160,6 @@ def sampling_handler(
159160
"annotations": None,
160161
"_meta": None,
161162
},
163+
"_meta": None,
162164
},
163165
]

tests/server/auth/test_oauth_proxy.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,8 @@ async def test_register_client(self, oauth_proxy):
420420
stored = await oauth_proxy.get_client("original-client")
421421
assert stored is not None
422422
assert stored.client_id == "original-client"
423-
assert stored.client_secret == "original-secret"
423+
# Proxy uses token_endpoint_auth_method="none", so client_secret is not stored
424+
assert stored.client_secret is None
424425

425426
async def test_get_registered_client(self, oauth_proxy):
426427
"""Test retrieving a registered client."""
@@ -1247,29 +1248,36 @@ async def test_token_endpoint_invalid_client_error(self, jwt_verifier):
12471248
class TestTokenHandlerErrorTransformation:
12481249
"""Tests for TokenHandler's OAuth 2.1 compliant error transformation."""
12491250

1250-
def test_transforms_client_auth_failure_to_invalid_client_401(self):
1251+
async def test_transforms_client_auth_failure_to_invalid_client_401(self):
12511252
"""Test that client authentication failures return invalid_client with 401."""
1252-
from mcp.server.auth.handlers.token import TokenErrorResponse
1253+
from unittest.mock import AsyncMock, patch
1254+
1255+
from mcp.server.auth.handlers.token import TokenHandler as SDKTokenHandler
12531256

12541257
from fastmcp.server.auth.oauth_proxy import TokenHandler
12551258

12561259
handler = TokenHandler(provider=Mock(), client_authenticator=Mock())
12571260

1258-
# Simulate error from ClientAuthenticator.authenticate() failure
1259-
error_response = TokenErrorResponse(
1260-
error="unauthorized_client",
1261-
error_description="Invalid client_id 'test-client-id'",
1261+
# Create a mock 401 response like the SDK returns for auth failures
1262+
mock_response = Mock()
1263+
mock_response.status_code = 401
1264+
mock_response.body = (
1265+
b'{"error":"unauthorized_client","error_description":"Invalid client_id"}'
12621266
)
12631267

1264-
response = handler.response(error_response)
1268+
# Patch the parent class's handle() to return our mock response
1269+
with patch.object(
1270+
SDKTokenHandler,
1271+
"handle",
1272+
new_callable=AsyncMock,
1273+
return_value=mock_response,
1274+
):
1275+
response = await handler.handle(Mock())
12651276

12661277
# Should transform to OAuth 2.1 compliant response
12671278
assert response.status_code == 401
12681279
assert b'"error":"invalid_client"' in response.body
1269-
assert (
1270-
b'"error_description":"Invalid client_id \'test-client-id\'"'
1271-
in response.body
1272-
)
1280+
assert b'"error_description":"Invalid client_id"' in response.body
12731281

12741282
def test_does_not_transform_grant_type_unauthorized_to_invalid_client(self):
12751283
"""Test that grant type authorization errors stay as unauthorized_client with 400."""

tests/server/auth/test_oauth_proxy_storage.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ async def test_register_and_get_client(self, jwt_verifier, temp_storage):
7575
client = await proxy.get_client("test-client-123")
7676
assert client is not None
7777
assert client.client_id == "test-client-123"
78-
assert client.client_secret == "secret-456"
78+
# Proxy uses token_endpoint_auth_method="none", so client_secret is not stored
79+
assert client.client_secret is None
7980
assert client.scope == "read write"
8081

8182
async def test_client_persists_across_proxy_instances(
@@ -96,7 +97,8 @@ async def test_client_persists_across_proxy_instances(
9697
proxy2 = self.create_proxy(jwt_verifier, storage=temp_storage)
9798
client = await proxy2.get_client("persistent-client")
9899
assert client is not None
99-
assert client.client_secret == "persistent-secret"
100+
# Proxy uses token_endpoint_auth_method="none", so client_secret is not stored
101+
assert client.client_secret is None
100102
assert client.scope == "openid profile"
101103

102104
async def test_nonexistent_client_returns_none(
@@ -199,7 +201,7 @@ async def test_storage_data_structure(self, jwt_verifier, temp_storage):
199201
"software_id": None,
200202
"software_version": None,
201203
"client_id": "structured-client",
202-
"client_secret": "secret",
204+
"client_secret": None,
203205
"client_id_issued_at": None,
204206
"client_secret_expires_at": None,
205207
"allowed_redirect_uri_patterns": None,

tests/server/middleware/test_logging.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def test_create_message_with_payloads(
144144
"event": "request_start",
145145
"source": "client",
146146
"method": "test_method",
147-
"payload": '{"method":"tools/call","params":{"_meta":null,"name":"test_method","arguments":{"param":"value"}}}',
147+
"payload": '{"method":"tools/call","params":{"task":null,"_meta":null,"name":"test_method","arguments":{"param":"value"}}}',
148148
"payload_type": "CallToolRequest",
149149
}
150150
)
@@ -159,7 +159,7 @@ def test_calculate_response_size(self, mock_context: MiddlewareContext[Any]):
159159
"event": "request_start",
160160
"source": "client",
161161
"method": "test_method",
162-
"payload_length": 98,
162+
"payload_length": 110,
163163
}
164164
)
165165

@@ -177,8 +177,8 @@ def test_calculate_response_size_with_token_estimation(
177177
"event": "request_start",
178178
"source": "client",
179179
"method": "test_method",
180-
"payload_tokens": 24,
181-
"payload_length": 98,
180+
"payload_tokens": 27,
181+
"payload_length": 110,
182182
}
183183
)
184184

@@ -303,7 +303,7 @@ async def test_on_message_with_pydantic_types_in_payload(
303303

304304
assert get_log_lines(caplog) == snapshot(
305305
[
306-
'{"event": "request_start", "method": "test_method", "source": "client", "payload": "{\\"method\\":\\"resources/read\\",\\"params\\":{\\"_meta\\":null,\\"uri\\":\\"test://example/1\\"}}", "payload_type": "ReadResourceRequest"}',
306+
'{"event": "request_start", "method": "test_method", "source": "client", "payload": "{\\"method\\":\\"resources/read\\",\\"params\\":{\\"task\\":null,\\"_meta\\":null,\\"uri\\":\\"test://example/1\\"}}", "payload_type": "ReadResourceRequest"}',
307307
'{"event": "request_success", "method": "test_method", "source": "client", "duration_ms": 0.02}',
308308
]
309309
)
@@ -365,7 +365,7 @@ def __str__(self) -> str:
365365

366366
assert get_log_lines(caplog) == snapshot(
367367
[
368-
'{"event": "request_start", "method": "test_method", "source": "client", "payload": "{\\"method\\":\\"tools/call\\",\\"params\\":{\\"_meta\\":null,\\"name\\":\\"test_method\\",\\"arguments\\":{\\"obj\\":\\"NON_SERIALIZABLE\\"}}}", "payload_type": "CallToolRequest"}',
368+
'{"event": "request_start", "method": "test_method", "source": "client", "payload": "{\\"method\\":\\"tools/call\\",\\"params\\":{\\"task\\":null,\\"_meta\\":null,\\"name\\":\\"test_method\\",\\"arguments\\":{\\"obj\\":\\"NON_SERIALIZABLE\\"}}}", "payload_type": "CallToolRequest"}',
369369
'{"event": "request_success", "method": "test_method", "source": "client", "duration_ms": 0.02}',
370370
]
371371
)
@@ -546,7 +546,7 @@ async def test_logging_middleware_with_payloads(
546546

547547
assert get_log_lines(caplog) == snapshot(
548548
[
549-
'event=request_start method=tools/call source=client payload={"_meta":null,"name":"simple_operation","arguments":{"data":"payload_test"}} payload_type=CallToolRequestParams',
549+
'event=request_start method=tools/call source=client payload={"task":null,"_meta":null,"name":"simple_operation","arguments":{"data":"payload_test"}} payload_type=CallToolRequestParams',
550550
"event=request_success method=tools/call source=client duration_ms=0.02",
551551
]
552552
)
@@ -570,7 +570,7 @@ async def test_structured_logging_middleware_produces_json(
570570

571571
assert get_log_lines(caplog) == snapshot(
572572
[
573-
'{"event": "request_start", "method": "tools/call", "source": "client", "payload": "{\\"_meta\\":null,\\"name\\":\\"simple_operation\\",\\"arguments\\":{\\"data\\":\\"json_test\\"}}", "payload_type": "CallToolRequestParams"}',
573+
'{"event": "request_start", "method": "tools/call", "source": "client", "payload": "{\\"task\\":null,\\"_meta\\":null,\\"name\\":\\"simple_operation\\",\\"arguments\\":{\\"data\\":\\"json_test\\"}}", "payload_type": "CallToolRequestParams"}',
574574
'{"event": "request_success", "method": "tools/call", "source": "client", "duration_ms": 0.02}',
575575
]
576576
)
@@ -665,6 +665,6 @@ async def test_logging_middleware_custom_configuration(
665665
# Check that our custom logger captured the logs
666666
log_output = log_buffer.getvalue()
667667
assert log_output == snapshot("""\
668-
event=request_start method=tools/call source=client payload={"_meta":null,"name":"simple_operation","arguments":{"data":"custom_test"}} payload_type=CallToolRequestParams
668+
event=request_start method=tools/call source=client payload={"task":null,"_meta":null,"name":"simple_operation","arguments":{"data":"custom_test"}} payload_type=CallToolRequestParams
669669
event=request_success method=tools/call source=client duration_ms=0.02
670670
""")

tests/server/test_auth_integration.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -366,18 +366,19 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient):
366366
assert metadata["revocation_endpoint"] == "https://auth.example.com/revoke"
367367
assert metadata["response_types_supported"] == ["code"]
368368
assert metadata["code_challenge_methods_supported"] == ["S256"]
369-
assert metadata["token_endpoint_auth_methods_supported"] == [
370-
"client_secret_post"
371-
]
369+
assert set(metadata["token_endpoint_auth_methods_supported"]) == {
370+
"client_secret_post",
371+
"client_secret_basic",
372+
}
372373
assert metadata["grant_types_supported"] == [
373374
"authorization_code",
374375
"refresh_token",
375376
]
376377
assert metadata["service_documentation"] == "https://docs.example.com/"
377378

378379
async def test_token_validation_error(self, test_client: httpx.AsyncClient):
379-
"""Test token endpoint error - validation error."""
380-
# Missing required fields
380+
"""Test token endpoint error - missing client_id returns auth error."""
381+
# Missing required fields - SDK validates client_id first
381382
response = await test_client.post(
382383
"/token",
383384
data={
@@ -386,10 +387,11 @@ async def test_token_validation_error(self, test_client: httpx.AsyncClient):
386387
},
387388
)
388389
error_response = response.json()
389-
assert error_response["error"] == "invalid_request"
390-
assert (
391-
"error_description" in error_response
392-
) # Contains validation error messages
390+
# SDK validates client_id before other fields, returning unauthorized_client
391+
# (FastMCP's OAuthProxy transforms this to invalid_client, but this test
392+
# uses the SDK's create_auth_routes directly)
393+
assert error_response["error"] == "unauthorized_client"
394+
assert "error_description" in error_response
393395

394396
async def test_token_invalid_auth_code(
395397
self, test_client, registered_client, pkce_challenge

0 commit comments

Comments
 (0)