Skip to content

Commit e88b4aa

Browse files
committed
Move around the response models to be closer to the handlers
1 parent 1202029 commit e88b4aa

File tree

7 files changed

+103
-104
lines changed

7 files changed

+103
-104
lines changed

src/mcp/server/auth/handlers/authorize.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,25 @@ class AuthorizationRequest(BaseModel):
5353
)
5454

5555

56+
AuthorizationErrorCode = Literal[
57+
"invalid_request",
58+
"unauthorized_client",
59+
"access_denied",
60+
"unsupported_response_type",
61+
"invalid_scope",
62+
"server_error",
63+
"temporarily_unavailable",
64+
]
65+
66+
67+
class AuthorizationErrorResponse(BaseModel):
68+
error: AuthorizationErrorCode
69+
error_description: str
70+
error_uri: AnyUrl | None = None
71+
# must be set if provided in the request
72+
state: str | None = None
73+
74+
5675
def validate_scope(
5776
requested_scope: str | None, client: OAuthClientInformationFull
5877
) -> list[str] | None:
@@ -84,25 +103,6 @@ def validate_redirect_uri(
84103
)
85104

86105

87-
ErrorCode = Literal[
88-
"invalid_request",
89-
"unauthorized_client",
90-
"access_denied",
91-
"unsupported_response_type",
92-
"invalid_scope",
93-
"server_error",
94-
"temporarily_unavailable",
95-
]
96-
97-
98-
class ErrorResponse(BaseModel):
99-
error: ErrorCode
100-
error_description: str
101-
error_uri: AnyUrl | None = None
102-
# must be set if provided in the request
103-
state: str | None = None
104-
105-
106106
def best_effort_extract_string(
107107
key: str, params: None | FormData | QueryParams
108108
) -> str | None:
@@ -132,7 +132,9 @@ async def handle(self, request: Request) -> Response:
132132
params = None
133133

134134
async def error_response(
135-
error: ErrorCode, error_description: str, attempt_load_client: bool = True
135+
error: AuthorizationErrorCode,
136+
error_description: str,
137+
attempt_load_client: bool = True,
136138
):
137139
nonlocal client, redirect_uri, state
138140
if client is None and attempt_load_client:
@@ -157,7 +159,7 @@ async def error_response(
157159
# make last-ditch effort to load state
158160
state = best_effort_extract_string("state", params)
159161

160-
error_resp = ErrorResponse(
162+
error_resp = AuthorizationErrorResponse(
161163
error=error,
162164
error_description=error_description,
163165
state=state,
@@ -194,7 +196,7 @@ async def error_response(
194196
auth_request = AuthorizationRequest.model_validate(params)
195197
state = auth_request.state # Update with validated state
196198
except ValidationError as validation_error:
197-
error: ErrorCode = "invalid_request"
199+
error: AuthorizationErrorCode = "invalid_request"
198200
for e in validation_error.errors():
199201
if e["loc"] == ("response_type",) and e["type"] == "literal_error":
200202
error = "unsupported_response_type"
@@ -264,11 +266,11 @@ async def error_response(
264266

265267

266268
def create_error_redirect(
267-
redirect_uri: AnyUrl, error: Exception | ErrorResponse
269+
redirect_uri: AnyUrl, error: Exception | AuthorizationErrorResponse
268270
) -> str:
269271
parsed_uri = urlparse(str(redirect_uri))
270272

271-
if isinstance(error, ErrorResponse):
273+
if isinstance(error, AuthorizationErrorResponse):
272274
# Convert ErrorResponse to dict
273275
error_dict = error.model_dump(exclude_none=True)
274276
query_params = {}

src/mcp/server/auth/handlers/register.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Literal
1111
from uuid import uuid4
1212

13-
from pydantic import BaseModel, ValidationError
13+
from pydantic import BaseModel, RootModel, ValidationError
1414
from starlette.requests import Request
1515
from starlette.responses import Response
1616

@@ -20,7 +20,13 @@
2020
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
2121

2222

23-
class ErrorResponse(BaseModel):
23+
class RegistrationRequest(RootModel):
24+
# this wrapper is a no-op; it's just to separate out the types exposed to the
25+
# provider from what we use in the HTTP handler
26+
root: OAuthClientMetadata
27+
28+
29+
class RegistrationErrorResponse(BaseModel):
2430
error: Literal[
2531
"invalid_redirect_uri",
2632
"invalid_client_metadata",
@@ -43,7 +49,7 @@ async def handle(self, request: Request) -> Response:
4349
client_metadata = OAuthClientMetadata.model_validate(body)
4450
except ValidationError as validation_error:
4551
return PydanticJSONResponse(
46-
content=ErrorResponse(
52+
content=RegistrationErrorResponse(
4753
error="invalid_client_metadata",
4854
error_description=stringify_pydantic_error(validation_error),
4955
),

src/mcp/server/auth/handlers/revoke.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,34 @@
55
"""
66

77
from dataclasses import dataclass
8+
from typing import Literal
89

9-
from pydantic import ValidationError
10+
from pydantic import BaseModel, ValidationError
1011
from starlette.requests import Request
1112
from starlette.responses import Response
1213

1314
from mcp.server.auth.errors import (
14-
InvalidClientError,
1515
stringify_pydantic_error,
1616
)
1717
from mcp.server.auth.json_response import PydanticJSONResponse
1818
from mcp.server.auth.middleware.client_auth import (
1919
ClientAuthenticator,
20-
ClientAuthRequest,
2120
)
22-
from mcp.server.auth.provider import OAuthServerProvider, OAuthTokenRevocationRequest
23-
from mcp.shared.auth import TokenErrorResponse
21+
from mcp.server.auth.provider import OAuthServerProvider
2422

2523

26-
class RevocationRequest(OAuthTokenRevocationRequest, ClientAuthRequest):
27-
pass
24+
class RevocationRequest(BaseModel):
25+
"""
26+
# See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1
27+
"""
28+
29+
token: str
30+
token_type_hint: Literal["access_token", "refresh_token"] | None = None
31+
32+
33+
class RevocationErrorResponse(BaseModel):
34+
error: Literal["invalid_request",]
35+
error_description: str | None = None
2836

2937

3038
@dataclass
@@ -42,21 +50,16 @@ async def handle(self, request: Request) -> Response:
4250
except ValidationError as e:
4351
return PydanticJSONResponse(
4452
status_code=400,
45-
content=TokenErrorResponse(
53+
content=RevocationErrorResponse(
4654
error="invalid_request",
4755
error_description=stringify_pydantic_error(e),
4856
),
4957
)
5058

51-
# Authenticate client
52-
try:
53-
client_auth_result = await self.client_authenticator(revocation_request)
54-
except InvalidClientError as e:
55-
return PydanticJSONResponse(status_code=401, content=e.error_response())
56-
5759
# Revoke token
58-
if self.provider.revoke_token:
59-
await self.provider.revoke_token(client_auth_result, revocation_request)
60+
await self.provider.revoke_token(
61+
revocation_request.token, revocation_request.token_type_hint
62+
)
6063

6164
# Return successful empty response
6265
return Response(

src/mcp/server/auth/handlers/token.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from dataclasses import dataclass
1111
from typing import Annotated, Literal
1212

13-
from pydantic import AnyHttpUrl, Field, RootModel, ValidationError
13+
from pydantic import AnyHttpUrl, BaseModel, Field, RootModel, ValidationError
1414
from starlette.requests import Request
1515

1616
from mcp.server.auth.errors import (
@@ -24,7 +24,7 @@
2424
ClientAuthRequest,
2525
)
2626
from mcp.server.auth.provider import OAuthServerProvider
27-
from mcp.shared.auth import TokenErrorResponse, TokenSuccessResponse
27+
from mcp.shared.auth import OAuthToken
2828

2929

3030
class AuthorizationCodeRequest(ClientAuthRequest):
@@ -53,6 +53,30 @@ class TokenRequest(RootModel):
5353
]
5454

5555

56+
class TokenErrorResponse(BaseModel):
57+
"""
58+
See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
59+
"""
60+
61+
error: Literal[
62+
"invalid_request",
63+
"invalid_client",
64+
"invalid_grant",
65+
"unauthorized_client",
66+
"unsupported_grant_type",
67+
"invalid_scope",
68+
]
69+
error_description: str | None = None
70+
error_uri: AnyHttpUrl | None = None
71+
72+
73+
class TokenSuccessResponse(RootModel):
74+
# this is just a wrapper over OAuthToken; the only reason we do this
75+
# is to have some separation between the HTTP response type, and the
76+
# type returned by the provider
77+
root: OAuthToken
78+
79+
5680
@dataclass
5781
class TokenHandler:
5882
provider: OAuthServerProvider
@@ -100,7 +124,7 @@ async def handle(self, request: Request):
100124
)
101125
)
102126

103-
tokens: TokenSuccessResponse
127+
tokens: OAuthToken
104128

105129
match token_request:
106130
case AuthorizationCodeRequest():
@@ -208,4 +232,4 @@ async def handle(self, request: Request):
208232
client_info, refresh_token, scopes
209233
)
210234

211-
return self.response(tokens)
235+
return self.response(TokenSuccessResponse(root=tokens))

src/mcp/server/auth/provider.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from mcp.server.auth.types import AuthInfo
1313
from mcp.shared.auth import (
1414
OAuthClientInformationFull,
15-
TokenSuccessResponse,
15+
OAuthToken,
1616
)
1717

1818

@@ -45,15 +45,6 @@ class RefreshToken(BaseModel):
4545
expires_at: int | None = None
4646

4747

48-
class OAuthTokenRevocationRequest(BaseModel):
49-
"""
50-
# See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1
51-
"""
52-
53-
token: str
54-
token_type_hint: Literal["access_token", "refresh_token"] | None = None
55-
56-
5748
class OAuthRegisteredClientsStore(Protocol):
5849
"""
5950
Interface for storing and retrieving registered OAuth clients.
@@ -149,7 +140,7 @@ async def load_authorization_code(
149140

150141
async def exchange_authorization_code(
151142
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
152-
) -> TokenSuccessResponse:
143+
) -> OAuthToken:
153144
"""
154145
Exchanges an authorization code for an access token.
155146
@@ -171,7 +162,7 @@ async def exchange_refresh_token(
171162
client: OAuthClientInformationFull,
172163
refresh_token: RefreshToken,
173164
scopes: list[str],
174-
) -> TokenSuccessResponse:
165+
) -> OAuthToken:
175166
"""
176167
Exchanges a refresh token for an access token.
177168
@@ -198,16 +189,20 @@ async def load_access_token(self, token: str) -> AuthInfo | None:
198189
...
199190

200191
async def revoke_token(
201-
self, client: OAuthClientInformationFull, request: OAuthTokenRevocationRequest
192+
self,
193+
token: str,
194+
token_type_hint: Literal["access_token", "refresh_token"] | None = None,
202195
) -> None:
203196
"""
204197
Revokes an access or refresh token.
205198
206199
If the given token is invalid or already revoked, this method should do nothing.
207200
208201
Args:
209-
client: The client revoking the token.
210-
request: The token revocation request.
202+
token: the token to revoke
203+
token_type_hint: hint about the type of token to revoke; optional. if the
204+
token cannot be located using this hint, the provider MUST extend its search
205+
to include all tokens.
211206
"""
212207
...
213208

src/mcp/shared/auth.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,7 @@
99
from pydantic import AnyHttpUrl, BaseModel, Field
1010

1111

12-
class TokenErrorResponse(BaseModel):
13-
"""
14-
See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
15-
"""
16-
17-
error: Literal[
18-
"invalid_request",
19-
"invalid_client",
20-
"invalid_grant",
21-
"unauthorized_client",
22-
"unsupported_grant_type",
23-
"invalid_scope",
24-
]
25-
error_description: Optional[str] = None
26-
error_uri: Optional[AnyHttpUrl] = None
27-
28-
29-
class TokenSuccessResponse(BaseModel):
12+
class OAuthToken(BaseModel):
3013
"""
3114
See https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
3215
"""

0 commit comments

Comments
 (0)