77import base64
88import hashlib
99import time
10- from typing import Annotated , Callable , Literal
10+ from dataclasses import dataclass
11+ from typing import Annotated , Literal
1112
1213from pydantic import AnyHttpUrl , Field , RootModel , ValidationError
1314from starlette .requests import Request
@@ -52,10 +53,12 @@ class TokenRequest(RootModel):
5253 ]
5354
5455
55- def create_token_handler (
56- provider : OAuthServerProvider , client_authenticator : ClientAuthenticator
57- ) -> Callable :
58- def response (obj : TokenSuccessResponse | TokenErrorResponse | ErrorResponse ):
56+ @dataclass
57+ class TokenHandler :
58+ provider : OAuthServerProvider
59+ client_authenticator : ClientAuthenticator
60+
61+ def response (self , obj : TokenSuccessResponse | TokenErrorResponse | ErrorResponse ):
5962 status_code = 200
6063 if isinstance (obj , TokenErrorResponse ):
6164 status_code = 400
@@ -69,25 +72,25 @@ def response(obj: TokenSuccessResponse | TokenErrorResponse | ErrorResponse):
6972 },
7073 )
7174
72- async def token_handler ( request : Request ):
75+ async def handle ( self , request : Request ):
7376 try :
7477 form_data = await request .form ()
7578 token_request = TokenRequest .model_validate (dict (form_data )).root
7679 except ValidationError as validation_error :
77- return response (
80+ return self . response (
7881 TokenErrorResponse (
7982 error = "invalid_request" ,
8083 error_description = stringify_pydantic_error (validation_error ),
8184 )
8285 )
8386
8487 try :
85- client_info = await client_authenticator (token_request )
88+ client_info = await self . client_authenticator (token_request )
8689 except InvalidClientError as e :
87- return response (e .error_response ())
90+ return self . response (e .error_response ())
8891
8992 if token_request .grant_type not in client_info .grant_types :
90- return response (
93+ return self . response (
9194 TokenErrorResponse (
9295 error = "unsupported_grant_type" ,
9396 error_description = (
@@ -101,12 +104,12 @@ async def token_handler(request: Request):
101104
102105 match token_request :
103106 case AuthorizationCodeRequest ():
104- auth_code = await provider .load_authorization_code (
107+ auth_code = await self . provider .load_authorization_code (
105108 client_info , token_request .code
106109 )
107110 if auth_code is None or auth_code .client_id != token_request .client_id :
108111 # if code belongs to different client, pretend it doesn't exist
109- return response (
112+ return self . response (
110113 TokenErrorResponse (
111114 error = "invalid_grant" ,
112115 error_description = "authorization code does not exist" ,
@@ -116,7 +119,7 @@ async def token_handler(request: Request):
116119 # make auth codes expire after a deadline
117120 # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5
118121 if auth_code .expires_at < time .time ():
119- return response (
122+ return self . response (
120123 TokenErrorResponse (
121124 error = "invalid_grant" ,
122125 error_description = "authorization code has expired" ,
@@ -126,7 +129,7 @@ async def token_handler(request: Request):
126129 # verify redirect_uri doesn't change between /authorize and /tokens
127130 # see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6
128131 if token_request .redirect_uri != auth_code .redirect_uri :
129- return response (
132+ return self . response (
130133 TokenErrorResponse (
131134 error = "invalid_request" ,
132135 error_description = (
@@ -144,28 +147,28 @@ async def token_handler(request: Request):
144147
145148 if hashed_code_verifier != auth_code .code_challenge :
146149 # see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6
147- return response (
150+ return self . response (
148151 TokenErrorResponse (
149152 error = "invalid_grant" ,
150153 error_description = "incorrect code_verifier" ,
151154 )
152155 )
153156
154157 # Exchange authorization code for tokens
155- tokens = await provider .exchange_authorization_code (
158+ tokens = await self . provider .exchange_authorization_code (
156159 client_info , auth_code
157160 )
158161
159162 case RefreshTokenRequest ():
160- refresh_token = await provider .load_refresh_token (
163+ refresh_token = await self . provider .load_refresh_token (
161164 client_info , token_request .refresh_token
162165 )
163166 if (
164167 refresh_token is None
165168 or refresh_token .client_id != token_request .client_id
166169 ):
167170 # if token belongs to different client, pretend it doesn't exist
168- return response (
171+ return self . response (
169172 TokenErrorResponse (
170173 error = "invalid_grant" ,
171174 error_description = "refresh token does not exist" ,
@@ -174,7 +177,7 @@ async def token_handler(request: Request):
174177
175178 if refresh_token .expires_at and refresh_token .expires_at < time .time ():
176179 # if the refresh token has expired, pretend it doesn't exist
177- return response (
180+ return self . response (
178181 TokenErrorResponse (
179182 error = "invalid_grant" ,
180183 error_description = "refresh token has expired" ,
@@ -190,20 +193,19 @@ async def token_handler(request: Request):
190193
191194 for scope in scopes :
192195 if scope not in refresh_token .scopes :
193- return response (
196+ return self . response (
194197 TokenErrorResponse (
195198 error = "invalid_scope" ,
196199 error_description = (
197- f"cannot request scope `{ scope } ` not provided by refresh token"
198- ),
200+ f"cannot request scope `{ scope } ` "
201+ "not provided by refresh token"
202+ ),
199203 )
200204 )
201205
202206 # Exchange refresh token for new tokens
203- tokens = await provider .exchange_refresh_token (
207+ tokens = await self . provider .exchange_refresh_token (
204208 client_info , refresh_token , scopes
205209 )
206210
207- return response (tokens )
208-
209- return token_handler
211+ return self .response (tokens )
0 commit comments