Skip to content

Commit 5f60679

Browse files
committed
Adjust more things to fit spec
1 parent c994eb2 commit 5f60679

File tree

6 files changed

+65
-75
lines changed

6 files changed

+65
-75
lines changed

src/mcp/server/auth/handlers/authorize.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,7 @@ class AuthorizationRequest(BaseModel):
3939
description="Optional scope; if specified, should be "
4040
"a space-separated list of scope strings",
4141
)
42-
43-
class Config:
44-
extra = "ignore"
42+
4543

4644

4745
def validate_scope(

src/mcp/server/auth/handlers/revoke.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
Corresponds to TypeScript file: src/server/auth/handlers/revoke.ts
55
"""
66

7-
from typing import Callable
7+
from typing import Callable, Optional
88

9-
from pydantic import ValidationError
9+
from pydantic import BaseModel, ValidationError
1010
from starlette.requests import Request
1111
from starlette.responses import Response
1212

@@ -17,8 +17,8 @@
1717
ClientAuthenticator,
1818
ClientAuthRequest,
1919
)
20-
from mcp.server.auth.provider import OAuthServerProvider
21-
from mcp.shared.auth import OAuthTokenRevocationRequest
20+
from mcp.server.auth.provider import OAuthServerProvider, OAuthTokenRevocationRequest
21+
2222

2323

2424
class RevocationRequest(OAuthTokenRevocationRequest, ClientAuthRequest):
@@ -28,18 +28,6 @@ class RevocationRequest(OAuthTokenRevocationRequest, ClientAuthRequest):
2828
def create_revocation_handler(
2929
provider: OAuthServerProvider, client_authenticator: ClientAuthenticator
3030
) -> Callable:
31-
"""
32-
Create a handler for OAuth 2.0 Token Revocation.
33-
34-
Corresponds to revocationHandler in src/server/auth/handlers/revoke.ts
35-
36-
Args:
37-
provider: The OAuth server provider
38-
39-
Returns:
40-
A Starlette endpoint handler function
41-
"""
42-
4331
async def revocation_handler(request: Request) -> Response:
4432
"""
4533
Handler for the OAuth 2.0 Token Revocation endpoint.

src/mcp/server/auth/handlers/token.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77
import base64
88
import hashlib
9+
import time
910
from typing import Annotated, Callable, Literal, Optional, Union
1011

11-
from pydantic import Field, RootModel, ValidationError
12+
from pydantic import AnyHttpUrl, Field, RootModel, ValidationError
1213
from starlette.requests import Request
1314

1415
from mcp.server.auth.errors import (
@@ -24,13 +25,19 @@
2425

2526

2627
class AuthorizationCodeRequest(ClientAuthRequest):
28+
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
2729
grant_type: Literal["authorization_code"]
2830
code: str = Field(..., description="The authorization code")
31+
redirect_uri: AnyHttpUrl | None = Field(
32+
..., description="Must be the same as redirect URI provided in /authorize"
33+
)
34+
client_id: str
35+
# See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5
2936
code_verifier: str = Field(..., description="PKCE code verifier")
30-
# TODO: this should take redirect_uri
3137

3238

3339
class RefreshTokenRequest(ClientAuthRequest):
40+
# See https://datatracker.ietf.org/doc/html/rfc6749#section-6
3441
grant_type: Literal["refresh_token"]
3542
refresh_token: str = Field(..., description="The refresh token")
3643
scope: Optional[str] = Field(None, description="Optional scope parameter")
@@ -42,7 +49,7 @@ class TokenRequest(RootModel):
4249
Field(discriminator="grant_type"),
4350
]
4451

45-
52+
AUTH_CODE_TTL = 300 # seconds
4653

4754
def create_token_handler(
4855
provider: OAuthServerProvider, client_authenticator: ClientAuthenticator
@@ -65,22 +72,28 @@ async def token_handler(request: Request):
6572

6673
match token_request:
6774
case AuthorizationCodeRequest():
68-
# TODO: verify that the redirect URIs match
69-
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6
70-
# TODO: enforce TTL on the authorization code
71-
72-
# Verify PKCE code verifier
73-
expected_challenge = await provider.challenge_for_authorization_code(
75+
auth_code_metadata = await provider.load_authorization_code_metadata(
7476
client_info, token_request.code
7577
)
76-
if expected_challenge is None:
78+
if auth_code_metadata is None or auth_code_metadata.client_id != token_request.client_id:
7779
raise InvalidRequestError("Invalid authorization code")
7880

79-
# Calculate challenge from verifier
81+
# make auth codes expire after a deadline
82+
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5
83+
expires_at = auth_code_metadata.issued_at + AUTH_CODE_TTL
84+
if expires_at < time.time():
85+
raise InvalidRequestError("authorization code has expired")
86+
87+
# verify redirect_uri doesn't change between /authorize and /tokens
88+
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6
89+
if token_request.redirect_uri != auth_code_metadata.redirect_uri:
90+
raise InvalidRequestError("redirect_uri did not match redirect_uri used when authorization code was created")
91+
92+
# Verify PKCE code verifier
8093
sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest()
81-
actual_challenge = base64.urlsafe_b64encode(sha256).decode().rstrip("=")
94+
hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=")
8295

83-
if actual_challenge != expected_challenge:
96+
if hashed_code_verifier != auth_code_metadata.code_challenge:
8497
raise InvalidRequestError(
8598
"code_verifier does not match the challenge"
8699
)

src/mcp/server/auth/provider.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Corresponds to TypeScript file: src/server/auth/provider.ts
55
"""
66

7-
from typing import List, Optional, Protocol
7+
from typing import List, Literal, Optional, Protocol
88

99
from pydantic import AnyHttpUrl, BaseModel
1010

@@ -28,6 +28,18 @@ class AuthorizationParams(BaseModel):
2828
code_challenge: str
2929
redirect_uri: AnyHttpUrl
3030

31+
class AuthorizationCodeMeta(BaseModel):
32+
issued_at: float
33+
client_id: str
34+
code_challenge: str
35+
redirect_uri: AnyHttpUrl
36+
class OAuthTokenRevocationRequest(BaseModel):
37+
"""
38+
# See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1
39+
"""
40+
41+
token: str
42+
token_type_hint: Optional[Literal["access_token", "refresh_token"]] = None
3143

3244
class OAuthRegisteredClientsStore(Protocol):
3345
"""
@@ -91,11 +103,11 @@ async def create_authorization_code(
91103
"""
92104
...
93105

94-
async def challenge_for_authorization_code(
106+
async def load_authorization_code_metadata(
95107
self, client: OAuthClientInformationFull, authorization_code: str
96-
) -> str | None:
108+
) -> AuthorizationCodeMeta | None:
97109
"""
98-
Returns the code_challenge that was used when the indicated authorization began.
110+
Loads metadata for the authorization code challenge.
99111
100112
Args:
101113
client: The client that requested the authorization code.

src/mcp/shared/auth.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,21 @@
1111

1212
class OAuthErrorResponse(BaseModel):
1313
"""
14-
OAuth 2.1 error response.
15-
16-
Corresponds to OAuthErrorResponseSchema in src/shared/auth.ts
14+
See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
1715
"""
1816

19-
error: str
17+
error: Literal["invalid_request", "invalid_client", "invalid_grant", "unauthorized_client", "unsupported_grant_type", "invalid_scope"]
2018
error_description: Optional[str] = None
2119
error_uri: Optional[AnyHttpUrl] = None
2220

2321

2422
class OAuthTokens(BaseModel):
2523
"""
26-
OAuth 2.1 token response.
27-
28-
Corresponds to OAuthTokensSchema in src/shared/auth.ts
24+
See https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
2925
"""
3026

3127
access_token: str
32-
token_type: str
28+
token_type: Literal["bearer"] = "bearer"
3329
expires_in: Optional[int] = None
3430
scope: Optional[str] = None
3531
refresh_token: Optional[str] = None

tests/server/fastmcp/auth/test_auth_integration.py

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919

2020
from mcp.server.auth.errors import InvalidTokenError
2121
from mcp.server.auth.provider import (
22+
AuthorizationCodeMeta,
2223
AuthorizationParams,
2324
OAuthRegisteredClientsStore,
2425
OAuthServerProvider,
26+
OAuthTokenRevocationRequest,
2527
)
2628
from mcp.server.auth.router import (
2729
ClientRegistrationOptions,
@@ -32,7 +34,6 @@
3234
from mcp.server.fastmcp import FastMCP
3335
from mcp.shared.auth import (
3436
OAuthClientInformationFull,
35-
OAuthTokenRevocationRequest,
3637
OAuthTokens,
3738
)
3839
from mcp.types import JSONRPCRequest
@@ -74,32 +75,19 @@ async def create_authorization_code(
7475
code = f"code_{int(time.time())}"
7576

7677
# Store the code for later verification
77-
self.auth_codes[code] = {
78-
"client_id": client.client_id,
79-
"code_challenge": params.code_challenge,
80-
"redirect_uri": params.redirect_uri,
81-
"expires_at": int(time.time()) + 600, # 10 minutes
82-
}
78+
self.auth_codes[code] = AuthorizationCodeMeta(
79+
client_id= client.client_id,
80+
code_challenge= params.code_challenge,
81+
redirect_uri= params.redirect_uri,
82+
issued_at= time.time(),
83+
)
8384

8485
return code
8586

86-
async def challenge_for_authorization_code(
87+
async def load_authorization_code_metadata(
8788
self, client: OAuthClientInformationFull, authorization_code: str
88-
) -> str:
89-
# Get the stored code info
90-
code_info = self.auth_codes.get(authorization_code)
91-
if not code_info:
92-
raise InvalidTokenError("Invalid authorization code")
93-
94-
# Check if code is expired
95-
if code_info["expires_at"] < int(time.time()):
96-
raise InvalidTokenError("Authorization code has expired")
97-
98-
# Check if the code was issued to this client
99-
if code_info["client_id"] != client.client_id:
100-
raise InvalidTokenError("Authorization code was not issued to this client")
101-
102-
return code_info["code_challenge"]
89+
) -> AuthorizationCodeMeta | None:
90+
return self.auth_codes.get(authorization_code)
10391

10492
async def exchange_authorization_code(
10593
self, client: OAuthClientInformationFull, authorization_code: str
@@ -109,14 +97,6 @@ async def exchange_authorization_code(
10997
if not code_info:
11098
raise InvalidTokenError("Invalid authorization code")
11199

112-
# Check if code is expired
113-
if code_info["expires_at"] < int(time.time()):
114-
raise InvalidTokenError("Authorization code has expired")
115-
116-
# Check if the code was issued to this client
117-
if code_info["client_id"] != client.client_id:
118-
raise InvalidTokenError("Authorization code was not issued to this client")
119-
120100
# Generate an access token and refresh token
121101
access_token = f"access_{secrets.token_hex(32)}"
122102
refresh_token = f"refresh_{secrets.token_hex(32)}"
@@ -436,6 +416,7 @@ async def test_authorization_flow(
436416
"client_secret": client_info["client_secret"],
437417
"code": auth_code,
438418
"code_verifier": code_verifier,
419+
"redirect_uri": "https://client.example.com/callback",
439420
},
440421
)
441422
assert response.status_code == 200
@@ -465,6 +446,7 @@ async def test_authorization_flow(
465446
"client_id": client_info["client_id"],
466447
"client_secret": client_info["client_secret"],
467448
"refresh_token": refresh_token,
449+
"redirect_uri": "https://client.example.com/callback",
468450
},
469451
)
470452
assert response.status_code == 200
@@ -585,6 +567,7 @@ def test_tool(x: int) -> str:
585567
"client_secret": client_info["client_secret"],
586568
"code": auth_code,
587569
"code_verifier": code_verifier,
570+
"redirect_uri": "https://client.example.com/callback",
588571
},
589572
)
590573
assert response.status_code == 200

0 commit comments

Comments
 (0)