diff --git a/src/alembic/versions/2f1b67c68ba5_add_exam_bank_tables.py b/src/alembic/versions/2f1b67c68ba5_add_exam_bank_tables.py new file mode 100644 index 0000000..72e8c15 --- /dev/null +++ b/src/alembic/versions/2f1b67c68ba5_add_exam_bank_tables.py @@ -0,0 +1,60 @@ +"""add exam bank tables + +Revision ID: 2f1b67c68ba5 +Revises: 3f19883760ae +Create Date: 2025-01-03 00:24:44.608869 + +""" +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "2f1b67c68ba5" +down_revision: str | None = "3f19883760ae" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_table( + "professor", + sa.Column("professor_id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("name", sa.String(128), nullable=False), + sa.Column("info_url", sa.String(128), nullable=False), + sa.Column("computing_id", sa.String(32), sa.ForeignKey("user_session.computing_id"), nullable=True), + ) + + op.create_table( + "course", + sa.Column("course_id", sa.Integer, primary_key=True, autoincrement=True), + sa.Column("course_faculty", sa.String(12), nullable=False), + sa.Column("course_number", sa.String(12), nullable=False), + sa.Column("course_name", sa.String(96), nullable=False), + ) + + op.create_table( + "exam_metadata", + sa.Column("exam_id", sa.Integer, primary_key=True), + sa.Column("upload_date", sa.DateTime, nullable=False), + sa.Column("exam_pdf_size", sa.Integer, nullable=False), + + sa.Column("author_id", sa.String(32), sa.ForeignKey("professor.professor_id"), nullable=False), + sa.Column("author_confirmed", sa.Boolean, nullable=False), + sa.Column("author_permission", sa.Boolean, nullable=False), + + sa.Column("kind", sa.String(24), nullable=False), + sa.Column("course_id", sa.String(32), sa.ForeignKey("professor.professor_id"), nullable=True), + sa.Column("title", sa.String(96), nullable=True), + sa.Column("description", sa.Text, nullable=True), + + sa.Column("date_string", sa.String(10), nullable=False), + ) + + +def downgrade() -> None: + op.drop_table("exam_metadata") + op.drop_table("professor") + op.drop_table("course") diff --git a/src/alembic/versions/3f19883760ae_add_session_type_to_auth.py b/src/alembic/versions/3f19883760ae_add_session_type_to_auth.py new file mode 100644 index 0000000..b59ea36 --- /dev/null +++ b/src/alembic/versions/3f19883760ae_add_session_type_to_auth.py @@ -0,0 +1,25 @@ +"""add session_type to auth + +Revision ID: 3f19883760ae +Revises: 2a6ea95342dc +Create Date: 2025-01-03 00:16:50.579541 + +""" +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "3f19883760ae" +down_revision: str | None = "2a6ea95342dc" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.add_column("user_session", sa.Column("session_type", sa.String(48), nullable=False)) + +def downgrade() -> None: + op.drop_column("user_session", "session_type") diff --git a/src/auth/crud.py b/src/auth/crud.py index eeda7a7..0c0c0a5 100644 --- a/src/auth/crud.py +++ b/src/auth/crud.py @@ -9,7 +9,12 @@ _logger = logging.getLogger(__name__) -async def create_user_session(db_session: AsyncSession, session_id: str, computing_id: str): +async def create_user_session( + db_session: AsyncSession, + session_id: str, + computing_id: str, + session_type: str, +) -> None: """ Updates the past user session if one exists, so no duplicate sessions can ever occur. @@ -46,50 +51,67 @@ async def create_user_session(db_session: AsyncSession, session_id: str, computi existing_user.last_logged_in = datetime.now() else: db_session.add(UserSession( - session_id=session_id, computing_id=computing_id, issue_time=datetime.now(), + session_id=session_id, + session_type=session_type, )) async def remove_user_session(db_session: AsyncSession, session_id: str) -> dict: - query = sqlalchemy.select(UserSession).where(UserSession.session_id == session_id) - user_session = await db_session.scalars(query) + user_session = await db_session.scalars( + sqlalchemy + .select(UserSession) + .where(UserSession.session_id == session_id) + ) await db_session.delete(user_session.first()) async def get_computing_id(db_session: AsyncSession, session_id: str) -> str | None: - query = sqlalchemy.select(UserSession).where(UserSession.session_id == session_id) - existing_user_session = (await db_session.scalars(query)).first() + existing_user_session = await db_session.scalar( + sqlalchemy + .select(UserSession) + .where(UserSession.session_id == session_id) + ) return existing_user_session.computing_id if existing_user_session else None +async def get_session_type(db_session: AsyncSession, session_id: str) -> str | None: + existing_user_session = await db_session.scalar( + sqlalchemy + .select(UserSession) + .where(UserSession.session_id == session_id) + ) + return existing_user_session.session_type if existing_user_session else None + + # remove all out of date user sessions async def task_clean_expired_user_sessions(db_session: AsyncSession): one_day_ago = datetime.now() - timedelta(days=0.5) - query = sqlalchemy.delete(UserSession).where(UserSession.issue_time < one_day_ago) - await db_session.execute(query) + await db_session.execute( + sqlalchemy + .delete(UserSession) + .where(UserSession.issue_time < one_day_ago) + ) await db_session.commit() # get the site user given a session ID; returns None when session is invalid async def get_site_user(db_session: AsyncSession, session_id: str) -> None | SiteUserData: - query = ( + user_session = await db_session.scalar( sqlalchemy .select(UserSession) .where(UserSession.session_id == session_id) ) - user_session = await db_session.scalar(query) if user_session is None: return None - query = ( + user = await db_session.scalar( sqlalchemy .select(SiteUser) .where(SiteUser.computing_id == user_session.computing_id) ) - user = await db_session.scalar(query) if user is None: return None @@ -116,21 +138,19 @@ async def update_site_user( session_id: str, profile_picture_url: str ) -> bool: - query = ( + user_session = await db_session.scalar( sqlalchemy .select(UserSession) .where(UserSession.session_id == session_id) ) - user_session = await db_session.scalar(query) if user_session is None: return False - query = ( + await db_session.execute( sqlalchemy .update(SiteUser) .where(SiteUser.computing_id == user_session.computing_id) .values(profile_picture_url = profile_picture_url) ) - await db_session.execute(query) return True diff --git a/src/auth/tables.py b/src/auth/tables.py index 9599018..03887d5 100644 --- a/src/auth/tables.py +++ b/src/auth/tables.py @@ -2,7 +2,7 @@ from sqlalchemy import Column, DateTime, ForeignKey, String, Text -from constants import COMPUTING_ID_LEN, SESSION_ID_LEN +from constants import COMPUTING_ID_LEN, SESSION_ID_LEN, SESSION_TYPE_LEN from database import Base @@ -23,6 +23,8 @@ class UserSession(Base): String(SESSION_ID_LEN), nullable=False, unique=True ) # the space needed to store 256 bytes in base64 + # whether a user is faculty, csss-member, student, or just "sfu" + session_type = Column(String(SESSION_TYPE_LEN), nullable=False) class SiteUser(Base): # user is a reserved word in postgres diff --git a/src/auth/types.py b/src/auth/types.py index 6587ca3..598724b 100644 --- a/src/auth/types.py +++ b/src/auth/types.py @@ -1,6 +1,26 @@ from dataclasses import dataclass +class SessionType: + # see: https://www.sfu.ca/information-systems/services/cas/cas-for-web-applications/ + # for more info on the kinds of members + FACULTY = "faculty" + # TODO: what will happen to the maillists for authentication; are groups part of this? + CSSS_MEMBER = "csss member" # !cs-students maillist + STUDENT = "student" + ALUMNI = "alumni" + SFU = "sfu" + + @staticmethod + def valid_session_type_list(): + # values taken from https://www.sfu.ca/information-systems/services/cas/cas-for-web-applications.html + return [ + "faculty", + "student", + "alumni", + "sfu" + ] + @dataclass class SiteUserData: computing_id: str diff --git a/src/auth/urls.py b/src/auth/urls.py index 113cfda..6c75f08 100644 --- a/src/auth/urls.py +++ b/src/auth/urls.py @@ -10,6 +10,7 @@ import database from auth import crud +from auth.types import SessionType from constants import FRONTEND_ROOT_URL _logger = logging.getLogger(__name__) @@ -31,7 +32,6 @@ def generate_session_id_b64(num_bytes: int) -> str: tags=["authentication"], ) - # NOTE: logging in a second time invaldiates the last session_id @router.get( "/login", @@ -47,17 +47,43 @@ async def login_user( # verify the ticket is valid service = urllib.parse.quote(f"{FRONTEND_ROOT_URL}/api/auth/login?redirect_path={redirect_path}&redirect_fragment={redirect_fragment}") service_validate_url = f"https://cas.sfu.ca/cas/serviceValidate?service={service}&ticket={ticket}" - cas_response = xmltodict.parse(requests.get(service_validate_url).text) + cas_response_text = requests.get(service_validate_url).text + cas_response = xmltodict.parse(cas_response_text) + + print("CAS RESPONSE ::") + print(cas_response_text) if "cas:authenticationFailure" in cas_response["cas:serviceResponse"]: _logger.info(f"User failed to login, with response {cas_response}") raise HTTPException(status_code=401, detail="authentication error, ticket likely invalid") - else: + elif "cas:authenticationSuccess" in cas_response["cas:serviceResponse"]: session_id = generate_session_id_b64(256) computing_id = cas_response["cas:serviceResponse"]["cas:authenticationSuccess"]["cas:user"] - await crud.create_user_session(db_session, session_id, computing_id) + # NOTE: it is the frontend's job to pass the correct authentication reuqest to CAS, otherwise we + # will only be able to give a user the "sfu" session_type (least privileged) + if "cas:maillist" in cas_response["cas:serviceResponse"]: + # maillist + # TODO: (ASK SFU IT) can alumni be in the cmpt-students maillist? + if cas_response["cas:serviceResponse"]["cas:authenticationSuccess"]["cas:maillist"] == "cmpt-students": + session_type = SessionType.CSSS_MEMBER + else: + raise HTTPException(status_code=500, details="malformed cas:maillist authentication response; this is an SFU CAS error") + elif "cas:authtype" in cas_response["cas:serviceResponse"]["cas:authenticationSuccess"]: + # sfu, alumni, faculty, student + session_type = cas_response["cas:serviceResponse"]["cas:authenticationSuccess"]["cas:authtype"] + if session_type not in SessionType.valid_session_type_list(): + raise HTTPException(status_code=500, detail=f"unexpected session type from SFU CAS of {session_type}") + + if session_type == "alumni": + if "@" not in computing_id: + raise HTTPException(status_code=500, detail=f"invalid alumni computing_id response from CAS AUTH with value {session_type}") + computing_id = computing_id.split("@")[0] + else: + raise HTTPException(status_code=500, detail="malformed unknown authentication response; this is an SFU CAS error") + + await crud.create_user_session(db_session, session_id, computing_id, session_type) await db_session.commit() # clean old sessions after sending the response @@ -69,6 +95,8 @@ async def login_user( ) # this overwrites any past, possibly invalid, session_id return response + else: + raise HTTPException(status_code=500, detail="malformed authentication response; this is an SFU CAS error") @router.get( "/logout", @@ -91,7 +119,6 @@ async def logout_user( response.delete_cookie(key="session_id") return response - @router.get( "/user", description="Get info about the current user. Only accessible by that user", @@ -113,7 +140,6 @@ async def get_user( return JSONResponse(user_info.serializable_dict()) - @router.patch( "/user", description="Update information for the currently logged in user. Only accessible by that user", diff --git a/src/auth/utils.py b/src/auth/utils.py new file mode 100644 index 0000000..ea86a8d --- /dev/null +++ b/src/auth/utils.py @@ -0,0 +1,20 @@ +from fastapi import HTTPException, Request + +import auth.crud +import database + + +async def logged_in_or_raise( + request: Request, + db_session: database.DBSession +) -> tuple[str, str]: + """gets the user's computing_id, or raises an exception if the current request is not logged in""" + session_id = request.cookies.get("session_id", None) + if session_id is None: + raise HTTPException(status_code=401) + + session_computing_id = await auth.crud.get_computing_id(db_session, session_id) + if session_computing_id is None: + raise HTTPException(status_code=401) + + return session_id, session_computing_id diff --git a/src/constants.py b/src/constants.py index 97b0d0c..49dc571 100644 --- a/src/constants.py +++ b/src/constants.py @@ -14,6 +14,9 @@ COMPUTING_ID_LEN = 32 COMPUTING_ID_MAX = 8 +# depends how large SFU maillists can be +SESSION_TYPE_LEN = 48 + # see https://support.discord.com/hc/en-us/articles/4407571667351-How-to-Find-User-IDs-for-Law-Enforcement#:~:text=Each%20Discord%20user%20is%20assigned,user%20and%20cannot%20be%20changed. # NOTE: the length got updated to 19 in july 2024. See https://www.reddit.com/r/discordapp/comments/ucrp1r/only_3_months_until_discord_ids_hit_19_digits/ # I set us to 32 just in case... diff --git a/src/exambank/tables.py b/src/exambank/tables.py new file mode 100644 index 0000000..bea4a09 --- /dev/null +++ b/src/exambank/tables.py @@ -0,0 +1,55 @@ +from datetime import datetime +from types import ExamKind + +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text +from sqlalchemy.orm import relationship + +from constants import COMPUTING_ID_LEN, SESSION_ID_LEN, SESSION_TYPE_LEN +from database import Base + +# TODO: determine what info will need to be in the spreadsheet, then moved here + +class ExamMetadata(Base): + __tablename__ = "exam_metadata" + + # exam_id is the number used to access the exam + exam_id = Column(Integer, primary_key=True) + upload_date = Column(DateTime, nullable=False) + exam_pdf_size = Column(Integer, nullable=False) # in bytes + + author_id = Column(String(COMPUTING_ID_LEN), ForeignKey("professor.professor_id"), nullable=False) + # whether this is the confirmed author of the exam, or just suspected + author_confirmed = Column(Boolean, nullable=False) + # true if the professor has given permission for us to use their exam + author_permission = Column(Boolean, nullable=False) + + kind = Column(String(24), nullable=False) + course_id = Column(String(COMPUTING_ID_LEN), ForeignKey("course.professor_id"), nullable=True) + title = Column(String(96), nullable=True) # Something like "Midterm 2" or "Computational Geometry Final" + description = Column(Text, nullable=True) # For a natural language description of the contents + + # formatted as xxxx-xx-xx, include x for unknown dates + date_string = Column(String(10), nullable=False) + +# TODO: eventually hook the following tables in with the rest of the site & coursys api + +class Professor(Base): + __tablename__ = "professor" + + professor_id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(128), nullable=False) + info_url = Column(String(128), nullable=False) # A url which provides more information about the professor + + # we may not know a professor's computing_id + computing_id = Column(String(COMPUTING_ID_LEN), ForeignKey("user_session.computing_id"), nullable=True) + +class Course(Base): + __tablename__ = "course" + + course_id = Column(Integer, primary_key=True, autoincrement=True) + + # formatted f"{faculty} {course_number}", ie. CMPT 300 + course_faculty = Column(String(12), nullable=False) + course_number = Column(String(12), nullable=False) + course_name = Column(String(96), nullable=False) + diff --git a/src/exambank/types.py b/src/exambank/types.py new file mode 100644 index 0000000..c68cc49 --- /dev/null +++ b/src/exambank/types.py @@ -0,0 +1,7 @@ +class ExamKind: + FINAL = "final" + MIDTERM = "midterm" + QUIZ = "quiz" + ASSIGNMENT = "assignment" + NOTES = "notes" + MISC = "misc" diff --git a/src/exambank/urls.py b/src/exambank/urls.py new file mode 100644 index 0000000..cfa9699 --- /dev/null +++ b/src/exambank/urls.py @@ -0,0 +1,88 @@ +import os + +import sqlalchemy +from fastapi import APIRouter, HTTPException, JSONResponse, Request, Response +from tables import Course, ExamMetadata, Professor + +import database +import exambank.crud +from auth.utils import logged_in_or_raise +from exambank.watermark import apply_watermark, create_watermark, raster_pdf +from permission.types import ExamBankAccess +from utils import path_in_dir + +# all exams are stored here, and for the time being must be manually moved here +EXAM_BANK_DIR = "/opt/csss-site/media/exam-bank" + +router = APIRouter( + prefix="/exam-bank", + tags=["exam-bank"], +) + +# TODO: update endpoints to use crud functions -> don't use crud actually; refactor to do that later + +@router.get( + "/metadata" +) +async def exam_metadata( + request: Request, + db_session: database.DBSession, +): + _, _ = await logged_in_or_raise(request, db_session) + await ExamBankAccess.has_permission_or_raise(request, errmsg="user must have exam bank access permission") + + """ + courses = [f.name for f in os.scandir(f"{EXAM_BANK_DIR}") if f.is_dir()] + if course_id_starts_with is not None: + courses = [course for course in courses if course.startswith(course_id_starts_with)] + + exam_list = exambank.crud.all_exams(db_session, course_id_starts_with) + return JSONResponse([exam.serializable_dict() for exam in exam_list]) + """ + + # TODO: test that the joins work correctly + exams = await db_session.scalar( + sqlalchemy + .select(ExamMetadata, Professor, Course) + .join(Professor) + .join(Course, isouter=True) # we want to have null values if the course is not known + .order_by(Course.course_number) + ) + + print(exams) + + # TODO: serialize exams somehow + return JSONResponse(exams) + +# TODO: implement endpoint to fetch exams +""" +@router.get( + "/exam/{exam_id}" +) +async def get_exam( + request: Request, + db_session: database.DBSession, + exam_id: int, +): + _, session_computing_id = await logged_in_or_raise(request, db_session) + await ExamBankAccess.has_permission_or_raise(request, errmsg="user must have exam bank access permission") + + # number exams with an exam_id pkey + # TODO: store resource locations in a db table & simply look them up + + meta = exambank.crud.exam_metadata(db_session, exam_id) + if meta is None: + raise HTTPException(status_code=400, detail=f"could not find the exam with exam_id={exam_id}") + + exam_path = f"{EXAM_BANK_DIR}/{meta.pdf_path}" + if not path_in_dir(exam_path, EXAM_BANK_DIR): + raise HTTPException(status_code=500, detail="Found dangerous pdf path, exiting") + + # TODO: test this works nicely + watermark = create_watermark(session_computing_id, 20) + watermarked_pdf = apply_watermark(exam_path, watermark) + image_bytes = raster_pdf(watermarked_pdf) + + headers = { "Content-Disposition": f'inline; filename="{meta.course_id}_{exam_id}_{session_computing_id}.pdf"' } + return Response(content=image_bytes, headers=headers, media_type="application/pdf") +""" diff --git a/src/exambank/watermark.py b/src/exambank/watermark.py index 7f124d6..a381f57 100644 --- a/src/exambank/watermark.py +++ b/src/exambank/watermark.py @@ -1,3 +1,4 @@ +from datetime import datetime from io import BytesIO from pathlib import Path @@ -12,8 +13,8 @@ BORDER = 20 def create_watermark( - computing_id: str, - density: int = 5 + computing_id: str, + density: int = 5 ) -> BytesIO: """ Returns a PDF with one page containing the watermark as text. @@ -40,7 +41,6 @@ def create_watermark( warning_pdf.setFillColor(colors.grey, alpha=0.75) warning_pdf.setFont("Helvetica", 14) - from datetime import datetime warning_pdf.drawString(BORDER, BORDER, f"This exam was generated by {computing_id} at {datetime.now()}") warning_pdf.save() @@ -50,6 +50,7 @@ def create_watermark( watermark_pdf = PdfWriter() stamp_pdf = PdfReader(stamp_buffer) warning_pdf = PdfReader(warning_buffer) + # Destructively merges in place stamp_pdf.pages[0].merge_page(warning_pdf.pages[0]) watermark_pdf.add_page(stamp_pdf.pages[0]) @@ -60,9 +61,9 @@ def create_watermark( return watermark_buffer def apply_watermark( - pdf_path: Path | str, - # expect a BytesIO instance (at position 0), accept a file/path - stamp: BytesIO | Path | str, + pdf_path: Path | str, + # expect a BytesIO instance (at position 0), accept a file/path + stamp: BytesIO | Path | str, ) -> BytesIO: # process file stamp_page = PdfReader(stamp).pages[0] @@ -77,12 +78,11 @@ def apply_watermark( watermarked_pdf = BytesIO() writer.write(watermarked_pdf) watermarked_pdf.seek(0) - return watermarked_pdf def raster_pdf( - pdf_path: BytesIO, - dpi: int = 300 + pdf_path: BytesIO, + dpi: int = 300 ) -> BytesIO: raster_buffer = BytesIO() # adapted from https://github.com/pymupdf/PyMuPDF/discussions/1183 @@ -97,14 +97,17 @@ def raster_pdf( tarpage.insert_image(tarpage.rect, stream=pix.pil_tobytes("PNG")) target.save(raster_buffer) + raster_buffer.seek(0) return raster_buffer +# TODO: not sure what this function does, but let's remove it? def raster_pdf_from_path( - pdf_path: Path | str, - dpi: int = 300 + pdf_path: Path | str, + dpi: int = 300 ) -> BytesIO: raster_buffer = BytesIO() + # adapted from https://github.com/pymupdf/PyMuPDF/discussions/1183 with pymupdf.open(filename=pdf_path) as doc: page_count = doc.page_count @@ -117,5 +120,6 @@ def raster_pdf_from_path( tarpage.insert_image(tarpage.rect, stream=pix.pil_tobytes("PNG")) target.save(raster_buffer) + raster_buffer.seek(0) return raster_buffer diff --git a/src/load_test_db.py b/src/load_test_db.py index b9bc596..84eee1e 100644 --- a/src/load_test_db.py +++ b/src/load_test_db.py @@ -9,6 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from auth.crud import create_user_session, update_site_user +from auth.types import SessionType from database import SQLALCHEMY_TEST_DATABASE_URL, Base, DatabaseSessionManager from officers.constants import OfficerPosition from officers.crud import create_new_officer_info, create_new_officer_term, update_officer_info, update_officer_term @@ -58,15 +59,15 @@ async def reset_db(engine): print(f"new tables: {table_list}") async def load_test_auth_data(db_session: AsyncSession): - await create_user_session(db_session, "temp_id_314", "abc314") + await create_user_session(db_session, "temp_id_314", "abc314", SessionType.SFU) await update_site_user(db_session, "temp_id_314", "www.my_profile_picture_url.ca/test") await db_session.commit() async def load_test_officers_data(db_session: AsyncSession): print("login the 3 users, putting them in the site users table") - await create_user_session(db_session, "temp_id_1", "abc11") - await create_user_session(db_session, "temp_id_2", "abc22") - await create_user_session(db_session, "temp_id_3", "abc33") + await create_user_session(db_session, "temp_id_1", "abc11", SessionType.SFU) + await create_user_session(db_session, "temp_id_2", "abc22", SessionType.SFU) + await create_user_session(db_session, "temp_id_3", "abc33", SessionType.FACULTY) await db_session.commit() print("add officer info") @@ -216,7 +217,7 @@ async def load_test_officers_data(db_session: AsyncSession): async def load_sysadmin(db_session: AsyncSession): # put your computing id here for testing purposes print(f"loading new sysadmin '{SYSADMIN_COMPUTING_ID}'") - await create_user_session(db_session, f"temp_id_{SYSADMIN_COMPUTING_ID}", SYSADMIN_COMPUTING_ID) + await create_user_session(db_session, f"temp_id_{SYSADMIN_COMPUTING_ID}", SYSADMIN_COMPUTING_ID, SessionType.CSSS_MEMBER) await create_new_officer_info(db_session, OfficerInfo( legal_name="Gabe Schulz", discord_id=None, diff --git a/src/officers/urls.py b/src/officers/urls.py index decdb10..901cafc 100755 --- a/src/officers/urls.py +++ b/src/officers/urls.py @@ -7,6 +7,7 @@ import database import officers.crud import utils +from auth.utils import logged_in_or_raise from officers.tables import OfficerInfo, OfficerTerm from officers.types import InitialOfficerInfo, OfficerInfoUpload, OfficerTermUpload from permission.types import OfficerPrivateInfo, WebsiteAdmin @@ -37,21 +38,6 @@ async def has_officer_private_info_access( has_private_access = await OfficerPrivateInfo.has_permission(db_session, computing_id) return session_id, computing_id, has_private_access -async def logged_in_or_raise( - request: Request, - db_session: database.DBSession -) -> tuple[str, str]: - """gets the user's computing_id, or raises an exception if the current request is not logged in""" - session_id = request.cookies.get("session_id", None) - if session_id is None: - raise HTTPException(status_code=401) - - session_computing_id = await auth.crud.get_computing_id(db_session, session_id) - if session_computing_id is None: - raise HTTPException(status_code=401) - - return session_id, session_computing_id - # ---------------------------------------- # # endpoints diff --git a/src/permission/types.py b/src/permission/types.py index 659ed32..6766dff 100644 --- a/src/permission/types.py +++ b/src/permission/types.py @@ -1,11 +1,13 @@ from datetime import date from typing import ClassVar -from fastapi import HTTPException +from fastapi import HTTPException, Request +import auth.crud import database import officers.crud import utils +from auth.types import SessionType from data.semesters import step_semesters from officers.constants import OfficerPosition @@ -48,11 +50,51 @@ async def has_permission(db_session: database.DBSession, computing_id: str) -> b return True return False + @staticmethod + async def validate_request(db_session: database.DBSession, request: Request): + """ + Checks if the provided request satisfies these permissions, and raises the neccessary + exceptions if not + """ + session_id = request.cookies.get("session_id", None) + if session_id is None: + raise HTTPException(status_code=401, detail="must be logged in") + + computing_id = await auth.crud.get_computing_id(db_session, session_id) + if not await WebsiteAdmin.has_permission(db_session, computing_id): + raise HTTPException(status_code=401, detail="must have website admin permissions") + @staticmethod async def has_permission_or_raise( db_session: database.DBSession, computing_id: str, - errmsg:str = "must have website admin permissions" + errmsg: str = "must have website admin permissions" ) -> bool: if not await WebsiteAdmin.has_permission(db_session, computing_id): raise HTTPException(status_code=401, detail=errmsg) + +class ExamBankAccess: + @staticmethod + async def has_permission( + db_session: database.DBSession, + request: Request, + ) -> bool: + session_id = request.cookies.get("session_id", None) + if session_id is None: + return False + + if await auth.crud.get_session_type(db_session, session_id) == SessionType.FACULTY: + return True + + # the only non-faculty who can view exams are website admins + computing_id = await auth.crud.get_computing_id(db_session, session_id) + return await WebsiteAdmin.has_permission(db_session, computing_id) + + @staticmethod + async def has_permission_or_raise( + db_session: database.DBSession, + request: Request, + errmsg: str = "must have exam bank access permissions" + ): + if not await ExamBankAccess.has_permission(db_session, request): + raise HTTPException(status_code=401, detail=errmsg) diff --git a/src/utils.py b/src/utils.py index acf5ad0..a3a9cfe 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,5 +1,6 @@ import re from datetime import date, datetime +from pathlib import Path from sqlalchemy import Select @@ -34,6 +35,16 @@ def is_active_officer(query: Select) -> Select: ) ) +# TODO: test this +def path_in_dir(path: str, parent_dir: str): + """ + Determine if path is in parent_dir. A useful check for input + validation, to avoid leaking secrets + """ + parent = Path(parent_dir).resolve() + child = Path(path).resolve() + return parent in child.parents + def has_started_term(query: Select) -> bool: return query.where( OfficerTerm.start_date <= date.today() diff --git a/tests/integration/test_officers.py b/tests/integration/test_officers.py index dd8ba0c..8d357b6 100644 --- a/tests/integration/test_officers.py +++ b/tests/integration/test_officers.py @@ -7,6 +7,7 @@ import load_test_db from auth.crud import create_user_session +from auth.types import SessionType from database import SQLALCHEMY_TEST_DATABASE_URL, DatabaseSessionManager from main import app from officers.constants import OfficerPosition @@ -169,7 +170,7 @@ async def test__endpoints_admin(client, database_setup): # login as website admin session_id = "temp_id_" + load_test_db.SYSADMIN_COMPUTING_ID async with database_setup.session() as db_session: - await create_user_session(db_session, session_id, load_test_db.SYSADMIN_COMPUTING_ID) + await create_user_session(db_session, session_id, load_test_db.SYSADMIN_COMPUTING_ID, SessionType.CSSS_MEMBER) client.cookies = { "session_id": session_id }