Skip to content

Commit a51ced2

Browse files
committed
Use classes for handlers
1 parent 26393a7 commit a51ced2

File tree

6 files changed

+87
-95
lines changed

6 files changed

+87
-95
lines changed

src/mcp/server/auth/handlers/authorize.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
"""
66

77
import logging
8-
from typing import Callable, Literal
8+
from dataclasses import dataclass
9+
from typing import Literal
910
from urllib.parse import urlencode, urlparse, urlunparse
1011

1112
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
@@ -117,8 +118,11 @@ class AnyHttpUrlModel(RootModel):
117118
root: AnyHttpUrl
118119

119120

120-
def create_authorization_handler(provider: OAuthServerProvider) -> Callable:
121-
async def authorization_handler(request: Request) -> Response:
121+
@dataclass
122+
class AuthorizationHandler:
123+
provider: OAuthServerProvider
124+
125+
async def handle(self, request: Request) -> Response:
122126
# implements authorization requests for grant_type=code;
123127
# see https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
124128

@@ -134,7 +138,7 @@ async def error_response(
134138
if client is None and attempt_load_client:
135139
# make last-ditch attempt to load the client
136140
client_id = best_effort_extract_string("client_id", params)
137-
client = client_id and await provider.clients_store.get_client(
141+
client = client_id and await self.provider.clients_store.get_client(
138142
client_id
139143
)
140144
if redirect_uri is None and client:
@@ -200,7 +204,9 @@ async def error_response(
200204
)
201205

202206
# Get client information
203-
client = await provider.clients_store.get_client(auth_request.client_id)
207+
client = await self.provider.clients_store.get_client(
208+
auth_request.client_id,
209+
)
204210
if not client:
205211
# For client_id validation errors, return direct error (no redirect)
206212
return await error_response(
@@ -241,7 +247,10 @@ async def error_response(
241247
response = RedirectResponse(
242248
url="", status_code=302, headers={"Cache-Control": "no-store"}
243249
)
244-
response.headers["location"] = await provider.authorize(client, auth_params)
250+
response.headers["location"] = await self.provider.authorize(
251+
client,
252+
auth_params,
253+
)
245254
return response
246255

247256
except Exception as validation_error:
@@ -253,7 +262,6 @@ async def error_response(
253262
error="server_error", error_description="An unexpected error occurred"
254263
)
255264

256-
return authorization_handler
257265

258266

259267
def create_error_redirect(

src/mcp/server/auth/handlers/metadata.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,41 +4,22 @@
44
Corresponds to TypeScript file: src/server/auth/handlers/metadata.ts
55
"""
66

7-
from typing import Any, Callable
7+
from dataclasses import dataclass
8+
from typing import Any
89

910
from starlette.requests import Request
1011
from starlette.responses import JSONResponse, Response
1112

1213

13-
def create_metadata_handler(metadata: dict[str, Any]) -> Callable:
14-
"""
15-
Create a handler for OAuth 2.0 Authorization Server Metadata.
14+
@dataclass
15+
class MetadataHandler:
16+
metadata: dict[str, Any]
1617

17-
Corresponds to metadataHandler in src/server/auth/handlers/metadata.ts
18-
19-
Args:
20-
metadata: The metadata to return in the response
21-
22-
Returns:
23-
A Starlette endpoint handler function
24-
"""
25-
26-
async def metadata_handler(request: Request) -> Response:
27-
"""
28-
Handler for the OAuth 2.0 Authorization Server Metadata endpoint.
29-
30-
Args:
31-
request: The Starlette request
32-
33-
Returns:
34-
JSON response with the authorization server metadata
35-
"""
18+
async def handle(self, request: Request) -> Response:
3619
# Remove any None values from metadata
37-
clean_metadata = {k: v for k, v in metadata.items() if v is not None}
20+
clean_metadata = {k: v for k, v in self.metadata.items() if v is not None}
3821

3922
return JSONResponse(
4023
content=clean_metadata,
4124
headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour
4225
)
43-
44-
return metadata_handler

src/mcp/server/auth/handlers/register.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
import secrets
88
import time
9-
from typing import Callable, Literal
9+
from dataclasses import dataclass
10+
from typing import Literal
1011
from uuid import uuid4
1112

1213
from pydantic import BaseModel, ValidationError
@@ -29,10 +30,11 @@ class ErrorResponse(BaseModel):
2930
error_description: str
3031

3132

32-
def create_registration_handler(
33-
clients_store: OAuthRegisteredClientsStore, client_secret_expiry_seconds: int | None
34-
) -> Callable:
35-
async def registration_handler(request: Request) -> Response:
33+
@dataclass
34+
class RegistrationHandler:
35+
clients_store: OAuthRegisteredClientsStore
36+
client_secret_expiry_seconds: int | None
37+
async def handle(self, request: Request) -> Response:
3638
# Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1
3739
try:
3840
# Parse request body as JSON
@@ -55,8 +57,8 @@ async def registration_handler(request: Request) -> Response:
5557

5658
client_id_issued_at = int(time.time())
5759
client_secret_expires_at = (
58-
client_id_issued_at + client_secret_expiry_seconds
59-
if client_secret_expiry_seconds is not None
60+
client_id_issued_at + self.client_secret_expiry_seconds
61+
if self.client_secret_expiry_seconds is not None
6062
else None
6163
)
6264

@@ -83,9 +85,7 @@ async def registration_handler(request: Request) -> Response:
8385
software_version=client_metadata.software_version,
8486
)
8587
# Register client
86-
await clients_store.register_client(client_info)
88+
await self.clients_store.register_client(client_info)
8789

8890
# Return client information
8991
return PydanticJSONResponse(content=client_info, status_code=201)
90-
91-
return registration_handler

src/mcp/server/auth/handlers/revoke.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Corresponds to TypeScript file: src/server/auth/handlers/revoke.ts
55
"""
66

7-
from typing import Callable
7+
from dataclasses import dataclass
88

99
from pydantic import ValidationError
1010
from starlette.requests import Request
@@ -27,10 +27,12 @@ class RevocationRequest(OAuthTokenRevocationRequest, ClientAuthRequest):
2727
pass
2828

2929

30-
def create_revocation_handler(
31-
provider: OAuthServerProvider, client_authenticator: ClientAuthenticator
32-
) -> Callable:
33-
async def revocation_handler(request: Request) -> Response:
30+
@dataclass
31+
class RevocationHandler:
32+
provider: OAuthServerProvider
33+
client_authenticator: ClientAuthenticator
34+
35+
async def handle(self, request: Request) -> Response:
3436
"""
3537
Handler for the OAuth 2.0 Token Revocation endpoint.
3638
"""
@@ -48,13 +50,13 @@ async def revocation_handler(request: Request) -> Response:
4850

4951
# Authenticate client
5052
try:
51-
client_auth_result = await client_authenticator(revocation_request)
53+
client_auth_result = await self.client_authenticator(revocation_request)
5254
except InvalidClientError as e:
5355
return PydanticJSONResponse(status_code=401, content=e.error_response())
5456

5557
# Revoke token
56-
if provider.revoke_token:
57-
await provider.revoke_token(client_auth_result, revocation_request)
58+
if self.provider.revoke_token:
59+
await self.provider.revoke_token(client_auth_result, revocation_request)
5860

5961
# Return successful empty response
6062
return Response(
@@ -64,5 +66,3 @@ async def revocation_handler(request: Request) -> Response:
6466
"Pragma": "no-cache",
6567
},
6668
)
67-
68-
return revocation_handler

src/mcp/server/auth/handlers/token.py

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import base64
88
import hashlib
99
import time
10-
from typing import Annotated, Callable, Literal
10+
from dataclasses import dataclass
11+
from typing import Annotated, Literal
1112

1213
from pydantic import AnyHttpUrl, Field, RootModel, ValidationError
1314
from starlette.requests import Request
@@ -52,10 +53,12 @@ class TokenRequest(RootModel):
5253
]
5354

5455

55-
def create_token_handler(
56-
provider: OAuthServerProvider, client_authenticator: ClientAuthenticator
57-
) -> Callable:
58-
def response(obj: TokenSuccessResponse | TokenErrorResponse | ErrorResponse):
56+
@dataclass
57+
class TokenHandler:
58+
provider: OAuthServerProvider
59+
client_authenticator: ClientAuthenticator
60+
61+
def response(self, obj: TokenSuccessResponse | TokenErrorResponse | ErrorResponse):
5962
status_code = 200
6063
if isinstance(obj, TokenErrorResponse):
6164
status_code = 400
@@ -69,25 +72,25 @@ def response(obj: TokenSuccessResponse | TokenErrorResponse | ErrorResponse):
6972
},
7073
)
7174

72-
async def token_handler(request: Request):
75+
async def handle(self, request: Request):
7376
try:
7477
form_data = await request.form()
7578
token_request = TokenRequest.model_validate(dict(form_data)).root
7679
except ValidationError as validation_error:
77-
return response(
80+
return self.response(
7881
TokenErrorResponse(
7982
error="invalid_request",
8083
error_description=stringify_pydantic_error(validation_error),
8184
)
8285
)
8386

8487
try:
85-
client_info = await client_authenticator(token_request)
88+
client_info = await self.client_authenticator(token_request)
8689
except InvalidClientError as e:
87-
return response(e.error_response())
90+
return self.response(e.error_response())
8891

8992
if token_request.grant_type not in client_info.grant_types:
90-
return response(
93+
return self.response(
9194
TokenErrorResponse(
9295
error="unsupported_grant_type",
9396
error_description=(
@@ -101,12 +104,12 @@ async def token_handler(request: Request):
101104

102105
match token_request:
103106
case AuthorizationCodeRequest():
104-
auth_code = await provider.load_authorization_code(
107+
auth_code = await self.provider.load_authorization_code(
105108
client_info, token_request.code
106109
)
107110
if auth_code is None or auth_code.client_id != token_request.client_id:
108111
# if code belongs to different client, pretend it doesn't exist
109-
return response(
112+
return self.response(
110113
TokenErrorResponse(
111114
error="invalid_grant",
112115
error_description="authorization code does not exist",
@@ -116,7 +119,7 @@ async def token_handler(request: Request):
116119
# make auth codes expire after a deadline
117120
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5
118121
if auth_code.expires_at < time.time():
119-
return response(
122+
return self.response(
120123
TokenErrorResponse(
121124
error="invalid_grant",
122125
error_description="authorization code has expired",
@@ -126,7 +129,7 @@ async def token_handler(request: Request):
126129
# verify redirect_uri doesn't change between /authorize and /tokens
127130
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6
128131
if token_request.redirect_uri != auth_code.redirect_uri:
129-
return response(
132+
return self.response(
130133
TokenErrorResponse(
131134
error="invalid_request",
132135
error_description=(
@@ -144,28 +147,28 @@ async def token_handler(request: Request):
144147

145148
if hashed_code_verifier != auth_code.code_challenge:
146149
# see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6
147-
return response(
150+
return self.response(
148151
TokenErrorResponse(
149152
error="invalid_grant",
150153
error_description="incorrect code_verifier",
151154
)
152155
)
153156

154157
# Exchange authorization code for tokens
155-
tokens = await provider.exchange_authorization_code(
158+
tokens = await self.provider.exchange_authorization_code(
156159
client_info, auth_code
157160
)
158161

159162
case RefreshTokenRequest():
160-
refresh_token = await provider.load_refresh_token(
163+
refresh_token = await self.provider.load_refresh_token(
161164
client_info, token_request.refresh_token
162165
)
163166
if (
164167
refresh_token is None
165168
or refresh_token.client_id != token_request.client_id
166169
):
167170
# if token belongs to different client, pretend it doesn't exist
168-
return response(
171+
return self.response(
169172
TokenErrorResponse(
170173
error="invalid_grant",
171174
error_description="refresh token does not exist",
@@ -174,7 +177,7 @@ async def token_handler(request: Request):
174177

175178
if refresh_token.expires_at and refresh_token.expires_at < time.time():
176179
# if the refresh token has expired, pretend it doesn't exist
177-
return response(
180+
return self.response(
178181
TokenErrorResponse(
179182
error="invalid_grant",
180183
error_description="refresh token has expired",
@@ -190,20 +193,19 @@ async def token_handler(request: Request):
190193

191194
for scope in scopes:
192195
if scope not in refresh_token.scopes:
193-
return response(
196+
return self.response(
194197
TokenErrorResponse(
195198
error="invalid_scope",
196199
error_description=(
197-
f"cannot request scope `{scope}` not provided by refresh token"
198-
),
200+
f"cannot request scope `{scope}` "
201+
"not provided by refresh token"
202+
),
199203
)
200204
)
201205

202206
# Exchange refresh token for new tokens
203-
tokens = await provider.exchange_refresh_token(
207+
tokens = await self.provider.exchange_refresh_token(
204208
client_info, refresh_token, scopes
205209
)
206210

207-
return response(tokens)
208-
209-
return token_handler
211+
return self.response(tokens)

0 commit comments

Comments
 (0)