12
12
from datetime import UTC , datetime , timedelta
13
13
from typing import Optional
14
14
from fastapi import Depends , Cookie , HTTPException , status
15
+ from fastapi .responses import RedirectResponse
15
16
from utils .db import get_session
16
17
from utils .models import User , PasswordResetToken
17
18
@@ -180,7 +181,8 @@ def validate_token_and_get_user(
180
181
if decoded_token :
181
182
user_email = decoded_token .get ("sub" )
182
183
user = session .exec (select (User ).where (
183
- User .email == user_email )).first ()
184
+ User .email == user_email
185
+ )).first ()
184
186
if user :
185
187
if token_type == "refresh" :
186
188
new_access_token = create_access_token (
@@ -215,6 +217,14 @@ def get_user_from_tokens(
215
217
return None , None , None
216
218
217
219
220
+ class AuthenticationError (HTTPException ):
221
+ def __init__ (self ):
222
+ super ().__init__ (
223
+ status_code = status .HTTP_303_SEE_OTHER ,
224
+ headers = {"Location" : "/login" }
225
+ )
226
+
227
+
218
228
def get_authenticated_user (
219
229
tokens : tuple [Optional [str ], Optional [str ]
220
230
] = Depends (oauth2_scheme_cookie ),
@@ -228,11 +238,7 @@ def get_authenticated_user(
228
238
raise NeedsNewTokens (user , new_access_token , new_refresh_token )
229
239
return user
230
240
231
- # If both tokens are invalid or missing, redirect to login
232
- raise HTTPException (
233
- status_code = status .HTTP_307_TEMPORARY_REDIRECT ,
234
- headers = {"Location" : "/login" }
235
- )
241
+ raise AuthenticationError ()
236
242
237
243
238
244
def get_optional_user (
@@ -275,7 +281,9 @@ def generate_password_reset_url(email: str, token: str) -> str:
275
281
276
282
def send_reset_email (email : str , session : Session ):
277
283
# Check for an existing unexpired token
278
- user = session .exec (select (User ).where (User .email == email )).first ()
284
+ user = session .exec (select (User ).where (
285
+ User .email == email
286
+ )).first ()
279
287
if user :
280
288
existing_token = session .exec (
281
289
select (PasswordResetToken )
@@ -316,18 +324,19 @@ def send_reset_email(email: str, session: Session):
316
324
317
325
318
326
def get_user_from_reset_token (email : str , token : str , session : Session ) -> tuple [Optional [User ], Optional [PasswordResetToken ]]:
319
- reset_token = session .exec (select (PasswordResetToken ).where (
320
- PasswordResetToken .token == token ,
321
- PasswordResetToken .expires_at > datetime .now (UTC ),
322
- PasswordResetToken .used == False
323
- )).first ()
327
+ result = session .exec (
328
+ select (User , PasswordResetToken )
329
+ .where (
330
+ User .email == email ,
331
+ PasswordResetToken .token == token ,
332
+ PasswordResetToken .expires_at > datetime .now (UTC ),
333
+ PasswordResetToken .used == False ,
334
+ PasswordResetToken .user_id == User .id
335
+ )
336
+ ).first ()
324
337
325
- if not reset_token :
338
+ if not result :
326
339
return None , None
327
340
328
- user = session .exec (select (User ).where (
329
- User .email == email ,
330
- User .id == reset_token .user_id
331
- )).first ()
332
-
341
+ user , reset_token = result
333
342
return user , reset_token
0 commit comments