16
16
from jose .jwt import encode as jwt_encode
17
17
from starlette .authentication import AuthCredentials
18
18
from starlette .authentication import AuthenticationBackend
19
+ from starlette .authentication import AuthenticationError
19
20
from starlette .authentication import BaseUser
20
21
from starlette .middleware .authentication import AuthenticationMiddleware
21
22
from starlette .requests import Request
23
+ from starlette .requests import HTTPConnection
22
24
from starlette .responses import PlainTextResponse
25
+ from starlette .responses import Response
23
26
from starlette .types import ASGIApp
24
27
from starlette .types import Receive
25
28
from starlette .types import Scope
@@ -111,9 +114,9 @@ async def authenticate(self, request: Request) -> Optional[Tuple[Auth, User]]:
111
114
try :
112
115
token_data = Auth .jwt_decode (param )
113
116
except JOSEError as e :
114
- raise OAuth2AuthenticationError ( 401 , str (e ))
117
+ raise AuthenticationError ( str (e ))
115
118
if token_data ["exp" ] and token_data ["exp" ] < int (datetime .now (timezone .utc ).timestamp ()):
116
- raise OAuth2AuthenticationError ( 401 , "Token expired" )
119
+ raise AuthenticationError ( "Token expired" )
117
120
118
121
user = User (token_data )
119
122
auth = Auth (user .pop ("scope" , []))
@@ -138,6 +141,7 @@ def __init__(
138
141
app : ASGIApp ,
139
142
config : Union [OAuth2Config , dict ],
140
143
callback : Callable [[Auth , User ], Union [Awaitable [None ], None ]] = None ,
144
+ on_error : Callable [[HTTPConnection , AuthenticationError ], Response ] | None = None ,
141
145
** kwargs , # AuthenticationMiddleware kwargs
142
146
) -> None :
143
147
"""Initiates the middleware with the given configuration.
@@ -151,9 +155,13 @@ def __init__(
151
155
elif not isinstance (config , OAuth2Config ):
152
156
raise TypeError ("config is not a valid type" )
153
157
self .default_application_middleware = app
154
- self .auth_middleware = AuthenticationMiddleware (app , backend = OAuth2Backend (config , callback ), ** kwargs )
158
+ self .auth_middleware = AuthenticationMiddleware (app , backend = OAuth2Backend (config , callback ), on_error = on_error or self . on_error , ** kwargs )
155
159
156
160
async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
157
161
if scope ["type" ] == "http" :
158
162
return await self .auth_middleware (scope , receive , send )
159
163
await self .default_application_middleware (scope , receive , send )
164
+
165
+ @staticmethod
166
+ def on_error (conn : HTTPConnection , exc : Exception ) -> Response :
167
+ return PlainTextResponse (str (exc ), status_code = 401 )
0 commit comments