Skip to content
Closed
60 changes: 60 additions & 0 deletions src/alembic/versions/2f1b67c68ba5_add_exam_bank_tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""add exam bank tables

Revision ID: 2f1b67c68ba5
Revises: 3f19883760ae
Create Date: 2025-01-03 00:24:44.608869

"""
from collections.abc import Sequence

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "2f1b67c68ba5"
down_revision: str | None = "3f19883760ae"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
op.create_table(
"professor",
sa.Column("professor_id", sa.Integer, primary_key=True, autoincrement=True),
sa.Column("name", sa.String(128), nullable=False),
sa.Column("info_url", sa.String(128), nullable=False),
sa.Column("computing_id", sa.String(32), sa.ForeignKey("user_session.computing_id"), nullable=True),
)

op.create_table(
"course",
sa.Column("course_id", sa.Integer, primary_key=True, autoincrement=True),
sa.Column("course_faculty", sa.String(12), nullable=False),
sa.Column("course_number", sa.String(12), nullable=False),
sa.Column("course_name", sa.String(96), nullable=False),
)

op.create_table(
"exam_metadata",
sa.Column("exam_id", sa.Integer, primary_key=True),
sa.Column("upload_date", sa.DateTime, nullable=False),
sa.Column("exam_pdf_size", sa.Integer, nullable=False),

sa.Column("author_id", sa.String(32), sa.ForeignKey("professor.professor_id"), nullable=False),
sa.Column("author_confirmed", sa.Boolean, nullable=False),
sa.Column("author_permission", sa.Boolean, nullable=False),

sa.Column("kind", sa.String(24), nullable=False),
sa.Column("course_id", sa.String(32), sa.ForeignKey("professor.professor_id"), nullable=True),
sa.Column("title", sa.String(96), nullable=True),
sa.Column("description", sa.Text, nullable=True),

sa.Column("date_string", sa.String(10), nullable=False),
)


def downgrade() -> None:
op.drop_table("exam_metadata")
op.drop_table("professor")
op.drop_table("course")
25 changes: 25 additions & 0 deletions src/alembic/versions/3f19883760ae_add_session_type_to_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""add session_type to auth

Revision ID: 3f19883760ae
Revises: 2a6ea95342dc
Create Date: 2025-01-03 00:16:50.579541

"""
from collections.abc import Sequence

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "3f19883760ae"
down_revision: str | None = "2a6ea95342dc"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
op.add_column("user_session", sa.Column("session_type", sa.String(48), nullable=False))

def downgrade() -> None:
op.drop_column("user_session", "session_type")
52 changes: 36 additions & 16 deletions src/auth/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

_logger = logging.getLogger(__name__)

async def create_user_session(db_session: AsyncSession, session_id: str, computing_id: str):
async def create_user_session(
db_session: AsyncSession,
session_id: str,
computing_id: str,
session_type: str,
) -> None:
"""
Updates the past user session if one exists, so no duplicate sessions can ever occur.

Expand Down Expand Up @@ -46,50 +51,67 @@ async def create_user_session(db_session: AsyncSession, session_id: str, computi
existing_user.last_logged_in = datetime.now()
else:
db_session.add(UserSession(
session_id=session_id,
computing_id=computing_id,
issue_time=datetime.now(),
session_id=session_id,
session_type=session_type,
))


async def remove_user_session(db_session: AsyncSession, session_id: str) -> dict:
query = sqlalchemy.select(UserSession).where(UserSession.session_id == session_id)
user_session = await db_session.scalars(query)
user_session = await db_session.scalars(
sqlalchemy
.select(UserSession)
.where(UserSession.session_id == session_id)
)
await db_session.delete(user_session.first())


async def get_computing_id(db_session: AsyncSession, session_id: str) -> str | None:
query = sqlalchemy.select(UserSession).where(UserSession.session_id == session_id)
existing_user_session = (await db_session.scalars(query)).first()
existing_user_session = await db_session.scalar(
sqlalchemy
.select(UserSession)
.where(UserSession.session_id == session_id)
)
return existing_user_session.computing_id if existing_user_session else None


async def get_session_type(db_session: AsyncSession, session_id: str) -> str | None:
existing_user_session = await db_session.scalar(
sqlalchemy
.select(UserSession)
.where(UserSession.session_id == session_id)
)
return existing_user_session.session_type if existing_user_session else None


# remove all out of date user sessions
async def task_clean_expired_user_sessions(db_session: AsyncSession):
one_day_ago = datetime.now() - timedelta(days=0.5)

query = sqlalchemy.delete(UserSession).where(UserSession.issue_time < one_day_ago)
await db_session.execute(query)
await db_session.execute(
sqlalchemy
.delete(UserSession)
.where(UserSession.issue_time < one_day_ago)
)
await db_session.commit()


# get the site user given a session ID; returns None when session is invalid
async def get_site_user(db_session: AsyncSession, session_id: str) -> None | SiteUserData:
query = (
user_session = await db_session.scalar(
sqlalchemy
.select(UserSession)
.where(UserSession.session_id == session_id)
)
user_session = await db_session.scalar(query)
if user_session is None:
return None

query = (
user = await db_session.scalar(
sqlalchemy
.select(SiteUser)
.where(SiteUser.computing_id == user_session.computing_id)
)
user = await db_session.scalar(query)
if user is None:
return None

Expand All @@ -116,21 +138,19 @@ async def update_site_user(
session_id: str,
profile_picture_url: str
) -> bool:
query = (
user_session = await db_session.scalar(
sqlalchemy
.select(UserSession)
.where(UserSession.session_id == session_id)
)
user_session = await db_session.scalar(query)
if user_session is None:
return False

query = (
await db_session.execute(
sqlalchemy
.update(SiteUser)
.where(SiteUser.computing_id == user_session.computing_id)
.values(profile_picture_url = profile_picture_url)
)
await db_session.execute(query)

return True
4 changes: 3 additions & 1 deletion src/auth/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sqlalchemy import Column, DateTime, ForeignKey, String, Text

from constants import COMPUTING_ID_LEN, SESSION_ID_LEN
from constants import COMPUTING_ID_LEN, SESSION_ID_LEN, SESSION_TYPE_LEN
from database import Base


Expand All @@ -23,6 +23,8 @@ class UserSession(Base):
String(SESSION_ID_LEN), nullable=False, unique=True
) # the space needed to store 256 bytes in base64

# whether a user is faculty, csss-member, student, or just "sfu"
session_type = Column(String(SESSION_TYPE_LEN), nullable=False)

class SiteUser(Base):
# user is a reserved word in postgres
Expand Down
20 changes: 20 additions & 0 deletions src/auth/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
from dataclasses import dataclass


class SessionType:
# see: https://www.sfu.ca/information-systems/services/cas/cas-for-web-applications/
# for more info on the kinds of members
FACULTY = "faculty"
# TODO: what will happen to the maillists for authentication; are groups part of this?
CSSS_MEMBER = "csss member" # !cs-students maillist
STUDENT = "student"
ALUMNI = "alumni"
SFU = "sfu"

@staticmethod
def valid_session_type_list():
# values taken from https://www.sfu.ca/information-systems/services/cas/cas-for-web-applications.html
return [
"faculty",
"student",
"alumni",
"sfu"
]

@dataclass
class SiteUserData:
computing_id: str
Expand Down
38 changes: 32 additions & 6 deletions src/auth/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import database
from auth import crud
from auth.types import SessionType
from constants import FRONTEND_ROOT_URL

_logger = logging.getLogger(__name__)
Expand All @@ -31,7 +32,6 @@ def generate_session_id_b64(num_bytes: int) -> str:
tags=["authentication"],
)


# NOTE: logging in a second time invaldiates the last session_id
@router.get(
"/login",
Expand All @@ -47,17 +47,43 @@ async def login_user(
# verify the ticket is valid
service = urllib.parse.quote(f"{FRONTEND_ROOT_URL}/api/auth/login?redirect_path={redirect_path}&redirect_fragment={redirect_fragment}")
service_validate_url = f"https://cas.sfu.ca/cas/serviceValidate?service={service}&ticket={ticket}"
cas_response = xmltodict.parse(requests.get(service_validate_url).text)
cas_response_text = requests.get(service_validate_url).text
cas_response = xmltodict.parse(cas_response_text)

print("CAS RESPONSE ::")
print(cas_response_text)

if "cas:authenticationFailure" in cas_response["cas:serviceResponse"]:
_logger.info(f"User failed to login, with response {cas_response}")
raise HTTPException(status_code=401, detail="authentication error, ticket likely invalid")

else:
elif "cas:authenticationSuccess" in cas_response["cas:serviceResponse"]:
session_id = generate_session_id_b64(256)
computing_id = cas_response["cas:serviceResponse"]["cas:authenticationSuccess"]["cas:user"]

await crud.create_user_session(db_session, session_id, computing_id)
# NOTE: it is the frontend's job to pass the correct authentication reuqest to CAS, otherwise we
# will only be able to give a user the "sfu" session_type (least privileged)
if "cas:maillist" in cas_response["cas:serviceResponse"]:
# maillist
# TODO: (ASK SFU IT) can alumni be in the cmpt-students maillist?
if cas_response["cas:serviceResponse"]["cas:authenticationSuccess"]["cas:maillist"] == "cmpt-students":
session_type = SessionType.CSSS_MEMBER
else:
raise HTTPException(status_code=500, details="malformed cas:maillist authentication response; this is an SFU CAS error")
elif "cas:authtype" in cas_response["cas:serviceResponse"]["cas:authenticationSuccess"]:
# sfu, alumni, faculty, student
session_type = cas_response["cas:serviceResponse"]["cas:authenticationSuccess"]["cas:authtype"]
if session_type not in SessionType.valid_session_type_list():
raise HTTPException(status_code=500, detail=f"unexpected session type from SFU CAS of {session_type}")

if session_type == "alumni":
if "@" not in computing_id:
raise HTTPException(status_code=500, detail=f"invalid alumni computing_id response from CAS AUTH with value {session_type}")
computing_id = computing_id.split("@")[0]
else:
raise HTTPException(status_code=500, detail="malformed unknown authentication response; this is an SFU CAS error")

await crud.create_user_session(db_session, session_id, computing_id, session_type)
await db_session.commit()

# clean old sessions after sending the response
Expand All @@ -69,6 +95,8 @@ async def login_user(
) # this overwrites any past, possibly invalid, session_id
return response

else:
raise HTTPException(status_code=500, detail="malformed authentication response; this is an SFU CAS error")

@router.get(
"/logout",
Expand All @@ -91,7 +119,6 @@ async def logout_user(
response.delete_cookie(key="session_id")
return response


@router.get(
"/user",
description="Get info about the current user. Only accessible by that user",
Expand All @@ -113,7 +140,6 @@ async def get_user(

return JSONResponse(user_info.serializable_dict())


@router.patch(
"/user",
description="Update information for the currently logged in user. Only accessible by that user",
Expand Down
20 changes: 20 additions & 0 deletions src/auth/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from fastapi import HTTPException, Request

import auth.crud
import database


async def logged_in_or_raise(
request: Request,
db_session: database.DBSession
) -> tuple[str, str]:
"""gets the user's computing_id, or raises an exception if the current request is not logged in"""
session_id = request.cookies.get("session_id", None)
if session_id is None:
raise HTTPException(status_code=401)

session_computing_id = await auth.crud.get_computing_id(db_session, session_id)
if session_computing_id is None:
raise HTTPException(status_code=401)

return session_id, session_computing_id
3 changes: 3 additions & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
COMPUTING_ID_LEN = 32
COMPUTING_ID_MAX = 8

# depends how large SFU maillists can be
SESSION_TYPE_LEN = 48

# see https://support.discord.com/hc/en-us/articles/4407571667351-How-to-Find-User-IDs-for-Law-Enforcement#:~:text=Each%20Discord%20user%20is%20assigned,user%20and%20cannot%20be%20changed.
# NOTE: the length got updated to 19 in july 2024. See https://www.reddit.com/r/discordapp/comments/ucrp1r/only_3_months_until_discord_ids_hit_19_digits/
# I set us to 32 just in case...
Expand Down
Loading
Loading