2121 ClientAuthRequest ,
2222)
2323from mcp .server .auth .provider import OAuthServerProvider
24- from mcp .shared .auth import OAuthTokens
24+ from mcp .shared .auth import TokenErrorResponse , TokenSuccessResponse
2525
2626
2727class AuthorizationCodeRequest (ClientAuthRequest ):
@@ -54,53 +54,79 @@ class TokenRequest(RootModel):
5454def create_token_handler (
5555 provider : OAuthServerProvider , client_authenticator : ClientAuthenticator
5656) -> Callable :
57+ def response (obj : TokenSuccessResponse | TokenErrorResponse ):
58+ return PydanticJSONResponse (
59+ content = obj ,
60+ headers = {
61+ "Cache-Control" : "no-store" ,
62+ "Pragma" : "no-cache" ,
63+ },
64+ )
65+
5766 async def token_handler (request : Request ):
5867 try :
5968 form_data = await request .form ()
6069 token_request = TokenRequest .model_validate (dict (form_data )).root
61- except ValidationError as e :
62- raise InvalidRequestError (f"Invalid request body: { e } " )
70+ except ValidationError as validation_error :
71+ return response (TokenErrorResponse (
72+ error = "invalid_request" ,
73+ error_description = "\n " .join (e ['msg' ] for e in validation_error .errors ())
74+
75+ ))
6376 client_info = await client_authenticator (token_request )
6477
6578 if token_request .grant_type not in client_info .grant_types :
66- raise InvalidRequestError (
67- f"Unsupported grant type (supported grant types are "
79+ return response (TokenErrorResponse (
80+ error = "unsupported_grant_type" ,
81+ error_description = f"Unsupported grant type (supported grant types are "
6882 f"{ client_info .grant_types } )"
69- )
83+ ))
7084
71- tokens : OAuthTokens
85+ tokens : TokenSuccessResponse
7286
7387 match token_request :
7488 case AuthorizationCodeRequest ():
75- auth_code_metadata = await provider .load_authorization_code_metadata (
89+ auth_code = await provider .load_authorization_code (
7690 client_info , token_request .code
7791 )
78- if auth_code_metadata is None or auth_code_metadata .client_id != token_request .client_id :
79- raise InvalidRequestError ("Invalid authorization code" )
92+ if auth_code is None or auth_code .client_id != token_request .client_id :
93+ # if the authoriation code belongs to a different client, pretend it doesn't exist
94+ return response (TokenErrorResponse (
95+ error = "invalid_grant" ,
96+ error_description = f"authorization code does not exist"
97+ ))
8098
8199 # make auth codes expire after a deadline
82100 # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5
83- expires_at = auth_code_metadata .issued_at + AUTH_CODE_TTL
101+ expires_at = auth_code .issued_at + AUTH_CODE_TTL
84102 if expires_at < time .time ():
85- raise InvalidRequestError ("authorization code has expired" )
103+ return response (TokenErrorResponse (
104+ error = "invalid_grant" ,
105+ error_description = f"authorization code has expired"
106+ ))
86107
87108 # verify redirect_uri doesn't change between /authorize and /tokens
88109 # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6
89- if token_request .redirect_uri != auth_code_metadata .redirect_uri :
90- raise InvalidRequestError ("redirect_uri did not match redirect_uri used when authorization code was created" )
110+ if token_request .redirect_uri != auth_code .redirect_uri :
111+ return response (TokenErrorResponse (
112+ error = "invalid_request" ,
113+ error_description = f"redirect_uri did not match redirect_uri used when authorization code was created"
114+ ))
91115
92116 # Verify PKCE code verifier
93117 sha256 = hashlib .sha256 (token_request .code_verifier .encode ()).digest ()
94118 hashed_code_verifier = base64 .urlsafe_b64encode (sha256 ).decode ().rstrip ("=" )
95119
96- if hashed_code_verifier != auth_code_metadata .code_challenge :
97- raise InvalidRequestError (
98- "code_verifier does not match the challenge"
99- )
120+ if hashed_code_verifier != auth_code .code_challenge :
121+ # see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6
122+ return response (TokenErrorResponse (
123+ error = "invalid_grant" ,
124+ error_description = f"incorrect code_verifier"
125+ ))
100126
101127 # Exchange authorization code for tokens
102128 tokens = await provider .exchange_authorization_code (
103- client_info , token_request . code
129+ client_info , auth_code
104130 )
105131
106132 case RefreshTokenRequest ():
@@ -112,12 +138,6 @@ async def token_handler(request: Request):
112138 client_info , token_request .refresh_token , scopes
113139 )
114140
115- return PydanticJSONResponse (
116- content = tokens ,
117- headers = {
118- "Cache-Control" : "no-store" ,
119- "Pragma" : "no-cache" ,
120- },
121- )
141+ return response (tokens )
122142
123143 return token_handler
0 commit comments