Skip to content

Commit 1544000

Browse files
committed
Add client_secret_basic authentication support
Add support for HTTP Basic Authentication (client_secret_basic) as a client authentication method for the token and revoke endpoints, alongside the existing client_secret_post method. This improves compatibility with OAuth servers like Keycloak that use Basic auth. Key changes: - Update OAuthClientMetadata to accept "client_secret_basic" as valid token_endpoint_auth_method - Return 401 status for authentication failures (was 400) - Update metadata endpoints to advertise both auth methods - Add tests for both auth methods and edge cases
1 parent 47d35f0 commit 1544000

File tree

7 files changed

+235
-46
lines changed

7 files changed

+235
-46
lines changed

src/mcp/server/auth/handlers/revoke.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,28 +40,25 @@ async def handle(self, request: Request) -> Response:
4040
Handler for the OAuth 2.0 Token Revocation endpoint.
4141
"""
4242
try:
43-
form_data = await request.form()
44-
revocation_request = RevocationRequest.model_validate(dict(form_data))
45-
except ValidationError as e:
43+
client = await self.client_authenticator.authenticate_request(request)
44+
except AuthenticationError as e:
4645
return PydanticJSONResponse(
47-
status_code=400,
46+
status_code=401,
4847
content=RevocationErrorResponse(
49-
error="invalid_request",
50-
error_description=stringify_pydantic_error(e),
48+
error="unauthorized_client",
49+
error_description=e.message,
5150
),
5251
)
5352

54-
# Authenticate client
5553
try:
56-
client = await self.client_authenticator.authenticate(
57-
revocation_request.client_id, revocation_request.client_secret
58-
)
59-
except AuthenticationError as e:
54+
form_data = await request.form()
55+
revocation_request = RevocationRequest.model_validate(dict(form_data))
56+
except ValidationError as e:
6057
return PydanticJSONResponse(
61-
status_code=401,
58+
status_code=400,
6259
content=RevocationErrorResponse(
63-
error="unauthorized_client",
64-
error_description=e.message,
60+
error="invalid_request",
61+
error_description=stringify_pydantic_error(e),
6562
),
6663
)
6764

src/mcp/server/auth/handlers/token.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,22 @@ def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
9191
)
9292

9393
async def handle(self, request: Request):
94+
try:
95+
client_info = await self.client_authenticator.authenticate_request(request)
96+
except AuthenticationError as e:
97+
# Authentication failures should return 401
98+
return PydanticJSONResponse(
99+
content=TokenErrorResponse(
100+
error="unauthorized_client",
101+
error_description=e.message,
102+
),
103+
status_code=401,
104+
headers={
105+
"Cache-Control": "no-store",
106+
"Pragma": "no-cache",
107+
},
108+
)
109+
94110
try:
95111
form_data = await request.form()
96112
token_request = TokenRequest.model_validate(dict(form_data)).root
@@ -102,19 +118,6 @@ async def handle(self, request: Request):
102118
)
103119
)
104120

105-
try:
106-
client_info = await self.client_authenticator.authenticate(
107-
client_id=token_request.client_id,
108-
client_secret=token_request.client_secret,
109-
)
110-
except AuthenticationError as e:
111-
return self.response(
112-
TokenErrorResponse(
113-
error="unauthorized_client",
114-
error_description=e.message,
115-
)
116-
)
117-
118121
if token_request.grant_type not in client_info.grant_types:
119122
return self.response(
120123
TokenErrorResponse(

src/mcp/server/auth/middleware/client_auth.py

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import base64
12
import time
23
from typing import Any
34

5+
from starlette.requests import Request
6+
47
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
58
from mcp.shared.auth import OAuthClientInformationFull
69

@@ -30,19 +33,67 @@ def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
3033
"""
3134
self.provider = provider
3235

33-
async def authenticate(self, client_id: str, client_secret: str | None) -> OAuthClientInformationFull:
34-
# Look up client information
35-
client = await self.provider.get_client(client_id)
36+
async def authenticate_request(self, request: Request) -> OAuthClientInformationFull:
37+
"""
38+
Authenticate a client from an HTTP request.
39+
40+
Extracts client credentials from the appropriate location based on the
41+
client's registered authentication method and validates them.
42+
43+
Args:
44+
request: The HTTP request containing client credentials
45+
46+
Returns:
47+
The authenticated client information
48+
49+
Raises:
50+
AuthenticationError: If authentication fails
51+
"""
52+
form_data = await request.form()
53+
client_id = form_data.get("client_id")
54+
if not client_id:
55+
raise AuthenticationError("Missing client_id")
56+
57+
client = await self.provider.get_client(str(client_id))
3658
if not client:
3759
raise AuthenticationError("Invalid client_id")
3860

39-
# If client from the store expects a secret, validate that the request provides
40-
# that secret
61+
request_client_secret = None
62+
auth_header = request.headers.get("Authorization", "")
63+
64+
if client.token_endpoint_auth_method == "client_secret_basic":
65+
if not auth_header.startswith("Basic "):
66+
raise AuthenticationError("Missing or invalid Basic authentication in Authorization header")
67+
68+
try:
69+
encoded_credentials = auth_header[6:] # Remove "Basic " prefix
70+
decoded = base64.b64decode(encoded_credentials).decode("utf-8")
71+
if ":" not in decoded:
72+
raise ValueError("Invalid Basic auth format")
73+
basic_client_id, request_client_secret = decoded.split(":", 1)
74+
75+
if basic_client_id != client_id:
76+
raise AuthenticationError("Client ID mismatch in Basic auth")
77+
except AuthenticationError:
78+
raise
79+
except Exception:
80+
raise AuthenticationError("Invalid Basic authentication header")
81+
82+
elif client.token_endpoint_auth_method == "client_secret_post":
83+
request_client_secret = form_data.get("client_secret")
84+
if request_client_secret:
85+
request_client_secret = str(request_client_secret)
86+
87+
elif client.token_endpoint_auth_method == "none":
88+
request_client_secret = None
89+
else:
90+
raise AuthenticationError(f"Unsupported auth method: {client.token_endpoint_auth_method}")
91+
4192
if client.client_secret:
42-
if not client_secret:
93+
if not request_client_secret:
4394
raise AuthenticationError("Client secret is required")
4495

45-
if client.client_secret != client_secret:
96+
if client.client_secret != request_client_secret:
4697
raise AuthenticationError("Invalid client_secret")
4798

4899
if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()):

src/mcp/server/auth/routes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def build_metadata(
164164
response_types_supported=["code"],
165165
response_modes_supported=None,
166166
grant_types_supported=["authorization_code", "refresh_token"],
167-
token_endpoint_auth_methods_supported=["client_secret_post"],
167+
token_endpoint_auth_methods_supported=["client_secret_post", "client_secret_basic"],
168168
token_endpoint_auth_signing_alg_values_supported=None,
169169
service_documentation=service_documentation_url,
170170
ui_locales_supported=None,
@@ -181,7 +181,7 @@ def build_metadata(
181181
# Add revocation endpoint if supported
182182
if revocation_options.enabled:
183183
metadata.revocation_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REVOCATION_PATH)
184-
metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"]
184+
metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post", "client_secret_basic"]
185185

186186
return metadata
187187

src/mcp/shared/auth.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,7 @@ class OAuthClientMetadata(BaseModel):
4242
"""
4343

4444
redirect_uris: list[AnyUrl] = Field(..., min_length=1)
45-
# token_endpoint_auth_method: this implementation only supports none &
46-
# client_secret_post;
47-
# ie: we do not support client_secret_basic
48-
token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post"
45+
token_endpoint_auth_method: Literal["none", "client_secret_post", "client_secret_basic"] = "client_secret_post"
4946
# grant_types: this implementation only supports authorization_code & refresh_token
5047
grant_types: list[Literal["authorization_code", "refresh_token"]] = [
5148
"authorization_code",

tests/client/test_auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -832,10 +832,10 @@ def test_build_metadata(
832832
"registration_endpoint": Is(registration_endpoint),
833833
"scopes_supported": ["read", "write", "admin"],
834834
"grant_types_supported": ["authorization_code", "refresh_token"],
835-
"token_endpoint_auth_methods_supported": ["client_secret_post"],
835+
"token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
836836
"service_documentation": Is(service_documentation_url),
837837
"revocation_endpoint": Is(revocation_endpoint),
838-
"revocation_endpoint_auth_methods_supported": ["client_secret_post"],
838+
"revocation_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
839839
"code_challenge_methods_supported": ["S256"],
840840
}
841841
)

tests/server/fastmcp/auth/test_auth_integration.py

Lines changed: 145 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import httpx
1414
import pytest
15-
from pydantic import AnyHttpUrl
15+
from pydantic import AnyHttpUrl, AnyUrl
1616
from starlette.applications import Starlette
1717

1818
from mcp.server.auth.provider import (
@@ -357,7 +357,7 @@ async def test_metadata_endpoint(self, test_client: httpx.AsyncClient):
357357
assert metadata["revocation_endpoint"] == "https://auth.example.com/revoke"
358358
assert metadata["response_types_supported"] == ["code"]
359359
assert metadata["code_challenge_methods_supported"] == ["S256"]
360-
assert metadata["token_endpoint_auth_methods_supported"] == ["client_secret_post"]
360+
assert metadata["token_endpoint_auth_methods_supported"] == ["client_secret_post", "client_secret_basic"]
361361
assert metadata["grant_types_supported"] == [
362362
"authorization_code",
363363
"refresh_token",
@@ -376,8 +376,8 @@ async def test_token_validation_error(self, test_client: httpx.AsyncClient):
376376
},
377377
)
378378
error_response = response.json()
379-
assert error_response["error"] == "invalid_request"
380-
assert "error_description" in error_response # Contains validation error messages
379+
assert error_response["error"] == "unauthorized_client"
380+
assert "error_description" in error_response # Contains error message
381381

382382
@pytest.mark.anyio
383383
async def test_token_invalid_auth_code(
@@ -942,6 +942,147 @@ async def test_client_registration_invalid_grant_type(self, test_client: httpx.A
942942
assert error_data["error"] == "invalid_client_metadata"
943943
assert error_data["error_description"] == "grant_types must be authorization_code and refresh_token"
944944

945+
@pytest.mark.anyio
946+
async def test_client_secret_basic_authentication(
947+
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str]
948+
):
949+
"""Test that client_secret_basic authentication works correctly."""
950+
client_metadata = {
951+
"redirect_uris": ["https://client.example.com/callback"],
952+
"client_name": "Basic Auth Client",
953+
"token_endpoint_auth_method": "client_secret_basic",
954+
"grant_types": ["authorization_code", "refresh_token"],
955+
}
956+
957+
response = await test_client.post("/register", json=client_metadata)
958+
assert response.status_code == 201
959+
client_info = response.json()
960+
assert client_info["token_endpoint_auth_method"] == "client_secret_basic"
961+
962+
auth_code = f"code_{int(time.time())}"
963+
mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode(
964+
code=auth_code,
965+
client_id=client_info["client_id"],
966+
code_challenge=pkce_challenge["code_challenge"],
967+
redirect_uri=AnyUrl("https://client.example.com/callback"),
968+
redirect_uri_provided_explicitly=True,
969+
scopes=["read", "write"],
970+
expires_at=time.time() + 600,
971+
)
972+
973+
credentials = f"{client_info['client_id']}:{client_info['client_secret']}"
974+
encoded_credentials = base64.b64encode(credentials.encode()).decode()
975+
976+
response = await test_client.post(
977+
"/token",
978+
headers={"Authorization": f"Basic {encoded_credentials}"},
979+
data={
980+
"grant_type": "authorization_code",
981+
"client_id": client_info["client_id"],
982+
"code": auth_code,
983+
"code_verifier": pkce_challenge["code_verifier"],
984+
"redirect_uri": "https://client.example.com/callback",
985+
},
986+
)
987+
assert response.status_code == 200
988+
token_response = response.json()
989+
assert "access_token" in token_response
990+
991+
@pytest.mark.anyio
992+
async def test_wrong_auth_method_without_valid_credentials_fails(
993+
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str]
994+
):
995+
"""Test that using the wrong authentication method fails when credentials are missing."""
996+
client_metadata = {
997+
"redirect_uris": ["https://client.example.com/callback"],
998+
"client_name": "Post Auth Client",
999+
"token_endpoint_auth_method": "client_secret_post",
1000+
"grant_types": ["authorization_code", "refresh_token"],
1001+
}
1002+
1003+
response = await test_client.post("/register", json=client_metadata)
1004+
assert response.status_code == 201
1005+
client_info = response.json()
1006+
assert client_info["token_endpoint_auth_method"] == "client_secret_post"
1007+
1008+
auth_code = f"code_{int(time.time())}"
1009+
mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode(
1010+
code=auth_code,
1011+
client_id=client_info["client_id"],
1012+
code_challenge=pkce_challenge["code_challenge"],
1013+
redirect_uri=AnyUrl("https://client.example.com/callback"),
1014+
redirect_uri_provided_explicitly=True,
1015+
scopes=["read", "write"],
1016+
expires_at=time.time() + 600,
1017+
)
1018+
1019+
# Try to use Basic auth when client_secret_post is registered (without secret in body)
1020+
# This should fail because the secret is missing from the expected location
1021+
1022+
credentials = f"{client_info['client_id']}:{client_info['client_secret']}"
1023+
encoded_credentials = base64.b64encode(credentials.encode()).decode()
1024+
1025+
response = await test_client.post(
1026+
"/token",
1027+
headers={"Authorization": f"Basic {encoded_credentials}"},
1028+
data={
1029+
"grant_type": "authorization_code",
1030+
"client_id": client_info["client_id"],
1031+
# client_secret NOT in body where it should be
1032+
"code": auth_code,
1033+
"code_verifier": pkce_challenge["code_verifier"],
1034+
"redirect_uri": "https://client.example.com/callback",
1035+
},
1036+
)
1037+
assert response.status_code == 401
1038+
error_response = response.json()
1039+
assert error_response["error"] == "unauthorized_client"
1040+
assert "Client secret is required" in error_response["error_description"]
1041+
1042+
@pytest.mark.anyio
1043+
async def test_basic_auth_without_header_fails(
1044+
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str]
1045+
):
1046+
"""Test that omitting Basic auth when client_secret_basic is registered fails."""
1047+
client_metadata = {
1048+
"redirect_uris": ["https://client.example.com/callback"],
1049+
"client_name": "Basic Auth Client",
1050+
"token_endpoint_auth_method": "client_secret_basic",
1051+
"grant_types": ["authorization_code", "refresh_token"],
1052+
}
1053+
1054+
response = await test_client.post("/register", json=client_metadata)
1055+
assert response.status_code == 201
1056+
client_info = response.json()
1057+
assert client_info["token_endpoint_auth_method"] == "client_secret_basic"
1058+
1059+
auth_code = f"code_{int(time.time())}"
1060+
mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode(
1061+
code=auth_code,
1062+
client_id=client_info["client_id"],
1063+
code_challenge=pkce_challenge["code_challenge"],
1064+
redirect_uri=AnyUrl("https://client.example.com/callback"),
1065+
redirect_uri_provided_explicitly=True,
1066+
scopes=["read", "write"],
1067+
expires_at=time.time() + 600,
1068+
)
1069+
1070+
response = await test_client.post(
1071+
"/token",
1072+
data={
1073+
"grant_type": "authorization_code",
1074+
"client_id": client_info["client_id"],
1075+
"client_secret": client_info["client_secret"], # Secret in body (ignored)
1076+
"code": auth_code,
1077+
"code_verifier": pkce_challenge["code_verifier"],
1078+
"redirect_uri": "https://client.example.com/callback",
1079+
},
1080+
)
1081+
assert response.status_code == 401
1082+
error_response = response.json()
1083+
assert error_response["error"] == "unauthorized_client"
1084+
assert "Missing or invalid Basic authentication" in error_response["error_description"]
1085+
9451086

9461087
class TestAuthorizeEndpointErrors:
9471088
"""Test error handling in the OAuth authorization endpoint."""

0 commit comments

Comments
 (0)