Skip to content

Commit ad74aee

Browse files
committed
Improve validation for /token
1 parent 424bfeb commit ad74aee

File tree

5 files changed

+438
-33
lines changed

5 files changed

+438
-33
lines changed

src/mcp/server/auth/handlers/token.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,18 @@ class TokenRequest(RootModel):
4949
Field(discriminator="grant_type"),
5050
]
5151

52-
AUTH_CODE_TTL = 300 # seconds
5352

5453
def create_token_handler(
5554
provider: OAuthServerProvider, client_authenticator: ClientAuthenticator
5655
) -> Callable:
5756
def response(obj: TokenSuccessResponse | TokenErrorResponse):
57+
status_code = 200
58+
if isinstance(obj, TokenErrorResponse):
59+
status_code = 400
60+
5861
return PydanticJSONResponse(
5962
content=obj,
63+
status_code=status_code,
6064
headers={
6165
"Cache-Control": "no-store",
6266
"Pragma": "no-cache",
@@ -98,8 +102,7 @@ async def token_handler(request: Request):
98102

99103
# make auth codes expire after a deadline
100104
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5
101-
expires_at = auth_code.issued_at + AUTH_CODE_TTL
102-
if expires_at < time.time():
105+
if auth_code.expires_at < time.time():
103106
return response(TokenErrorResponse(
104107
error="invalid_grant",
105108
error_description=f"authorization code has expired"
@@ -130,12 +133,34 @@ async def token_handler(request: Request):
130133
)
131134

132135
case RefreshTokenRequest():
136+
refresh_token = await provider.load_refresh_token(client_info, token_request.refresh_token)
137+
if refresh_token is None or refresh_token.client_id != token_request.client_id:
138+
# if the authoriation code belongs to a different client, pretend it doesn't exist
139+
return response(TokenErrorResponse(
140+
error="invalid_grant",
141+
error_description=f"refresh token does not exist"
142+
))
143+
144+
if refresh_token.expires_at and refresh_token.expires_at < time.time():
145+
# if the authoriation code belongs to a different client, pretend it doesn't exist
146+
return response(TokenErrorResponse(
147+
error="invalid_grant",
148+
error_description=f"refresh token has expired"
149+
))
150+
133151
# Parse scopes if provided
134-
scopes = token_request.scope.split(" ") if token_request.scope else None
152+
scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes
153+
154+
for scope in scopes:
155+
if scope not in refresh_token.scopes:
156+
return response(TokenErrorResponse(
157+
error="invalid_scope",
158+
error_description=f"cannot request scope `{scope}` not provided by refresh token"
159+
))
135160

136161
# Exchange refresh token for new tokens
137162
tokens = await provider.exchange_refresh_token(
138-
client_info, token_request.refresh_token, scopes
163+
client_info, refresh_token, scopes
139164
)
140165

141166
return response(tokens)

src/mcp/server/auth/middleware/bearer_auth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class AuthenticatedUser(SimpleUser):
2525
"""User with authentication info."""
2626

2727
def __init__(self, auth_info: AuthInfo):
28-
super().__init__(auth_info.user_id or "anonymous")
28+
super().__init__(auth_info.client_id)
2929
self.auth_info = auth_info
3030
self.scopes = auth_info.scopes
3131

src/mcp/server/auth/provider.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,18 @@ class AuthorizationParams(BaseModel):
3131
class AuthorizationCode(BaseModel):
3232
code: str
3333
scopes: list[str]
34-
issued_at: float
34+
expires_at: float
3535
client_id: str
3636
code_challenge: str
3737
redirect_uri: AnyHttpUrl
38+
39+
class RefreshToken(BaseModel):
40+
token: str
41+
client_id: str
42+
scopes: List[str]
43+
expires_at: Optional[int] = None
44+
45+
3846
class OAuthTokenRevocationRequest(BaseModel):
3947
"""
4048
# See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1
@@ -156,11 +164,14 @@ async def exchange_authorization_code(
156164
"""
157165
...
158166

167+
async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshToken | None:
168+
...
169+
159170
async def exchange_refresh_token(
160171
self,
161172
client: OAuthClientInformationFull,
162-
refresh_token: str,
163-
scopes: Optional[List[str]] = None,
173+
refresh_token: RefreshToken,
174+
scopes: List[str],
164175
) -> TokenSuccessResponse:
165176
"""
166177
Exchanges a refresh token for an access token.

src/mcp/server/auth/types.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,3 @@ class AuthInfo(BaseModel):
2020
client_id: str
2121
scopes: List[str]
2222
expires_at: Optional[int] = None
23-
user_id: Optional[str] = None
24-
25-
class Config:
26-
extra = "ignore"

0 commit comments

Comments
 (0)