Skip to content

Commit 424bfeb

Browse files
committed
Refactor back to authorize()
1 parent e100c9d commit 424bfeb

File tree

5 files changed

+125
-75
lines changed

5 files changed

+125
-75
lines changed

src/mcp/server/auth/handlers/authorize.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
from mcp.server.auth.provider import AuthorizationParams, OAuthServerProvider
2020
from mcp.shared.auth import OAuthClientInformationFull
2121

22+
import logging
23+
24+
logger = logging.getLogger(__name__)
25+
2226

2327
class AuthorizationRequest(BaseModel):
2428
client_id: str = Field(..., description="The client ID")
@@ -122,28 +126,18 @@ async def authorization_handler(request: Request) -> Response:
122126
)
123127

124128
try:
125-
# Let the provider handle the authorization flow
126-
authorization_code = await provider.create_authorization_code(
127-
client, auth_params
128-
)
129+
# Let the provider pick the next URI to redirect to
129130
response = RedirectResponse(
130131
url="", status_code=302, headers={"Cache-Control": "no-store"}
131132
)
132-
133-
# Redirect with code
134-
parsed_uri = urlparse(str(auth_params.redirect_uri))
135-
query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query) for v in vs]
136-
query_params.append(("code", authorization_code))
137-
if auth_params.state:
138-
query_params.append(("state", auth_params.state))
139-
140-
redirect_url = urlunparse(
141-
parsed_uri._replace(query=urlencode(query_params))
133+
response.headers["location"] = await provider.authorize(
134+
client, auth_params
142135
)
143-
response.headers["location"] = redirect_url
144136

145137
return response
146138
except Exception as e:
139+
logger.exception("error from authorize()", exc_info=e)
140+
147141
return RedirectResponse(
148142
url=create_error_redirect(redirect_uri, e, auth_request.state),
149143
status_code=302,

src/mcp/server/auth/handlers/token.py

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
ClientAuthRequest,
2222
)
2323
from mcp.server.auth.provider import OAuthServerProvider
24-
from mcp.shared.auth import OAuthTokens
24+
from mcp.shared.auth import TokenErrorResponse, TokenSuccessResponse
2525

2626

2727
class AuthorizationCodeRequest(ClientAuthRequest):
@@ -54,53 +54,79 @@ class TokenRequest(RootModel):
5454
def create_token_handler(
5555
provider: OAuthServerProvider, client_authenticator: ClientAuthenticator
5656
) -> Callable:
57+
def response(obj: TokenSuccessResponse | TokenErrorResponse):
58+
return PydanticJSONResponse(
59+
content=obj,
60+
headers={
61+
"Cache-Control": "no-store",
62+
"Pragma": "no-cache",
63+
},
64+
)
65+
5766
async def token_handler(request: Request):
5867
try:
5968
form_data = await request.form()
6069
token_request = TokenRequest.model_validate(dict(form_data)).root
61-
except ValidationError as e:
62-
raise InvalidRequestError(f"Invalid request body: {e}")
70+
except ValidationError as validation_error:
71+
return response(TokenErrorResponse(
72+
error="invalid_request",
73+
error_description="\n".join(e['msg'] for e in validation_error.errors())
74+
75+
))
6376
client_info = await client_authenticator(token_request)
6477

6578
if token_request.grant_type not in client_info.grant_types:
66-
raise InvalidRequestError(
67-
f"Unsupported grant type (supported grant types are "
79+
return response(TokenErrorResponse(
80+
error="unsupported_grant_type",
81+
error_description=f"Unsupported grant type (supported grant types are "
6882
f"{client_info.grant_types})"
69-
)
83+
))
7084

71-
tokens: OAuthTokens
85+
tokens: TokenSuccessResponse
7286

7387
match token_request:
7488
case AuthorizationCodeRequest():
75-
auth_code_metadata = await provider.load_authorization_code_metadata(
89+
auth_code = await provider.load_authorization_code(
7690
client_info, token_request.code
7791
)
78-
if auth_code_metadata is None or auth_code_metadata.client_id != token_request.client_id:
79-
raise InvalidRequestError("Invalid authorization code")
92+
if auth_code is None or auth_code.client_id != token_request.client_id:
93+
# if the authoriation code belongs to a different client, pretend it doesn't exist
94+
return response(TokenErrorResponse(
95+
error="invalid_grant",
96+
error_description=f"authorization code does not exist"
97+
))
8098

8199
# make auth codes expire after a deadline
82100
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5
83-
expires_at = auth_code_metadata.issued_at + AUTH_CODE_TTL
101+
expires_at = auth_code.issued_at + AUTH_CODE_TTL
84102
if expires_at < time.time():
85-
raise InvalidRequestError("authorization code has expired")
103+
return response(TokenErrorResponse(
104+
error="invalid_grant",
105+
error_description=f"authorization code has expired"
106+
))
86107

87108
# verify redirect_uri doesn't change between /authorize and /tokens
88109
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6
89-
if token_request.redirect_uri != auth_code_metadata.redirect_uri:
90-
raise InvalidRequestError("redirect_uri did not match redirect_uri used when authorization code was created")
110+
if token_request.redirect_uri != auth_code.redirect_uri:
111+
return response(TokenErrorResponse(
112+
error="invalid_request",
113+
error_description=f"redirect_uri did not match redirect_uri used when authorization code was created"
114+
))
91115

92116
# Verify PKCE code verifier
93117
sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest()
94118
hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=")
95119

96-
if hashed_code_verifier != auth_code_metadata.code_challenge:
97-
raise InvalidRequestError(
98-
"code_verifier does not match the challenge"
99-
)
120+
if hashed_code_verifier != auth_code.code_challenge:
121+
# see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6
122+
return response(TokenErrorResponse(
123+
error="invalid_grant",
124+
error_description=f"incorrect code_verifier"
125+
))
100126

101127
# Exchange authorization code for tokens
102128
tokens = await provider.exchange_authorization_code(
103-
client_info, token_request.code
129+
client_info, auth_code
104130
)
105131

106132
case RefreshTokenRequest():
@@ -112,12 +138,6 @@ async def token_handler(request: Request):
112138
client_info, token_request.refresh_token, scopes
113139
)
114140

115-
return PydanticJSONResponse(
116-
content=tokens,
117-
headers={
118-
"Cache-Control": "no-store",
119-
"Pragma": "no-cache",
120-
},
121-
)
141+
return response(tokens)
122142

123143
return token_handler

src/mcp/server/auth/provider.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
"""
66

77
from typing import List, Literal, Optional, Protocol
8+
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
89

9-
from pydantic import AnyHttpUrl, BaseModel
10+
from pydantic import AnyHttpUrl, AnyUrl, BaseModel
1011

1112
from mcp.server.auth.types import AuthInfo
1213
from mcp.shared.auth import (
1314
OAuthClientInformationFull,
14-
OAuthTokens,
15+
TokenSuccessResponse,
1516
)
1617

1718

@@ -27,7 +28,9 @@ class AuthorizationParams(BaseModel):
2728
code_challenge: str
2829
redirect_uri: AnyHttpUrl
2930

30-
class AuthorizationCodeMeta(BaseModel):
31+
class AuthorizationCode(BaseModel):
32+
code: str
33+
scopes: list[str]
3134
issued_at: float
3235
client_id: str
3336
code_challenge: str
@@ -88,12 +91,33 @@ def clients_store(self) -> OAuthRegisteredClientsStore:
8891
"""
8992
...
9093

91-
async def create_authorization_code(
94+
async def authorize(
9295
self, client: OAuthClientInformationFull, params: AuthorizationParams
9396
) -> str:
9497
"""
95-
Generates and stores an authorization code as part of completing the /authorize
96-
OAuth step.
98+
Called as part of the /authorize endpoint, and returns a URL that the client
99+
will be redirected to.
100+
Many MCP implementations will redirect to a third-party provider to perform
101+
a second OAuth exchange with that provider. In this sort of setup, the client
102+
has an OAuth connection with the MCP server, and the MCP server has an OAuth
103+
connection with the 3rd-party provider. At the end of this flow, the client
104+
should be redirected to the redirect_uri from params.redirect_uri.
105+
106+
+--------+ +------------+ +-------------------+
107+
| | | | | |
108+
| Client | --> | MCP Server | --> | 3rd Party OAuth |
109+
| | | | | Server |
110+
+--------+ +------------+ +-------------------+
111+
| ^ |
112+
+------------+ | | |
113+
| | | | Redirect |
114+
|redirect_uri|<-----+ +------------------+
115+
| |
116+
+------------+
117+
118+
Implementations will need to define another handler on the MCP server return
119+
flow to perform the second redirect, and generates and stores an authorization
120+
code as part of completing the OAuth authorization step.
97121
98122
Implementations SHOULD generate an authorization code with at least 160 bits of
99123
entropy,
@@ -102,9 +126,9 @@ async def create_authorization_code(
102126
"""
103127
...
104128

105-
async def load_authorization_code_metadata(
129+
async def load_authorization_code(
106130
self, client: OAuthClientInformationFull, authorization_code: str
107-
) -> AuthorizationCodeMeta | None:
131+
) -> AuthorizationCode | None:
108132
"""
109133
Loads metadata for the authorization code challenge.
110134
@@ -118,8 +142,8 @@ async def load_authorization_code_metadata(
118142
...
119143

120144
async def exchange_authorization_code(
121-
self, client: OAuthClientInformationFull, authorization_code: str
122-
) -> OAuthTokens:
145+
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCode
146+
) -> TokenSuccessResponse:
123147
"""
124148
Exchanges an authorization code for an access token.
125149
@@ -137,7 +161,7 @@ async def exchange_refresh_token(
137161
client: OAuthClientInformationFull,
138162
refresh_token: str,
139163
scopes: Optional[List[str]] = None,
140-
) -> OAuthTokens:
164+
) -> TokenSuccessResponse:
141165
"""
142166
Exchanges a refresh token for an access token.
143167
@@ -178,3 +202,15 @@ async def revoke_token(
178202
request: The token revocation request.
179203
"""
180204
...
205+
206+
def construct_redirect_uri(redirect_uri_base: str, authorization_code: AuthorizationCode, state: Optional[str]) -> str:
207+
parsed_uri = urlparse(redirect_uri_base)
208+
query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query) for v in vs]
209+
query_params.append(("code", authorization_code.code))
210+
if state:
211+
query_params.append(("state", state))
212+
213+
redirect_uri = urlunparse(
214+
parsed_uri._replace(query=urlencode(query_params))
215+
)
216+
return redirect_uri

src/mcp/shared/auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pydantic import AnyHttpUrl, BaseModel, Field
1010

1111

12-
class OAuthErrorResponse(BaseModel):
12+
class TokenErrorResponse(BaseModel):
1313
"""
1414
See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
1515
"""
@@ -19,7 +19,7 @@ class OAuthErrorResponse(BaseModel):
1919
error_uri: Optional[AnyHttpUrl] = None
2020

2121

22-
class OAuthTokens(BaseModel):
22+
class TokenSuccessResponse(BaseModel):
2323
"""
2424
See https://datatracker.ietf.org/doc/html/rfc6749#section-5.1
2525
"""

0 commit comments

Comments
 (0)