|
1 | 1 | # SPDX-FileCopyrightText: 2023-2024 MTS PJSC |
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | 3 | import logging |
4 | | -from typing import Annotated, Any |
| 4 | +from typing import Annotated, Any, NoReturn |
5 | 5 |
|
6 | 6 | from fastapi import Depends, FastAPI, Request |
7 | | -from keycloak import KeycloakOpenID |
| 7 | +from jwcrypto.common import JWException |
| 8 | +from keycloak import KeycloakOpenID, KeycloakOperationError |
8 | 9 |
|
| 10 | +from syncmaster.db.models import User |
9 | 11 | from syncmaster.exceptions import EntityNotFoundError |
10 | 12 | from syncmaster.exceptions.auth import AuthorizationError |
11 | 13 | from syncmaster.exceptions.redirect import RedirectException |
@@ -63,83 +65,75 @@ async def get_token_authorization_code_grant( |
63 | 65 | ) -> dict[str, Any]: |
64 | 66 | try: |
65 | 67 | redirect_uri = redirect_uri or self.settings.keycloak.redirect_uri |
66 | | - token = self.keycloak_openid.token( |
| 68 | + token = await self.keycloak_openid.a_token( |
67 | 69 | grant_type="authorization_code", |
68 | 70 | code=code, |
69 | 71 | redirect_uri=redirect_uri, |
70 | 72 | ) |
71 | 73 | return token |
72 | | - except Exception as e: |
| 74 | + except KeycloakOperationError as e: |
73 | 75 | raise AuthorizationError("Failed to get token") from e |
74 | 76 |
|
75 | | - async def get_current_user(self, access_token: str, *args, **kwargs) -> Any: |
| 77 | + async def get_current_user(self, access_token: str | None, **kwargs) -> User: |
76 | 78 | request: Request = kwargs["request"] |
77 | | - refresh_token = request.session.get("refresh_token") |
78 | | - |
79 | 79 | if not access_token: |
80 | 80 | log.debug("No access token found in session.") |
81 | | - self.redirect_to_auth(request.url.path) |
| 81 | + await self.redirect_to_auth(request.url.path) |
82 | 82 |
|
83 | 83 | try: |
84 | 84 | # if user is disabled or blocked in Keycloak after the token is issued, he will |
85 | 85 | # remain authorized until the token expires (not more than 15 minutes in MTS SSO) |
86 | | - token_info = self.keycloak_openid.decode_token(token=access_token) |
87 | | - except Exception as e: |
| 86 | + token_info = await self.keycloak_openid.a_decode_token(token=access_token) |
| 87 | + except (KeycloakOperationError, JWException) as e: |
88 | 88 | log.info("Access token is invalid or expired: %s", e) |
89 | | - token_info = None |
| 89 | + token_info = {} |
90 | 90 |
|
| 91 | + refresh_token = request.session.get("refresh_token") |
91 | 92 | if not token_info and refresh_token: |
92 | 93 | log.debug("Access token invalid. Attempting to refresh.") |
93 | 94 |
|
94 | 95 | try: |
95 | | - new_tokens = await self.refresh_access_token(refresh_token) |
| 96 | + new_tokens = await self.keycloak_openid.a_refresh_token(refresh_token) |
96 | 97 |
|
97 | 98 | new_access_token = new_tokens.get("access_token") |
98 | 99 | new_refresh_token = new_tokens.get("refresh_token") |
99 | 100 | request.session["access_token"] = new_access_token |
100 | 101 | request.session["refresh_token"] = new_refresh_token |
101 | 102 |
|
102 | | - token_info = self.keycloak_openid.decode_token( |
103 | | - token=new_access_token, |
104 | | - ) |
| 103 | + token_info = await self.keycloak_openid.a_decode_token(token=new_access_token) |
105 | 104 | log.debug("Access token refreshed and decoded successfully.") |
106 | | - except Exception as e: |
| 105 | + except (KeycloakOperationError, JWException) as e: |
107 | 106 | log.debug("Failed to refresh access token: %s", e) |
108 | | - self.redirect_to_auth(request.url.path) |
| 107 | + await self.redirect_to_auth(request.url.path) |
109 | 108 |
|
110 | 109 | # these names are hardcoded in keycloak: |
111 | 110 | # https://github.com/keycloak/keycloak/blob/3ca3a4ad349b4d457f6829eaf2ae05f1e01408be/core/src/main/java/org/keycloak/representations/IDToken.java |
112 | 111 | user_id = token_info.get("sub") |
| 112 | + if not user_id: |
| 113 | + raise AuthorizationError("Invalid token payload") |
| 114 | + |
113 | 115 | login = token_info.get("preferred_username") |
114 | 116 | email = token_info.get("email") |
115 | 117 | first_name = token_info.get("given_name") |
116 | 118 | middle_name = token_info.get("middle_name") |
117 | 119 | last_name = token_info.get("family_name") |
118 | 120 |
|
119 | | - if not user_id: |
120 | | - raise AuthorizationError("Invalid token payload") |
121 | | - |
122 | | - async with self._uow: |
123 | | - try: |
124 | | - user = await self._uow.user.read_by_username(login) |
125 | | - except EntityNotFoundError: |
126 | | - user = await self._uow.user.create( |
| 121 | + try: |
| 122 | + return await self._uow.user.read_by_username(login) |
| 123 | + except EntityNotFoundError: |
| 124 | + async with self._uow: |
| 125 | + return await self._uow.user.create( |
127 | 126 | username=login, |
128 | 127 | email=email, |
129 | 128 | first_name=first_name, |
130 | 129 | middle_name=middle_name, |
131 | 130 | last_name=last_name, |
132 | 131 | is_active=True, |
133 | 132 | ) |
134 | | - return user |
135 | | - |
136 | | - async def refresh_access_token(self, refresh_token: str) -> dict[str, Any]: |
137 | | - new_tokens = self.keycloak_openid.refresh_token(refresh_token) |
138 | | - return new_tokens |
139 | 133 |
|
140 | | - def redirect_to_auth(self, path: str) -> None: |
| 134 | + async def redirect_to_auth(self, path: str) -> NoReturn: |
141 | 135 | state = generate_state(path) |
142 | | - auth_url = self.keycloak_openid.auth_url( |
| 136 | + auth_url = await self.keycloak_openid.a_auth_url( |
143 | 137 | redirect_uri=self.settings.keycloak.redirect_uri, |
144 | 138 | scope=self.settings.keycloak.scope, |
145 | 139 | state=state, |
|
0 commit comments