diff --git a/pyproject.toml b/pyproject.toml index dee1528..79e80cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,4 +26,8 @@ line-ending = "lf" [tool.ruff.lint] select = ["E", "F", "B", "I", "N", "UP", "A", "PTH", "W", "RUF", "C4", "PIE", "Q", "FLY"] # "ANN" -ignore = ["E501", "F401", "N806"] \ No newline at end of file +ignore = ["E501", "F401", "N806"] + +[tool.pyright] +executionEnvironments = [{ root = "src" }] +typeCheckingMode = "standard" diff --git a/src/auth/models.py b/src/auth/models.py new file mode 100644 index 0000000..f342468 --- /dev/null +++ b/src/auth/models.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel, Field + + +class LoginBodyModel(BaseModel): + service: str = Field(description="Service URL used for SFU's CAS system") + ticket: str = Field(description="Ticket return from SFU's CAS system") + redirect_url: str | None = Field(None, description="Optional redirect URL") diff --git a/src/auth/urls.py b/src/auth/urls.py index 113cfda..af60046 100644 --- a/src/auth/urls.py +++ b/src/auth/urls.py @@ -5,12 +5,14 @@ import requests # TODO: make this async import xmltodict -from fastapi import APIRouter, BackgroundTasks, HTTPException, Request +from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, Response from fastapi.responses import JSONResponse, PlainTextResponse, RedirectResponse import database from auth import crud -from constants import FRONTEND_ROOT_URL +from auth.models import LoginBodyModel +from constants import DOMAIN, IS_PROD, SAMESITE +from utils.shared_models import DetailModel _logger = logging.getLogger(__name__) @@ -32,27 +34,34 @@ def generate_session_id_b64(num_bytes: int) -> str: ) -# NOTE: logging in a second time invaldiates the last session_id -@router.get( +# NOTE: logging in a second time invalidates the last session_id +@router.post( "/login", - description="Login to the sfucsss.org. Must redirect to this endpoint from SFU's cas authentication service for correct parameters", + description="Create a login session.", + response_description="Successfully validated with SFU's CAS", + response_model=str, + responses={ + 307: { "description": "Successful validation, with redirect" }, + 400: { "description": "Origin is missing.", "model": DetailModel }, + 401: { "description": "Failed to validate ticket with SFU's CAS", "model": DetailModel } + }, + operation_id="login", ) async def login_user( - redirect_path: str, - redirect_fragment: str, - ticket: str, + request: Request, db_session: database.DBSession, background_tasks: BackgroundTasks, + body: LoginBodyModel ): # verify the ticket is valid - service = urllib.parse.quote(f"{FRONTEND_ROOT_URL}/api/auth/login?redirect_path={redirect_path}&redirect_fragment={redirect_fragment}") - service_validate_url = f"https://cas.sfu.ca/cas/serviceValidate?service={service}&ticket={ticket}" + service_url = body.service + service = urllib.parse.quote(service_url) + service_validate_url = f"https://cas.sfu.ca/cas/serviceValidate?service={service}&ticket={body.ticket}" cas_response = xmltodict.parse(requests.get(service_validate_url).text) if "cas:authenticationFailure" in cas_response["cas:serviceResponse"]: _logger.info(f"User failed to login, with response {cas_response}") - raise HTTPException(status_code=401, detail="authentication error, ticket likely invalid") - + raise HTTPException(status_code=401, detail="authentication error") else: session_id = generate_session_id_b64(256) computing_id = cas_response["cas:serviceResponse"]["cas:authenticationSuccess"]["cas:user"] @@ -63,15 +72,29 @@ async def login_user( # clean old sessions after sending the response background_tasks.add_task(crud.task_clean_expired_user_sessions, db_session) - response = RedirectResponse(FRONTEND_ROOT_URL + redirect_path + "#" + redirect_fragment) + if body.redirect_url: + origin = request.headers.get("origin") + if origin: + response = RedirectResponse(origin + body.redirect_url) + else: + raise HTTPException(status_code=400, detail="bad origin") + else: + response = Response() + response.set_cookie( - key="session_id", value=session_id + key="session_id", + value=session_id, + secure=IS_PROD, + httponly=True, + samesite=SAMESITE, + domain=DOMAIN ) # this overwrites any past, possibly invalid, session_id return response @router.get( "/logout", + operation_id="logout", description="Logs out the current user by invalidating the session_id cookie", ) async def logout_user( @@ -94,6 +117,7 @@ async def logout_user( @router.get( "/user", + operation_id="get_user", description="Get info about the current user. Only accessible by that user", ) async def get_user( @@ -116,6 +140,7 @@ async def get_user( @router.patch( "/user", + operation_id="update_user", description="Update information for the currently logged in user. Only accessible by that user", ) async def update_user( diff --git a/src/constants.py b/src/constants.py index 97b0d0c..3b6a9c7 100644 --- a/src/constants.py +++ b/src/constants.py @@ -2,12 +2,13 @@ # TODO(future): replace new.sfucsss.org with sfucsss.org during migration # TODO(far-future): branch-specific root IP addresses (e.g., devbranch.sfucsss.org) -FRONTEND_ROOT_URL = "http://localhost:8080" if os.environ.get("LOCAL") == "true" else "https://new.sfucsss.org" -GITHUB_ORG_NAME = "CSSS-Test-Organization" if os.environ.get("LOCAL") == "true" else "CSSS" +ENV_LOCAL = os.environ.get("LOCAL") +IS_PROD = True if not ENV_LOCAL or ENV_LOCAL.lower() != "true" else False +GITHUB_ORG_NAME = "CSSS-Test-Organization" if not IS_PROD else "CSSS" W3_GUILD_ID = "1260652618875797504" CSSS_GUILD_ID = "228761314644852736" -ACTIVE_GUILD_ID = W3_GUILD_ID if os.environ.get("LOCAL") == "true" else CSSS_GUILD_ID +ACTIVE_GUILD_ID = W3_GUILD_ID if not IS_PROD else CSSS_GUILD_ID SESSION_ID_LEN = 512 # technically a max of 8 digits https://www.sfu.ca/computing/about/support/tips/sfu-userid.html @@ -25,3 +26,7 @@ # https://docs.github.com/en/enterprise-server@3.10/admin/identity-and-access-management/iam-configuration-reference/username-considerations-for-external-authentication GITHUB_USERNAME_LEN = 39 + +# COOKIE +SAMESITE="none" if IS_PROD else "lax" +DOMAIN=".sfucsss.org" if IS_PROD else None diff --git a/src/main.py b/src/main.py index 37b700f..0a5433b 100755 --- a/src/main.py +++ b/src/main.py @@ -1,9 +1,9 @@ import logging -import os from fastapi import FastAPI, Request, status from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse import auth.urls @@ -11,38 +11,47 @@ import elections.urls import officers.urls import permission.urls +from constants import IS_PROD logging.basicConfig(level=logging.DEBUG) database.setup_database() -_login_link = ( - "https://cas.sfu.ca/cas/login?service=" + ( - "http%3A%2F%2Flocalhost%3A8080" - if os.environ.get("LOCAL") == "true" - else "https%3A%2F%2Fnew.sfucsss.org" - ) + "%2Fapi%2Fauth%2Flogin%3Fredirect_path%3D%2Fapi%2Fapi%2Fdocs%2F%26redirect_fragment%3D" -) - # Enable OpenAPI docs only for local development -if os.environ.get("LOCAL") == "true": +if not IS_PROD: + print("Running local environment") + origins = [ + "http://localhost:4200", # default Angular + "http://localhost:8080", # for existing applications/sites + ] app = FastAPI( lifespan=database.lifespan, title="CSSS Site Backend", - description=f'To login, please click here

To logout from CAS click here', root_path="/api", ) -# if on production, disable vieweing the docs +# if on production, disable viewing the docs else: + print("Running production environment") + origins = [ + "https://sfucsss.org", + "https://test.sfucsss.org", + "https://admin.sfucsss.org" + ] app = FastAPI( lifespan=database.lifespan, title="CSSS Site Backend", - description=f'To login, please click here

To logout from CAS click here', root_path="/api", docs_url=None, # disables Swagger UI redoc_url=None, # disables ReDoc openapi_url=None # disables OpenAPI schema ) +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"] +) app.include_router(auth.urls.router) app.include_router(elections.urls.router) @@ -55,7 +64,7 @@ async def read_root(): @app.exception_handler(RequestValidationError) async def validation_exception_handler( - request: Request, + _: Request, exception: RequestValidationError, ): return JSONResponse( diff --git a/src/utils/shared_models.py b/src/utils/shared_models.py index ceaa2e2..121ede4 100644 --- a/src/utils/shared_models.py +++ b/src/utils/shared_models.py @@ -3,3 +3,6 @@ class SuccessFailModel(BaseModel): success: bool + +class DetailModel(BaseModel): + detail: str