|
1 | | -# SPDX-FileCopyrightText: 2023-2024 MTS PJSC |
| 1 | +# SPDX-FileCopyrightText: 2023-2025 MTS PJSC |
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | 3 | import logging |
4 | | -from typing import Any |
| 4 | +from typing import Annotated, Any |
5 | 5 |
|
6 | | -from fastapi import Request |
| 6 | +from fastapi import Depends, FastAPI, Request |
7 | 7 |
|
8 | 8 | from syncmaster.db.models import User |
9 | 9 | from syncmaster.exceptions import EntityNotFoundError |
10 | 10 | from syncmaster.exceptions.auth import AuthorizationError |
| 11 | +from syncmaster.server.dependencies import Stub |
| 12 | +from syncmaster.server.providers.auth.base_provider import AuthProvider |
11 | 13 | from syncmaster.server.providers.auth.keycloak_provider import ( |
12 | 14 | KeycloakAuthProvider, |
13 | 15 | KeycloakOperationError, |
14 | 16 | ) |
| 17 | +from syncmaster.server.services.unit_of_work import UnitOfWork |
| 18 | +from syncmaster.server.settings.auth.oauth2_gateway import OAuth2GatewayProviderSettings |
15 | 19 |
|
16 | 20 | log = logging.getLogger(__name__) |
17 | 21 |
|
18 | 22 |
|
19 | 23 | class OAuth2GatewayProvider(KeycloakAuthProvider): |
| 24 | + def __init__( |
| 25 | + self, |
| 26 | + settings: Annotated[OAuth2GatewayProviderSettings, Depends(Stub(OAuth2GatewayProviderSettings))], |
| 27 | + unit_of_work: Annotated[UnitOfWork, Depends()], |
| 28 | + ) -> None: |
| 29 | + super().__init__(settings, unit_of_work) # type: ignore[arg-type] |
| 30 | + |
| 31 | + @classmethod |
| 32 | + def setup(cls, app: FastAPI) -> FastAPI: |
| 33 | + settings = OAuth2GatewayProviderSettings.model_validate( |
| 34 | + app.state.settings.auth.model_dump(exclude={"provider"}), |
| 35 | + ) |
| 36 | + log.info("Using %s provider with settings:\n%s", cls.__name__, settings) |
| 37 | + app.dependency_overrides[AuthProvider] = cls |
| 38 | + app.dependency_overrides[OAuth2GatewayProviderSettings] = lambda: settings |
| 39 | + return app |
| 40 | + |
20 | 41 | async def get_current_user(self, access_token: str | None, request: Request) -> User: # noqa: WPS231, WPS217 |
21 | 42 |
|
22 | 43 | if not access_token: |
|
0 commit comments