Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions src/auth/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from sqlalchemy.ext.asyncio import AsyncSession

from auth.tables import SiteUser, UserSession
from auth.types import SiteUserData

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -52,7 +51,7 @@ async def create_user_session(db_session: AsyncSession, session_id: str, computi
))


async def remove_user_session(db_session: AsyncSession, session_id: str) -> dict:
async def remove_user_session(db_session: AsyncSession, session_id: str):
query = sqlalchemy.select(UserSession).where(UserSession.session_id == session_id)
user_session = await db_session.scalars(query)
await db_session.delete(user_session.first())
Expand All @@ -74,7 +73,7 @@ async def task_clean_expired_user_sessions(db_session: AsyncSession):


# get the site user given a session ID; returns None when session is invalid
async def get_site_user(db_session: AsyncSession, session_id: str) -> None | SiteUserData:
async def get_site_user(db_session: AsyncSession, session_id: str) -> SiteUser | None:
query = (
sqlalchemy
.select(UserSession)
Expand All @@ -89,17 +88,7 @@ async def get_site_user(db_session: AsyncSession, session_id: str) -> None | Sit
.select(SiteUser)
.where(SiteUser.computing_id == user_session.computing_id)
)
user = await db_session.scalar(query)
if user is None:
return None

return SiteUserData(
user_session.computing_id,
user.first_logged_in.isoformat(),
user.last_logged_in.isoformat(),
user.profile_picture_url
)

return await db_session.scalar(query)

async def site_user_exists(db_session: AsyncSession, computing_id: str) -> bool:
user = await db_session.scalar(
Expand Down
18 changes: 17 additions & 1 deletion src/auth/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
from datetime import datetime

from pydantic import BaseModel, Field


class LoginBodyModel(BaseModel):
class LoginBodyParams(BaseModel):
service: str = Field(description="Service URL used for SFU's CAS system")
ticket: str = Field(description="Ticket return from SFU's CAS system")
redirect_url: str | None = Field(None, description="Optional redirect URL")

class UpdateUserParams(BaseModel):
profile_picture_url: str

class UserSessionModel(BaseModel):
computing_id: str
issue_time: datetime
session_id: str

class SiteUserModel(BaseModel):
computing_id: str
first_logged_in: datetime
last_logged_in: datetime
profile_picture_url: str | None = None
26 changes: 18 additions & 8 deletions src/auth/tables.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datetime import datetime
from typing import Any

from sqlalchemy import Column, DateTime, ForeignKey, String, Text
from sqlalchemy import DateTime, ForeignKey, String, Text
from sqlalchemy.orm import Mapped, mapped_column

from constants import COMPUTING_ID_LEN, SESSION_ID_LEN
from database import Base
Expand All @@ -9,17 +11,17 @@
class UserSession(Base):
__tablename__ = "user_session"

computing_id = Column(
computing_id: Mapped[str] = mapped_column(
String(COMPUTING_ID_LEN),
ForeignKey("site_user.computing_id"),
# in psql pkey means non-null
primary_key=True,
)

# time the CAS ticket was issued
issue_time = Column(DateTime, nullable=False)
issue_time: Mapped[datetime] = mapped_column(DateTime, nullable=False)

session_id = Column(
session_id: Mapped[str] = mapped_column(
String(SESSION_ID_LEN), nullable=False, unique=True
) # the space needed to store 256 bytes in base64

Expand All @@ -29,15 +31,23 @@ class SiteUser(Base):
# see: https://stackoverflow.com/questions/22256124/cannot-create-a-database-table-named-user-in-postgresql
__tablename__ = "site_user"

computing_id = Column(
computing_id: Mapped[str] = mapped_column(
String(COMPUTING_ID_LEN),
primary_key=True,
)

# first and last time logged into the CSSS API
# note: default date (for pre-existing columns) is June 16th, 2024
first_logged_in = Column(DateTime, nullable=False, default=datetime(2024, 6, 16))
last_logged_in = Column(DateTime, nullable=False, default=datetime(2024, 6, 16))
first_logged_in: Mapped[datetime] = mapped_column(DateTime, nullable=False, default=datetime(2024, 6, 16))
last_logged_in: Mapped[datetime] = mapped_column(DateTime, nullable=False, default=datetime(2024, 6, 16))

# optional user information for display purposes
profile_picture_url = Column(Text, nullable=True)
profile_picture_url: Mapped[str | None] = mapped_column(Text, nullable=True)

def serialize(self) -> dict[str, str | int | bool | None]:
return {
"computing_id": self.computing_id,
"first_logged_in": self.first_logged_in.isoformat(),
"last_logged_in": self.last_logged_in.isoformat(),
"profile_picture_url": self.profile_picture_url
}
38 changes: 23 additions & 15 deletions src/auth/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import requests # TODO: make this async
import xmltodict
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, Response
from fastapi.responses import JSONResponse, PlainTextResponse, RedirectResponse
from fastapi.responses import JSONResponse, RedirectResponse

import database
from auth import crud
from auth.models import LoginBodyModel
from auth.models import LoginBodyParams, SiteUserModel, UpdateUserParams
from constants import DOMAIN, IS_PROD, SAMESITE
from utils.shared_models import DetailModel
from utils.shared_models import DetailModel, MessageModel

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -51,7 +51,7 @@ async def login_user(
request: Request,
db_session: database.DBSession,
background_tasks: BackgroundTasks,
body: LoginBodyModel
body: LoginBodyParams
):
# verify the ticket is valid
service_url = body.service
Expand Down Expand Up @@ -94,8 +94,9 @@ async def login_user(

@router.get(
"/logout",
operation_id="logout",
description="Logs out the current user by invalidating the session_id cookie",
operation_id="logout",
response_model=MessageModel
)
async def logout_user(
request: Request,
Expand All @@ -119,6 +120,10 @@ async def logout_user(
"/user",
operation_id="get_user",
description="Get info about the current user. Only accessible by that user",
response_model=SiteUserModel,
responses={
401: { "description": "Not logged in.", "model": DetailModel }
},
)
async def get_user(
request: Request,
Expand All @@ -129,35 +134,38 @@ async def get_user(
"""
session_id = request.cookies.get("session_id", None)
if session_id is None:
raise HTTPException(status_code=401, detail="User must be authenticated to get their info")
raise HTTPException(status_code=401, detail="user must be authenticated to get their info")

user_info = await crud.get_site_user(db_session, session_id)
if user_info is None:
raise HTTPException(status_code=401, detail="Could not find user with session_id, please log in")
raise HTTPException(status_code=401, detail="could not find user with session_id, please log in")

return JSONResponse(user_info.serializable_dict())
return JSONResponse(user_info.serialize())


# TODO: We should change this so that the admins can change people's pictures too, so they can remove offensive stuff
@router.patch(
"/user",
operation_id="update_user",
description="Update information for the currently logged in user. Only accessible by that user",
response_model=str,
responses={
401: { "description": "Not logged in.", "model": DetailModel }
},
)
async def update_user(
profile_picture_url: str,
body: UpdateUserParams,
request: Request,
db_session: database.DBSession,
):
"""
Returns the info stored in the site_user table in the auth module, if the user is logged in.
"""
session_id = request.cookies.get("session_id", None)
session_id = request.cookies.get("session_id")
if session_id is None:
raise HTTPException(status_code=401, detail="User must be authenticated to get their info")
raise HTTPException(status_code=401, detail="user must be authenticated to get their info")

ok = await crud.update_site_user(db_session, session_id, profile_picture_url)
ok = await crud.update_site_user(db_session, session_id, body.profile_picture_url)
await db_session.commit()
if not ok:
raise HTTPException(status_code=401, detail="Could not find user with session_id, please log in")

return PlainTextResponse("ok")
raise HTTPException(status_code=401, detail="could not find user with session_id, please log in")
5 changes: 3 additions & 2 deletions src/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
AsyncConnection,
AsyncSession,
)
from sqlalchemy.orm import DeclarativeBase

convention = {
"ix": "ix_%(column_0_label)s", # index
Expand All @@ -21,8 +22,8 @@
"pk": "pk_%(table_name)s", # primary key
}

Base = sqlalchemy.orm.declarative_base()
Base.metadata = MetaData(naming_convention=convention)
class Base(DeclarativeBase):
metadata = MetaData(naming_convention=convention)

# from: https://medium.com/@tclaitken/setting-up-a-fastapi-app-with-async-sqlalchemy-2-0-pydantic-v2-e6c540be4308
class DatabaseSessionManager:
Expand Down
4 changes: 2 additions & 2 deletions src/elections/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from enum import Enum
from enum import StrEnum

from pydantic import BaseModel


class ElectionTypeEnum(str, Enum):
class ElectionTypeEnum(StrEnum):
GENERAL = "general_election"
BY_ELECTION = "by_election"
COUNCIL_REP = "council_rep_election"
Expand Down
3 changes: 3 additions & 0 deletions src/utils/shared_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ class SuccessFailModel(BaseModel):

class DetailModel(BaseModel):
detail: str

class MessageModel(BaseModel):
message: str