Skip to content

Commit 2864a29

Browse files
Fixed middleware error handling and handled form and password validation via Pydantic with 422 status code
1 parent 90f0265 commit 2864a29

File tree

7 files changed

+390
-115
lines changed

7 files changed

+390
-115
lines changed

main.py

Lines changed: 76 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
1+
# ToDo: Add CSRF protection to all POST, download, and sensitive data routes
2+
13
import logging
24
from typing import Optional
35
from contextlib import asynccontextmanager
46
from fastapi import FastAPI, Request, Depends, status
57
from fastapi.responses import RedirectResponse
68
from fastapi.staticfiles import StaticFiles
79
from fastapi.templating import Jinja2Templates
8-
from fastapi.exceptions import RequestValidationError, StarletteHTTPException, HTTPException
10+
from fastapi.exceptions import RequestValidationError, HTTPException, StarletteHTTPException
911
from sqlmodel import Session
1012
from routers import authentication, organization, role, user
11-
from utils.auth import get_authenticated_user, get_optional_user, NeedsNewTokens, get_user_from_reset_token
13+
from utils.auth import get_authenticated_user, get_optional_user, NeedsNewTokens, get_user_from_reset_token, PasswordValidationError
1214
from utils.db import User, get_session
1315

1416

1517
logger = logging.getLogger("uvicorn.error")
18+
logger.setLevel(logging.DEBUG)
1619

1720

1821
@asynccontextmanager
@@ -31,7 +34,10 @@ async def lifespan(app: FastAPI):
3134
templates = Jinja2Templates(directory="templates")
3235

3336

34-
# Middleware to handle the NeedsNewTokens exception
37+
# -- Exception Handling Middlewares --
38+
39+
40+
# Handle NeedsNewTokens by setting new tokens and redirecting to same page
3541
@app.exception_handler(NeedsNewTokens)
3642
async def needs_new_tokens_handler(request: Request, exc: NeedsNewTokens):
3743
response = RedirectResponse(
@@ -52,28 +58,83 @@ async def needs_new_tokens_handler(request: Request, exc: NeedsNewTokens):
5258
)
5359
return response
5460

55-
# TODO: Make sure this only catches server errors and not 307 redirects
56-
# Create a custom server error class that inherits from StarletteHTTPException?
57-
# @app.exception_handler(StarletteHTTPException)
58-
# async def http_exception_handler(request: Request, exc: StarletteHTTPException):
59-
# return templates.TemplateResponse(
60-
# "errors/error.html",
61-
# {"request": request, "status_code": exc.status_code, "detail": exc.detail},
62-
# status_code=exc.status_code,
63-
# )
61+
62+
# Handle PasswordValidationError by rendering the error page
63+
@app.exception_handler(PasswordValidationError)
64+
async def password_validation_exception_handler(request: Request, exc: PasswordValidationError):
65+
return templates.TemplateResponse(
66+
"errors/validation_error.html",
67+
{
68+
"request": request,
69+
"status_code": 422,
70+
"errors": {exc.detail["field"]: exc.detail["message"]}
71+
},
72+
status_code=422,
73+
)
6474

6575

76+
# Handle RequestValidationError by rendering the error page (TODO: use toast instead?)
6677
@app.exception_handler(RequestValidationError)
6778
async def validation_exception_handler(request: Request, exc: RequestValidationError):
79+
errors = {}
80+
for error in exc.errors():
81+
# Handle different error locations more carefully
82+
location = error["loc"]
83+
84+
# Skip type errors for the whole body
85+
if len(location) == 1 and location[0] == "body":
86+
continue
87+
88+
# For form fields, the location might be just (field_name,)
89+
# For JSON body, it might be (body, field_name)
90+
field_name = location[-1] # Take the last item in the location tuple
91+
errors[field_name] = error["msg"]
92+
6893
return templates.TemplateResponse(
69-
"errors/error.html",
70-
{"request": request, "status_code": 422, "detail": str(exc)},
94+
"errors/validation_error.html",
95+
{
96+
"request": request,
97+
"status_code": 422,
98+
"errors": errors
99+
},
71100
status_code=422,
72101
)
73102

74103

104+
# Handle StarletteHTTPException (including 404, 405, etc.) by rendering the error page
105+
@app.exception_handler(StarletteHTTPException)
106+
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
107+
# Don't handle redirects
108+
if exc.status_code in [301, 302, 303, 307, 308]:
109+
raise exc
110+
111+
return templates.TemplateResponse(
112+
"errors/error.html",
113+
{"request": request, "status_code": exc.status_code, "detail": exc.detail},
114+
status_code=exc.status_code,
115+
)
116+
117+
118+
# Add handler for uncaught exceptions (500 Internal Server Error)
119+
@app.exception_handler(Exception)
120+
async def general_exception_handler(request: Request, exc: Exception):
121+
# Log the error for debugging
122+
logger.error(f"Unhandled exception: {exc}", exc_info=True)
123+
124+
return templates.TemplateResponse(
125+
"errors/error.html",
126+
{
127+
"request": request,
128+
"status_code": 500,
129+
"detail": "Internal Server Error"
130+
},
131+
status_code=500,
132+
)
133+
134+
75135
# -- Unauthenticated Routes --
76136

137+
77138
# Define a dependency for common parameters
78139
async def common_unauthenticated_parameters(
79140
request: Request,
@@ -113,7 +174,7 @@ async def read_register(
113174
@app.get("/forgot_password")
114175
async def read_forgot_password(
115176
params: dict = Depends(common_unauthenticated_parameters),
116-
show_form: Optional[bool] = True,
177+
show_form: Optional[str] = "true",
117178
):
118179
if params["user"]:
119180
return RedirectResponse(url="/dashboard", status_code=302)
@@ -196,7 +257,6 @@ async def read_profile(
196257
app.include_router(role.router)
197258
app.include_router(user.router)
198259

199-
200260
if __name__ == "__main__":
201261
import uvicorn
202262

routers/authentication.py

Lines changed: 100 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
# auth.py
2+
from logging import getLogger
23
from typing import Optional
34
from datetime import datetime
4-
from fastapi import APIRouter, Depends, HTTPException, Form, BackgroundTasks
5+
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks, Form
56
from fastapi.responses import RedirectResponse
6-
from pydantic import BaseModel, EmailStr, ConfigDict
7+
from pydantic import BaseModel, EmailStr, ConfigDict, field_validator
78
from sqlmodel import Session, select
89
from utils.db import User
910
from utils.auth import (
1011
get_session,
11-
validate_password_strength,
1212
get_user_from_reset_token,
13+
create_password_validator,
14+
create_passwords_match_validator,
1315
oauth2_scheme_cookie,
1416
get_password_hash,
1517
verify_password,
@@ -19,14 +21,88 @@
1921
send_reset_email
2022
)
2123

24+
logger = getLogger("uvicorn.error")
25+
2226
router = APIRouter(prefix="/auth", tags=["auth"])
2327

2428

25-
class UserCreate(BaseModel):
29+
# -- Server Request and Response Models --
30+
31+
32+
class UserRegister(BaseModel):
2633
name: str
2734
email: EmailStr
2835
password: str
29-
organization_id: Optional[int] = None
36+
confirm_password: str
37+
38+
validate_password_strength = create_password_validator("password")
39+
validate_passwords_match = create_passwords_match_validator(
40+
"password", "confirm_password")
41+
42+
@classmethod
43+
async def as_form(
44+
cls,
45+
name: str = Form(...),
46+
email: EmailStr = Form(...),
47+
password: str = Form(...),
48+
confirm_password: str = Form(...)
49+
):
50+
return cls(
51+
name=name,
52+
email=email,
53+
password=password,
54+
confirm_password=confirm_password
55+
)
56+
57+
58+
class UserLogin(BaseModel):
59+
email: EmailStr
60+
password: str
61+
62+
@classmethod
63+
async def as_form(
64+
cls,
65+
email: EmailStr = Form(...),
66+
password: str = Form(...)
67+
):
68+
return cls(email=email, password=password)
69+
70+
71+
class UserForgotPassword(BaseModel):
72+
email: EmailStr
73+
74+
@classmethod
75+
async def as_form(
76+
cls,
77+
email: EmailStr = Form(...)
78+
):
79+
return cls(email=email)
80+
81+
82+
class UserResetPassword(BaseModel):
83+
email: EmailStr
84+
token: str
85+
new_password: str
86+
confirm_new_password: str
87+
88+
# Use the factory with a different field name
89+
validate_password_strength = create_password_validator("new_password")
90+
validate_passwords_match = create_passwords_match_validator(
91+
"new_password", "confirm_new_password")
92+
93+
@classmethod
94+
async def as_form(
95+
cls,
96+
email: EmailStr = Form(...),
97+
token: str = Form(...),
98+
new_password: str = Form(...),
99+
confirm_new_password: str = Form(...)
100+
):
101+
return cls(email=email, token=token,
102+
new_password=new_password, confirm_new_password=confirm_new_password)
103+
104+
105+
# -- DB Request and Response Models --
30106

31107

32108
class UserRead(BaseModel):
@@ -41,31 +117,14 @@ class UserRead(BaseModel):
41117
deleted: bool
42118

43119

44-
class UserForgotPassword(BaseModel):
45-
email: EmailStr
46-
47-
48-
class UserResetPassword(BaseModel):
49-
token: str
50-
new_password: str
120+
# -- Routes --
51121

52122

53123
@router.post("/register", response_class=RedirectResponse)
54124
async def register(
55-
name: str = Form(...),
56-
email: EmailStr = Form(...),
57-
password: str = Form(...),
58-
confirm_password: str = Form(...),
125+
user: UserRegister = Depends(UserRegister.as_form),
59126
session: Session = Depends(get_session),
60127
) -> RedirectResponse:
61-
if password != confirm_password:
62-
raise HTTPException(status_code=400, detail="Passwords do not match")
63-
64-
if not validate_password_strength(password):
65-
raise HTTPException(
66-
status_code=400, detail="Password does not satisfy the security policy")
67-
68-
user = UserCreate(name=name, email=email, password=password)
69128
db_user = session.exec(select(User).where(
70129
User.email == user.email)).first()
71130

@@ -92,13 +151,13 @@ async def register(
92151

93152

94153
@router.post("/login", response_class=RedirectResponse)
95-
def login(
96-
email: str = Form(...),
97-
password: str = Form(...),
154+
async def login(
155+
user: UserLogin = Depends(UserLogin.as_form),
98156
session: Session = Depends(get_session),
99157
) -> RedirectResponse:
100-
db_user = session.exec(select(User).where(User.email == email)).first()
101-
if not db_user or not verify_password(password, db_user.hashed_password):
158+
db_user = session.exec(select(User).where(
159+
User.email == user.email)).first()
160+
if not db_user or not verify_password(user.password, db_user.hashed_password):
102161
raise HTTPException(status_code=400, detail="Invalid credentials")
103162

104163
# Create access token
@@ -128,7 +187,7 @@ def login(
128187

129188
# Updated refresh_token endpoint
130189
@router.post("/refresh", response_class=RedirectResponse)
131-
def refresh_token(
190+
async def refresh_token(
132191
tokens: tuple[Optional[str], Optional[str]
133192
] = Depends(oauth2_scheme_cookie),
134193
session: Session = Depends(get_session),
@@ -173,46 +232,35 @@ def refresh_token(
173232
return response
174233

175234

176-
class EmailSchema(BaseModel):
177-
email: EmailStr
178-
179-
180-
class ResetSchema(BaseModel):
181-
token: str
182-
new_password: str
183-
184-
185235
@router.post("/forgot_password")
186-
def forgot_password(user: UserForgotPassword, background_tasks: BackgroundTasks, session: Session = Depends(get_session)):
236+
async def forgot_password(
237+
background_tasks: BackgroundTasks,
238+
user: UserForgotPassword = Depends(UserForgotPassword.as_form),
239+
session: Session = Depends(get_session)
240+
):
187241
db_user = session.exec(select(User).where(
188242
User.email == user.email)).first()
189243

190244
# TODO: Handle this in background task so we don't leak information via timing attacks
191245
if db_user:
192246
background_tasks.add_task(send_reset_email, user.email, session)
193247

194-
return RedirectResponse(url="/forgot_password", status_code=303, show_form=False)
248+
return RedirectResponse(url="/forgot_password?show_form=false", status_code=303)
195249

196250

197251
@router.post("/reset_password")
198-
def reset_password(
199-
email: str, token: str, new_password: str, confirm_new_password: str, session: Session = Depends(get_session)
252+
async def reset_password(
253+
user: UserResetPassword = Depends(UserResetPassword.as_form),
254+
session: Session = Depends(get_session)
200255
):
201-
if new_password != confirm_new_password:
202-
raise HTTPException(status_code=400, detail="Passwords do not match")
203-
204-
if not validate_password_strength(new_password):
205-
raise HTTPException(
206-
status_code=400, detail="Password does not satisfy the security policy")
207-
208256
authorized_user, reset_token = get_user_from_reset_token(
209-
email, token, session)
257+
user.email, user.token, session)
210258

211259
if not authorized_user:
212260
raise HTTPException(status_code=400, detail="Invalid or expired token")
213261

214262
# Update password and mark token as used
215-
authorized_user.hashed_password = get_password_hash(new_password)
263+
authorized_user.hashed_password = get_password_hash(user.new_password)
216264
reset_token.used = True
217265
session.commit()
218266
session.refresh(authorized_user)

0 commit comments

Comments
 (0)