diff --git a/src/auth/crud.py b/src/auth/crud.py index eeda7a7..0c447a2 100644 --- a/src/auth/crud.py +++ b/src/auth/crud.py @@ -5,7 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from auth.tables import SiteUser, UserSession -from auth.types import SiteUserData _logger = logging.getLogger(__name__) @@ -52,7 +51,7 @@ async def create_user_session(db_session: AsyncSession, session_id: str, computi )) -async def remove_user_session(db_session: AsyncSession, session_id: str) -> dict: +async def remove_user_session(db_session: AsyncSession, session_id: str): query = sqlalchemy.select(UserSession).where(UserSession.session_id == session_id) user_session = await db_session.scalars(query) await db_session.delete(user_session.first()) @@ -74,7 +73,7 @@ async def task_clean_expired_user_sessions(db_session: AsyncSession): # get the site user given a session ID; returns None when session is invalid -async def get_site_user(db_session: AsyncSession, session_id: str) -> None | SiteUserData: +async def get_site_user(db_session: AsyncSession, session_id: str) -> SiteUser | None: query = ( sqlalchemy .select(UserSession) @@ -89,17 +88,7 @@ async def get_site_user(db_session: AsyncSession, session_id: str) -> None | Sit .select(SiteUser) .where(SiteUser.computing_id == user_session.computing_id) ) - user = await db_session.scalar(query) - if user is None: - return None - - return SiteUserData( - user_session.computing_id, - user.first_logged_in.isoformat(), - user.last_logged_in.isoformat(), - user.profile_picture_url - ) - + return await db_session.scalar(query) async def site_user_exists(db_session: AsyncSession, computing_id: str) -> bool: user = await db_session.scalar( diff --git a/src/auth/models.py b/src/auth/models.py index f342468..a428586 100644 --- a/src/auth/models.py +++ b/src/auth/models.py @@ -1,7 +1,23 @@ +from datetime import datetime + from pydantic import BaseModel, Field -class LoginBodyModel(BaseModel): +class LoginBodyParams(BaseModel): service: str = Field(description="Service URL used for SFU's CAS system") ticket: str = Field(description="Ticket return from SFU's CAS system") redirect_url: str | None = Field(None, description="Optional redirect URL") + +class UpdateUserParams(BaseModel): + profile_picture_url: str + +class UserSessionModel(BaseModel): + computing_id: str + issue_time: datetime + session_id: str + +class SiteUserModel(BaseModel): + computing_id: str + first_logged_in: datetime + last_logged_in: datetime + profile_picture_url: str | None = None diff --git a/src/auth/tables.py b/src/auth/tables.py index 9599018..b6ffe07 100644 --- a/src/auth/tables.py +++ b/src/auth/tables.py @@ -1,6 +1,7 @@ from datetime import datetime -from sqlalchemy import Column, DateTime, ForeignKey, String, Text +from sqlalchemy import DateTime, ForeignKey, String, Text, func +from sqlalchemy.orm import Mapped, mapped_column from constants import COMPUTING_ID_LEN, SESSION_ID_LEN from database import Base @@ -9,17 +10,18 @@ class UserSession(Base): __tablename__ = "user_session" - computing_id = Column( + computing_id: Mapped[str] = mapped_column( String(COMPUTING_ID_LEN), ForeignKey("site_user.computing_id"), # in psql pkey means non-null primary_key=True, ) + # TODO: Make all timestamps uneditable later # time the CAS ticket was issued - issue_time = Column(DateTime, nullable=False) + issue_time: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.now()) - session_id = Column( + session_id: Mapped[str] = mapped_column( String(SESSION_ID_LEN), nullable=False, unique=True ) # the space needed to store 256 bytes in base64 @@ -29,15 +31,22 @@ class SiteUser(Base): # see: https://stackoverflow.com/questions/22256124/cannot-create-a-database-table-named-user-in-postgresql __tablename__ = "site_user" - computing_id = Column( + computing_id: Mapped[str] = mapped_column( String(COMPUTING_ID_LEN), primary_key=True, ) # first and last time logged into the CSSS API - # note: default date (for pre-existing columns) is June 16th, 2024 - first_logged_in = Column(DateTime, nullable=False, default=datetime(2024, 6, 16)) - last_logged_in = Column(DateTime, nullable=False, default=datetime(2024, 6, 16)) + first_logged_in: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.now()) + last_logged_in: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.now()) # optional user information for display purposes - profile_picture_url = Column(Text, nullable=True) + profile_picture_url: Mapped[str | None] = mapped_column(Text, nullable=True) + + def serialize(self) -> dict[str, str | int | bool | None]: + return { + "computing_id": self.computing_id, + "first_logged_in": self.first_logged_in.isoformat(), + "last_logged_in": self.last_logged_in.isoformat(), + "profile_picture_url": self.profile_picture_url + } diff --git a/src/auth/urls.py b/src/auth/urls.py index af60046..290141a 100644 --- a/src/auth/urls.py +++ b/src/auth/urls.py @@ -6,13 +6,13 @@ import requests # TODO: make this async import xmltodict from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, Response -from fastapi.responses import JSONResponse, PlainTextResponse, RedirectResponse +from fastapi.responses import JSONResponse, RedirectResponse import database from auth import crud -from auth.models import LoginBodyModel +from auth.models import LoginBodyParams, SiteUserModel, UpdateUserParams from constants import DOMAIN, IS_PROD, SAMESITE -from utils.shared_models import DetailModel +from utils.shared_models import DetailModel, MessageModel _logger = logging.getLogger(__name__) @@ -51,7 +51,7 @@ async def login_user( request: Request, db_session: database.DBSession, background_tasks: BackgroundTasks, - body: LoginBodyModel + body: LoginBodyParams ): # verify the ticket is valid service_url = body.service @@ -94,8 +94,9 @@ async def login_user( @router.get( "/logout", - operation_id="logout", description="Logs out the current user by invalidating the session_id cookie", + operation_id="logout", + response_model=MessageModel ) async def logout_user( request: Request, @@ -119,6 +120,10 @@ async def logout_user( "/user", operation_id="get_user", description="Get info about the current user. Only accessible by that user", + response_model=SiteUserModel, + responses={ + 401: { "description": "Not logged in.", "model": DetailModel } + }, ) async def get_user( request: Request, @@ -129,35 +134,38 @@ async def get_user( """ session_id = request.cookies.get("session_id", None) if session_id is None: - raise HTTPException(status_code=401, detail="User must be authenticated to get their info") + raise HTTPException(status_code=401, detail="user must be authenticated to get their info") user_info = await crud.get_site_user(db_session, session_id) if user_info is None: - raise HTTPException(status_code=401, detail="Could not find user with session_id, please log in") + raise HTTPException(status_code=401, detail="could not find user with session_id, please log in") - return JSONResponse(user_info.serializable_dict()) + return JSONResponse(user_info.serialize()) +# TODO: We should change this so that the admins can change people's pictures too, so they can remove offensive stuff @router.patch( "/user", operation_id="update_user", description="Update information for the currently logged in user. Only accessible by that user", + response_model=str, + responses={ + 401: { "description": "Not logged in.", "model": DetailModel } + }, ) async def update_user( - profile_picture_url: str, + body: UpdateUserParams, request: Request, db_session: database.DBSession, ): """ Returns the info stored in the site_user table in the auth module, if the user is logged in. """ - session_id = request.cookies.get("session_id", None) + session_id = request.cookies.get("session_id") if session_id is None: - raise HTTPException(status_code=401, detail="User must be authenticated to get their info") + raise HTTPException(status_code=401, detail="user must be authenticated to get their info") - ok = await crud.update_site_user(db_session, session_id, profile_picture_url) + ok = await crud.update_site_user(db_session, session_id, body.profile_picture_url) await db_session.commit() if not ok: - raise HTTPException(status_code=401, detail="Could not find user with session_id, please log in") - - return PlainTextResponse("ok") + raise HTTPException(status_code=401, detail="could not find user with session_id, please log in") diff --git a/src/database.py b/src/database.py index c1abf41..2e872bd 100644 --- a/src/database.py +++ b/src/database.py @@ -12,6 +12,7 @@ AsyncConnection, AsyncSession, ) +from sqlalchemy.orm import DeclarativeBase convention = { "ix": "ix_%(column_0_label)s", # index @@ -21,8 +22,8 @@ "pk": "pk_%(table_name)s", # primary key } -Base = sqlalchemy.orm.declarative_base() -Base.metadata = MetaData(naming_convention=convention) +class Base(DeclarativeBase): + metadata = MetaData(naming_convention=convention) # from: https://medium.com/@tclaitken/setting-up-a-fastapi-app-with-async-sqlalchemy-2-0-pydantic-v2-e6c540be4308 class DatabaseSessionManager: diff --git a/src/elections/models.py b/src/elections/models.py index 2b39614..ca7f1e7 100644 --- a/src/elections/models.py +++ b/src/elections/models.py @@ -1,9 +1,9 @@ -from enum import Enum +from enum import StrEnum from pydantic import BaseModel -class ElectionTypeEnum(str, Enum): +class ElectionTypeEnum(StrEnum): GENERAL = "general_election" BY_ELECTION = "by_election" COUNCIL_REP = "council_rep_election" diff --git a/src/utils/shared_models.py b/src/utils/shared_models.py index 121ede4..9f5766d 100644 --- a/src/utils/shared_models.py +++ b/src/utils/shared_models.py @@ -6,3 +6,6 @@ class SuccessFailModel(BaseModel): class DetailModel(BaseModel): detail: str + +class MessageModel(BaseModel): + message: str