@@ -81,11 +81,10 @@ async def as_form(
81
81
82
82
class UserResetPassword (BaseModel ):
83
83
email : EmailStr
84
- token : str
84
+ token : Optional [ str ]
85
85
new_password : str
86
86
confirm_new_password : str
87
87
88
- # Use the factory with a different field name
89
88
validate_password_strength = create_password_validator ("new_password" )
90
89
validate_passwords_match = create_passwords_match_validator (
91
90
"new_password" , "confirm_new_password" )
@@ -94,12 +93,16 @@ class UserResetPassword(BaseModel):
94
93
async def as_form (
95
94
cls ,
96
95
email : EmailStr = Form (...),
97
- token : str = Form (... ),
96
+ token : str = Form (None ),
98
97
new_password : str = Form (...),
99
98
confirm_new_password : str = Form (...)
100
99
):
101
- return cls (email = email , token = token ,
102
- new_password = new_password , confirm_new_password = confirm_new_password )
100
+ return cls (
101
+ email = email ,
102
+ token = token ,
103
+ new_password = new_password ,
104
+ confirm_new_password = confirm_new_password
105
+ )
103
106
104
107
105
108
# --- DB Request and Response Models ---
@@ -256,8 +259,39 @@ async def forgot_password(
256
259
@router .post ("/reset_password" )
257
260
async def reset_password (
258
261
user : UserResetPassword = Depends (UserResetPassword .as_form ),
262
+ tokens : tuple [Optional [str ], Optional [str ]] = Depends (oauth2_scheme_cookie ),
259
263
session : Session = Depends (get_session )
260
264
):
265
+ access_token , _ = tokens
266
+
267
+ # Handle authenticated user
268
+ if access_token :
269
+ try :
270
+ decoded_token = validate_token (access_token )
271
+ if decoded_token and decoded_token .get ("sub" ) == user .email :
272
+ # User is authenticated and changing their own password
273
+ db_user = session .exec (select (User ).where (
274
+ User .email == user .email )).first ()
275
+ if not db_user :
276
+ raise HTTPException (status_code = 404 , detail = "User not found" )
277
+
278
+ # Update password
279
+ if db_user .password :
280
+ db_user .password .hashed_password = get_password_hash (user .new_password )
281
+ else :
282
+ db_user .password = UserPassword (
283
+ hashed_password = get_password_hash (user .new_password )
284
+ )
285
+ session .commit ()
286
+ return RedirectResponse (url = "/settings" , status_code = 303 )
287
+
288
+ except Exception as e :
289
+ logger .error (f"Error validating token: { e } " )
290
+
291
+ # Handle unauthenticated user with reset token
292
+ if not user .token :
293
+ raise HTTPException (status_code = 400 , detail = "Reset token required for unauthenticated password reset" )
294
+
261
295
authorized_user , reset_token = get_user_from_reset_token (
262
296
user .email , user .token , session )
263
297
@@ -270,16 +304,13 @@ async def reset_password(
270
304
user .new_password
271
305
)
272
306
else :
273
- logger .warning (
274
- "User password not found during password reset; creating new password for user" )
275
307
authorized_user .password = UserPassword (
276
308
hashed_password = get_password_hash (user .new_password )
277
309
)
278
310
279
311
reset_token .used = True
280
312
session .commit ()
281
- session .refresh (authorized_user )
282
-
313
+
283
314
return RedirectResponse (url = "/login" , status_code = 303 )
284
315
285
316
0 commit comments