Skip to content

Commit b3ce107

Browse files
committed
GH-22: Using starlette like exception handling
That is: - raising starlette.authentication.AuthenticationError - providing an on_error callback turning starlet 400 into 401 to keep same api - letting the user provide their own on_error when instantiating the middleware.
1 parent 089d648 commit b3ce107

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

src/fastapi_oauth2/middleware.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,13 @@
1616
from jose.jwt import encode as jwt_encode
1717
from starlette.authentication import AuthCredentials
1818
from starlette.authentication import AuthenticationBackend
19+
from starlette.authentication import AuthenticationError
1920
from starlette.authentication import BaseUser
2021
from starlette.middleware.authentication import AuthenticationMiddleware
2122
from starlette.requests import Request
23+
from starlette.requests import HTTPConnection
2224
from starlette.responses import PlainTextResponse
25+
from starlette.responses import Response
2326
from starlette.types import ASGIApp
2427
from starlette.types import Receive
2528
from starlette.types import Scope
@@ -111,9 +114,9 @@ async def authenticate(self, request: Request) -> Optional[Tuple[Auth, User]]:
111114
try:
112115
token_data = Auth.jwt_decode(param)
113116
except JOSEError as e:
114-
raise OAuth2AuthenticationError(401, str(e))
117+
raise AuthenticationError(str(e))
115118
if token_data["exp"] and token_data["exp"] < int(datetime.now(timezone.utc).timestamp()):
116-
raise OAuth2AuthenticationError(401, "Token expired")
119+
raise AuthenticationError("Token expired")
117120

118121
user = User(token_data)
119122
auth = Auth(user.pop("scope", []))
@@ -138,6 +141,7 @@ def __init__(
138141
app: ASGIApp,
139142
config: Union[OAuth2Config, dict],
140143
callback: Callable[[Auth, User], Union[Awaitable[None], None]] = None,
144+
on_error: Callable[[HTTPConnection, AuthenticationError], Response] | None = None,
141145
**kwargs, # AuthenticationMiddleware kwargs
142146
) -> None:
143147
"""Initiates the middleware with the given configuration.
@@ -151,9 +155,13 @@ def __init__(
151155
elif not isinstance(config, OAuth2Config):
152156
raise TypeError("config is not a valid type")
153157
self.default_application_middleware = app
154-
self.auth_middleware = AuthenticationMiddleware(app, backend=OAuth2Backend(config, callback), **kwargs)
158+
self.auth_middleware = AuthenticationMiddleware(app, backend=OAuth2Backend(config, callback), on_error = on_error or self.on_error, **kwargs)
155159

156160
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
157161
if scope["type"] == "http":
158162
return await self.auth_middleware(scope, receive, send)
159163
await self.default_application_middleware(scope, receive, send)
164+
165+
@staticmethod
166+
def on_error(conn: HTTPConnection, exc: Exception) -> Response:
167+
return PlainTextResponse(str(exc), status_code=401)

0 commit comments

Comments
 (0)