1
1
# auth.py
2
+ from logging import getLogger
2
3
from typing import Optional
3
4
from datetime import datetime
4
- from fastapi import APIRouter , Depends , HTTPException , Form , BackgroundTasks
5
+ from fastapi import APIRouter , Depends , HTTPException , BackgroundTasks , Form
5
6
from fastapi .responses import RedirectResponse
6
- from pydantic import BaseModel , EmailStr , ConfigDict
7
+ from pydantic import BaseModel , EmailStr , ConfigDict , field_validator
7
8
from sqlmodel import Session , select
8
9
from utils .db import User
9
10
from utils .auth import (
10
11
get_session ,
11
- validate_password_strength ,
12
12
get_user_from_reset_token ,
13
+ create_password_validator ,
14
+ create_passwords_match_validator ,
13
15
oauth2_scheme_cookie ,
14
16
get_password_hash ,
15
17
verify_password ,
19
21
send_reset_email
20
22
)
21
23
24
+ logger = getLogger ("uvicorn.error" )
25
+
22
26
router = APIRouter (prefix = "/auth" , tags = ["auth" ])
23
27
24
28
25
- class UserCreate (BaseModel ):
29
+ # -- Server Request and Response Models --
30
+
31
+
32
+ class UserRegister (BaseModel ):
26
33
name : str
27
34
email : EmailStr
28
35
password : str
29
- organization_id : Optional [int ] = None
36
+ confirm_password : str
37
+
38
+ validate_password_strength = create_password_validator ("password" )
39
+ validate_passwords_match = create_passwords_match_validator (
40
+ "password" , "confirm_password" )
41
+
42
+ @classmethod
43
+ async def as_form (
44
+ cls ,
45
+ name : str = Form (...),
46
+ email : EmailStr = Form (...),
47
+ password : str = Form (...),
48
+ confirm_password : str = Form (...)
49
+ ):
50
+ return cls (
51
+ name = name ,
52
+ email = email ,
53
+ password = password ,
54
+ confirm_password = confirm_password
55
+ )
56
+
57
+
58
+ class UserLogin (BaseModel ):
59
+ email : EmailStr
60
+ password : str
61
+
62
+ @classmethod
63
+ async def as_form (
64
+ cls ,
65
+ email : EmailStr = Form (...),
66
+ password : str = Form (...)
67
+ ):
68
+ return cls (email = email , password = password )
69
+
70
+
71
+ class UserForgotPassword (BaseModel ):
72
+ email : EmailStr
73
+
74
+ @classmethod
75
+ async def as_form (
76
+ cls ,
77
+ email : EmailStr = Form (...)
78
+ ):
79
+ return cls (email = email )
80
+
81
+
82
+ class UserResetPassword (BaseModel ):
83
+ email : EmailStr
84
+ token : str
85
+ new_password : str
86
+ confirm_new_password : str
87
+
88
+ # Use the factory with a different field name
89
+ validate_password_strength = create_password_validator ("new_password" )
90
+ validate_passwords_match = create_passwords_match_validator (
91
+ "new_password" , "confirm_new_password" )
92
+
93
+ @classmethod
94
+ async def as_form (
95
+ cls ,
96
+ email : EmailStr = Form (...),
97
+ token : str = Form (...),
98
+ new_password : str = Form (...),
99
+ confirm_new_password : str = Form (...)
100
+ ):
101
+ return cls (email = email , token = token ,
102
+ new_password = new_password , confirm_new_password = confirm_new_password )
103
+
104
+
105
+ # -- DB Request and Response Models --
30
106
31
107
32
108
class UserRead (BaseModel ):
@@ -41,31 +117,14 @@ class UserRead(BaseModel):
41
117
deleted : bool
42
118
43
119
44
- class UserForgotPassword (BaseModel ):
45
- email : EmailStr
46
-
47
-
48
- class UserResetPassword (BaseModel ):
49
- token : str
50
- new_password : str
120
+ # -- Routes --
51
121
52
122
53
123
@router .post ("/register" , response_class = RedirectResponse )
54
124
async def register (
55
- name : str = Form (...),
56
- email : EmailStr = Form (...),
57
- password : str = Form (...),
58
- confirm_password : str = Form (...),
125
+ user : UserRegister = Depends (UserRegister .as_form ),
59
126
session : Session = Depends (get_session ),
60
127
) -> RedirectResponse :
61
- if password != confirm_password :
62
- raise HTTPException (status_code = 400 , detail = "Passwords do not match" )
63
-
64
- if not validate_password_strength (password ):
65
- raise HTTPException (
66
- status_code = 400 , detail = "Password does not satisfy the security policy" )
67
-
68
- user = UserCreate (name = name , email = email , password = password )
69
128
db_user = session .exec (select (User ).where (
70
129
User .email == user .email )).first ()
71
130
@@ -92,13 +151,13 @@ async def register(
92
151
93
152
94
153
@router .post ("/login" , response_class = RedirectResponse )
95
- def login (
96
- email : str = Form (...),
97
- password : str = Form (...),
154
+ async def login (
155
+ user : UserLogin = Depends (UserLogin .as_form ),
98
156
session : Session = Depends (get_session ),
99
157
) -> RedirectResponse :
100
- db_user = session .exec (select (User ).where (User .email == email )).first ()
101
- if not db_user or not verify_password (password , db_user .hashed_password ):
158
+ db_user = session .exec (select (User ).where (
159
+ User .email == user .email )).first ()
160
+ if not db_user or not verify_password (user .password , db_user .hashed_password ):
102
161
raise HTTPException (status_code = 400 , detail = "Invalid credentials" )
103
162
104
163
# Create access token
@@ -128,7 +187,7 @@ def login(
128
187
129
188
# Updated refresh_token endpoint
130
189
@router .post ("/refresh" , response_class = RedirectResponse )
131
- def refresh_token (
190
+ async def refresh_token (
132
191
tokens : tuple [Optional [str ], Optional [str ]
133
192
] = Depends (oauth2_scheme_cookie ),
134
193
session : Session = Depends (get_session ),
@@ -173,46 +232,35 @@ def refresh_token(
173
232
return response
174
233
175
234
176
- class EmailSchema (BaseModel ):
177
- email : EmailStr
178
-
179
-
180
- class ResetSchema (BaseModel ):
181
- token : str
182
- new_password : str
183
-
184
-
185
235
@router .post ("/forgot_password" )
186
- def forgot_password (user : UserForgotPassword , background_tasks : BackgroundTasks , session : Session = Depends (get_session )):
236
+ async def forgot_password (
237
+ background_tasks : BackgroundTasks ,
238
+ user : UserForgotPassword = Depends (UserForgotPassword .as_form ),
239
+ session : Session = Depends (get_session )
240
+ ):
187
241
db_user = session .exec (select (User ).where (
188
242
User .email == user .email )).first ()
189
243
190
244
# TODO: Handle this in background task so we don't leak information via timing attacks
191
245
if db_user :
192
246
background_tasks .add_task (send_reset_email , user .email , session )
193
247
194
- return RedirectResponse (url = "/forgot_password" , status_code = 303 , show_form = False )
248
+ return RedirectResponse (url = "/forgot_password?show_form=false " , status_code = 303 )
195
249
196
250
197
251
@router .post ("/reset_password" )
198
- def reset_password (
199
- email : str , token : str , new_password : str , confirm_new_password : str , session : Session = Depends (get_session )
252
+ async def reset_password (
253
+ user : UserResetPassword = Depends (UserResetPassword .as_form ),
254
+ session : Session = Depends (get_session )
200
255
):
201
- if new_password != confirm_new_password :
202
- raise HTTPException (status_code = 400 , detail = "Passwords do not match" )
203
-
204
- if not validate_password_strength (new_password ):
205
- raise HTTPException (
206
- status_code = 400 , detail = "Password does not satisfy the security policy" )
207
-
208
256
authorized_user , reset_token = get_user_from_reset_token (
209
- email , token , session )
257
+ user . email , user . token , session )
210
258
211
259
if not authorized_user :
212
260
raise HTTPException (status_code = 400 , detail = "Invalid or expired token" )
213
261
214
262
# Update password and mark token as used
215
- authorized_user .hashed_password = get_password_hash (new_password )
263
+ authorized_user .hashed_password = get_password_hash (user . new_password )
216
264
reset_token .used = True
217
265
session .commit ()
218
266
session .refresh (authorized_user )
0 commit comments