diff --git a/docs/integration/integration.md b/docs/integration/integration.md index ca53dc1..30a2654 100644 --- a/docs/integration/integration.md +++ b/docs/integration/integration.md @@ -11,8 +11,9 @@ section covers its integration into a FastAPI app. The `OAuth2Middleware` is an authentication middleware which means that its usage makes the `user` and `auth` attributes available in the [request](https://www.starlette.io/requests/) context. It has a mandatory argument `config` of -[`OAuth2Config`](/integration/configuration#oauth2config) instance that has been discussed at the previous section and -an optional argument `callback` which is a callable that is called when the authentication succeeds. +[`OAuth2Config`](/integration/configuration#oauth2config) instance that has been discussed in the previous section and +optional arguments `callback` and `on_error` that accept callables as values and are called when the authentication +succeeds and fails correspondingly. ```python app: FastAPI @@ -20,10 +21,14 @@ app: FastAPI def on_auth_success(auth: Auth, user: User): """This could be async function as well.""" +def on_auth_error(conn: HTTPConnection, exc: Exception) -> Response: + return JSONResponse({"detail": str(exc)}, status_code=400) + app.add_middleware( OAuth2Middleware, config=OAuth2Config(...), callback=on_auth_success, + on_error=on_auth_error, ) ``` diff --git a/docs/references/tutorials.md b/docs/references/tutorials.md index 1ce1128..a49531a 100644 --- a/docs/references/tutorials.md +++ b/docs/references/tutorials.md @@ -115,7 +115,7 @@ async def error_handler(request: Request, exc: OAuth2AuthenticationError): return RedirectResponse(url="/login", status_code=303) ``` -The complete list of exceptions is the following. +The complete list of exceptions raised by the middleware is the following. - `OAuth2Error` - Base exception for all errors raised by the FastAPI OAuth2 library. - `OAuth2AuthenticationError` - An exception is raised when the authentication fails. diff --git a/src/fastapi_oauth2/__init__.py b/src/fastapi_oauth2/__init__.py index 5becc17..6849410 100644 --- a/src/fastapi_oauth2/__init__.py +++ b/src/fastapi_oauth2/__init__.py @@ -1 +1 @@ -__version__ = "1.0.0" +__version__ = "1.1.0" diff --git a/src/fastapi_oauth2/middleware.py b/src/fastapi_oauth2/middleware.py index 8481947..76ee47e 100644 --- a/src/fastapi_oauth2/middleware.py +++ b/src/fastapi_oauth2/middleware.py @@ -16,10 +16,12 @@ from jose.jwt import encode as jwt_encode from starlette.authentication import AuthCredentials from starlette.authentication import AuthenticationBackend +from starlette.authentication import AuthenticationError from starlette.authentication import BaseUser from starlette.middleware.authentication import AuthenticationMiddleware +from starlette.requests import HTTPConnection from starlette.requests import Request -from starlette.responses import PlainTextResponse +from starlette.responses import Response from starlette.types import ASGIApp from starlette.types import Receive from starlette.types import Scope @@ -28,7 +30,6 @@ from .claims import Claims from .config import OAuth2Config from .core import OAuth2Core -from .exceptions import OAuth2AuthenticationError class Auth(AuthCredentials): @@ -108,9 +109,12 @@ async def authenticate(self, request: Request) -> Optional[Tuple[Auth, User]]: if not scheme or not param: return Auth(), User() - token_data = Auth.jwt_decode(param) + try: + token_data = Auth.jwt_decode(param) + except JOSEError as e: + raise AuthenticationError(str(e)) if token_data["exp"] and token_data["exp"] < int(datetime.now(timezone.utc).timestamp()): - raise OAuth2AuthenticationError(401, "Token expired") + raise AuthenticationError("Token expired") user = User(token_data) auth = Auth(user.pop("scope", [])) @@ -135,7 +139,7 @@ def __init__( app: ASGIApp, config: Union[OAuth2Config, dict], callback: Callable[[Auth, User], Union[Awaitable[None], None]] = None, - **kwargs, # AuthenticationMiddleware kwargs + on_error: Optional[Callable[[HTTPConnection, AuthenticationError], Response]] = None, ) -> None: """Initiates the middleware with the given configuration. @@ -148,13 +152,10 @@ def __init__( elif not isinstance(config, OAuth2Config): raise TypeError("config is not a valid type") self.default_application_middleware = app - self.auth_middleware = AuthenticationMiddleware(app, backend=OAuth2Backend(config, callback), **kwargs) + on_error = on_error or AuthenticationMiddleware.default_on_error + self.auth_middleware = AuthenticationMiddleware(app, backend=OAuth2Backend(config, callback), on_error=on_error) async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "http": - try: - return await self.auth_middleware(scope, receive, send) - except (JOSEError, Exception) as e: - middleware = PlainTextResponse(str(e), status_code=401) - return await middleware(scope, receive, send) + return await self.auth_middleware(scope, receive, send) await self.default_application_middleware(scope, receive, send) diff --git a/tests/test_middleware.py b/tests/test_middleware.py index e33c6b7..2302d9b 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -1,5 +1,7 @@ import pytest +from fastapi.responses import JSONResponse from httpx import AsyncClient +from jose import jwt @pytest.mark.anyio @@ -26,3 +28,68 @@ async def test_middleware_on_logout(get_app): response = await client.get("/user") assert response.status_code == 403 # Forbidden + + +@pytest.mark.anyio +async def test_middleware_do_not_interfere_user_errors(get_app): + app = get_app() + + @app.get("/unexpected_error") + def my_entry_point(): + raise NameError # Intended code error + + async with AsyncClient(app=app, base_url="http://test") as client: + with pytest.raises(NameError): + await client.get("/unexpected_error") + + +@pytest.mark.anyio +async def test_middleware_ignores_custom_exceptions(get_app): + class MyCustomException(Exception): + pass + + app = get_app() + + @app.get("/custom_exception") + def my_entry_point(): + raise MyCustomException() + + async with AsyncClient(app=app, base_url="http://test") as client: + with pytest.raises(MyCustomException): + await client.get("/custom_exception") + + +@pytest.mark.anyio +async def test_middleware_ignores_handled_custom_exceptions(get_app): + class MyHandledException(Exception): + pass + + app = get_app() + + @app.exception_handler(MyHandledException) + async def unicorn_exception_handler(request, exc): + return JSONResponse( + status_code=418, + content={"details": "I am a custom Teapot!"}, + ) + + @app.get("/handled_exception") + def my_entry_point(): + raise MyHandledException() + + async with AsyncClient(app=app, base_url="http://test") as client: + response = await client.get("/handled_exception") + assert response.status_code == 418 # I am a teapot! + assert response.json() == {"details": "I am a custom Teapot!"} + + +@pytest.mark.anyio +async def test_middleware_reports_invalid_jwt(get_app): + async with AsyncClient(app=get_app(with_ssr=False), base_url="http://test") as client: + # Insert a bad token instead + badtoken = jwt.encode({"bad": "token"}, "badsecret", "HS256") + client.cookies.update(dict(Authorization=f"Bearer: {badtoken}")) + + response = await client.get("/user") + assert response.status_code == 400 + assert response.text == "Signature verification failed."