diff --git a/main.py b/main.py index ec8fe93..dc65574 100644 --- a/main.py +++ b/main.py @@ -33,7 +33,7 @@ async def lifespan(app: FastAPI): # Optional shutdown logic -app = FastAPI(lifespan=lifespan) +app: FastAPI = FastAPI(lifespan=lifespan) # Mount static files (e.g., CSS, JS) app.mount("/static", StaticFiles(directory="static"), name="static") @@ -169,10 +169,12 @@ async def read_home( @app.get("/login") async def read_login( - params: dict = Depends(common_unauthenticated_parameters) + params: dict = Depends(common_unauthenticated_parameters), + email_updated: Optional[str] = "false" ): if params["user"]: return RedirectResponse(url="/dashboard", status_code=302) + params["email_updated"] = email_updated return templates.TemplateResponse(params["request"], "authentication/login.html", params) @@ -256,14 +258,18 @@ async def read_dashboard( @app.get("/profile") async def read_profile( - params: dict = Depends(common_authenticated_parameters) + params: dict = Depends(common_authenticated_parameters), + email_update_requested: Optional[str] = "false", + email_updated: Optional[str] = "false" ): # Add image constraints to the template context params.update({ "max_file_size_mb": MAX_FILE_SIZE / (1024 * 1024), # Convert bytes to MB "min_dimension": MIN_DIMENSION, "max_dimension": MAX_DIMENSION, - "allowed_formats": list(ALLOWED_CONTENT_TYPES.keys()) + "allowed_formats": list(ALLOWED_CONTENT_TYPES.keys()), + "email_update_requested": email_update_requested, + "email_updated": email_updated }) return templates.TemplateResponse(params["request"], "users/profile.html", params) diff --git a/routers/authentication.py b/routers/authentication.py index 7583ec1..3b3d544 100644 --- a/routers/authentication.py +++ b/routers/authentication.py @@ -6,7 +6,7 @@ from fastapi.responses import RedirectResponse from pydantic import BaseModel, EmailStr, ConfigDict from sqlmodel import Session, select -from utils.models import User, UserPassword +from utils.models import User, UserPassword, DataIntegrityError from utils.auth import ( get_session, get_user_from_reset_token, @@ -18,13 +18,50 @@ create_access_token, create_refresh_token, validate_token, - send_reset_email + send_reset_email, + send_email_update_confirmation, + get_user_from_email_update_token, + get_authenticated_user ) logger = getLogger("uvicorn.error") router = APIRouter(prefix="/auth", tags=["auth"]) +# --- Custom Exceptions --- + + +class EmailAlreadyRegisteredError(HTTPException): + def __init__(self): + super().__init__( + status_code=409, + detail="This email is already registered" + ) + + +class InvalidCredentialsError(HTTPException): + def __init__(self): + super().__init__( + status_code=401, + detail="Invalid credentials" + ) + + +class InvalidResetTokenError(HTTPException): + def __init__(self): + super().__init__( + status_code=401, + detail="Invalid or expired password reset token; please request a new one" + ) + + +class InvalidEmailUpdateTokenError(HTTPException): + def __init__(self): + super().__init__( + status_code=401, + detail="Invalid or expired email update token; please request a new one" + ) + # --- Server Request and Response Models --- @@ -102,6 +139,17 @@ async def as_form( new_password=new_password, confirm_new_password=confirm_new_password) +class UpdateEmail(BaseModel): + new_email: EmailStr + + @classmethod + async def as_form( + cls, + new_email: EmailStr = Form(...) + ): + return cls(new_email=new_email) + + # --- DB Request and Response Models --- @@ -130,7 +178,7 @@ async def register( User.email == user.email)).first() if db_user: - raise HTTPException(status_code=400, detail="Email already registered") + raise EmailAlreadyRegisteredError() # Hash the password hashed_password = get_password_hash(user.password) @@ -147,9 +195,20 @@ async def register( refresh_token = create_refresh_token(data={"sub": db_user.email}) # Set cookie response = RedirectResponse(url="/", status_code=303) - response.set_cookie(key="access_token", value=access_token, httponly=True) - response.set_cookie(key="refresh_token", - value=refresh_token, httponly=True) + response.set_cookie( + key="access_token", + value=access_token, + httponly=True, + secure=True, + samesite="strict" + ) + response.set_cookie( + key="refresh_token", + value=refresh_token, + httponly=True, + secure=True, + samesite="strict" + ) return response @@ -164,7 +223,7 @@ async def login( User.email == user.email)).first() if not db_user or not db_user.password or not verify_password(user.password, db_user.password.hashed_password): - raise HTTPException(status_code=400, detail="Invalid credentials") + raise InvalidCredentialsError() # Create access token access_token = create_access_token( @@ -262,7 +321,7 @@ async def reset_password( user.email, user.token, session) if not authorized_user or not reset_token: - raise HTTPException(status_code=400, detail="Invalid or expired token") + raise InvalidResetTokenError() # Update password and mark token as used if authorized_user.password: @@ -289,3 +348,81 @@ def logout(): response.delete_cookie("access_token") response.delete_cookie("refresh_token") return response + + +@router.post("/update_email") +async def request_email_update( + update: UpdateEmail = Depends(UpdateEmail.as_form), + user: User = Depends(get_authenticated_user), + session: Session = Depends(get_session) +): + # Check if the new email is already registered + existing_user = session.exec( + select(User).where(User.email == update.new_email) + ).first() + + if existing_user: + raise EmailAlreadyRegisteredError() + + if not user.id: + raise DataIntegrityError(resource="User id") + + # Send confirmation email + send_email_update_confirmation( + current_email=user.email, + new_email=update.new_email, + user_id=user.id, + session=session + ) + + return RedirectResponse( + url="/profile?email_update_requested=true", + status_code=303 + ) + + +@router.get("/confirm_email_update") +async def confirm_email_update( + user_id: int, + token: str, + new_email: str, + session: Session = Depends(get_session) +): + user, update_token = get_user_from_email_update_token( + user_id, token, session + ) + + if not user or not update_token: + raise InvalidResetTokenError() + + # Update email and mark token as used + user.email = new_email + update_token.used = True + session.commit() + + # Create new tokens with the updated email + access_token = create_access_token(data={"sub": new_email, "fresh": True}) + refresh_token = create_refresh_token(data={"sub": new_email}) + + # Set cookies before redirecting + response = RedirectResponse( + url="/profile?email_updated=true", + status_code=303 + ) + + # Add secure cookie attributes + response.set_cookie( + key="access_token", + value=access_token, + httponly=True, + secure=True, + samesite="lax" + ) + response.set_cookie( + key="refresh_token", + value=refresh_token, + httponly=True, + secure=True, + samesite="lax" + ) + return response diff --git a/routers/organization.py b/routers/organization.py index c937ae3..afa6e5f 100644 --- a/routers/organization.py +++ b/routers/organization.py @@ -10,6 +10,8 @@ logger = getLogger("uvicorn.error") +router = APIRouter(prefix="/organizations", tags=["organizations"]) + # --- Custom Exceptions --- @@ -37,9 +39,6 @@ def __init__(self): ) -router = APIRouter(prefix="/organizations", tags=["organizations"]) - - # --- Server Request and Response Models --- diff --git a/routers/user.py b/routers/user.py index 1baf82a..135f355 100644 --- a/routers/user.py +++ b/routers/user.py @@ -16,7 +16,6 @@ class UpdateProfile(BaseModel): """Request model for updating user profile information""" name: str - email: EmailStr avatar_file: Optional[bytes] = None avatar_content_type: Optional[str] = None @@ -24,7 +23,6 @@ class UpdateProfile(BaseModel): async def as_form( cls, name: str = Form(...), - email: EmailStr = Form(...), avatar_file: Optional[UploadFile] = File(None), ): avatar_data = None @@ -36,7 +34,6 @@ async def as_form( return cls( name=name, - email=email, avatar_file=avatar_data, avatar_content_type=avatar_content_type ) @@ -73,7 +70,6 @@ async def update_profile( # Update user details user.name = user_profile.name - user.email = user_profile.email if user_profile.avatar_file: user.avatar_data = user_profile.avatar_file diff --git a/templates/components/header.html b/templates/components/header.html index d24bff6..6f8d04c 100644 --- a/templates/components/header.html +++ b/templates/components/header.html @@ -24,7 +24,11 @@