diff --git a/pyproject.toml b/pyproject.toml
index dee1528..79e80cf 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,4 +26,8 @@ line-ending = "lf"
[tool.ruff.lint]
select = ["E", "F", "B", "I", "N", "UP", "A", "PTH", "W", "RUF", "C4", "PIE", "Q", "FLY"] # "ANN"
-ignore = ["E501", "F401", "N806"]
\ No newline at end of file
+ignore = ["E501", "F401", "N806"]
+
+[tool.pyright]
+executionEnvironments = [{ root = "src" }]
+typeCheckingMode = "standard"
diff --git a/src/auth/models.py b/src/auth/models.py
new file mode 100644
index 0000000..f342468
--- /dev/null
+++ b/src/auth/models.py
@@ -0,0 +1,7 @@
+from pydantic import BaseModel, Field
+
+
+class LoginBodyModel(BaseModel):
+ service: str = Field(description="Service URL used for SFU's CAS system")
+ ticket: str = Field(description="Ticket return from SFU's CAS system")
+ redirect_url: str | None = Field(None, description="Optional redirect URL")
diff --git a/src/auth/urls.py b/src/auth/urls.py
index 113cfda..af60046 100644
--- a/src/auth/urls.py
+++ b/src/auth/urls.py
@@ -5,12 +5,14 @@
import requests # TODO: make this async
import xmltodict
-from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
+from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, Response
from fastapi.responses import JSONResponse, PlainTextResponse, RedirectResponse
import database
from auth import crud
-from constants import FRONTEND_ROOT_URL
+from auth.models import LoginBodyModel
+from constants import DOMAIN, IS_PROD, SAMESITE
+from utils.shared_models import DetailModel
_logger = logging.getLogger(__name__)
@@ -32,27 +34,34 @@ def generate_session_id_b64(num_bytes: int) -> str:
)
-# NOTE: logging in a second time invaldiates the last session_id
-@router.get(
+# NOTE: logging in a second time invalidates the last session_id
+@router.post(
"/login",
- description="Login to the sfucsss.org. Must redirect to this endpoint from SFU's cas authentication service for correct parameters",
+ description="Create a login session.",
+ response_description="Successfully validated with SFU's CAS",
+ response_model=str,
+ responses={
+ 307: { "description": "Successful validation, with redirect" },
+ 400: { "description": "Origin is missing.", "model": DetailModel },
+ 401: { "description": "Failed to validate ticket with SFU's CAS", "model": DetailModel }
+ },
+ operation_id="login",
)
async def login_user(
- redirect_path: str,
- redirect_fragment: str,
- ticket: str,
+ request: Request,
db_session: database.DBSession,
background_tasks: BackgroundTasks,
+ body: LoginBodyModel
):
# verify the ticket is valid
- service = urllib.parse.quote(f"{FRONTEND_ROOT_URL}/api/auth/login?redirect_path={redirect_path}&redirect_fragment={redirect_fragment}")
- service_validate_url = f"https://cas.sfu.ca/cas/serviceValidate?service={service}&ticket={ticket}"
+ service_url = body.service
+ service = urllib.parse.quote(service_url)
+ service_validate_url = f"https://cas.sfu.ca/cas/serviceValidate?service={service}&ticket={body.ticket}"
cas_response = xmltodict.parse(requests.get(service_validate_url).text)
if "cas:authenticationFailure" in cas_response["cas:serviceResponse"]:
_logger.info(f"User failed to login, with response {cas_response}")
- raise HTTPException(status_code=401, detail="authentication error, ticket likely invalid")
-
+ raise HTTPException(status_code=401, detail="authentication error")
else:
session_id = generate_session_id_b64(256)
computing_id = cas_response["cas:serviceResponse"]["cas:authenticationSuccess"]["cas:user"]
@@ -63,15 +72,29 @@ async def login_user(
# clean old sessions after sending the response
background_tasks.add_task(crud.task_clean_expired_user_sessions, db_session)
- response = RedirectResponse(FRONTEND_ROOT_URL + redirect_path + "#" + redirect_fragment)
+ if body.redirect_url:
+ origin = request.headers.get("origin")
+ if origin:
+ response = RedirectResponse(origin + body.redirect_url)
+ else:
+ raise HTTPException(status_code=400, detail="bad origin")
+ else:
+ response = Response()
+
response.set_cookie(
- key="session_id", value=session_id
+ key="session_id",
+ value=session_id,
+ secure=IS_PROD,
+ httponly=True,
+ samesite=SAMESITE,
+ domain=DOMAIN
) # this overwrites any past, possibly invalid, session_id
return response
@router.get(
"/logout",
+ operation_id="logout",
description="Logs out the current user by invalidating the session_id cookie",
)
async def logout_user(
@@ -94,6 +117,7 @@ async def logout_user(
@router.get(
"/user",
+ operation_id="get_user",
description="Get info about the current user. Only accessible by that user",
)
async def get_user(
@@ -116,6 +140,7 @@ async def get_user(
@router.patch(
"/user",
+ operation_id="update_user",
description="Update information for the currently logged in user. Only accessible by that user",
)
async def update_user(
diff --git a/src/constants.py b/src/constants.py
index 97b0d0c..3b6a9c7 100644
--- a/src/constants.py
+++ b/src/constants.py
@@ -2,12 +2,13 @@
# TODO(future): replace new.sfucsss.org with sfucsss.org during migration
# TODO(far-future): branch-specific root IP addresses (e.g., devbranch.sfucsss.org)
-FRONTEND_ROOT_URL = "http://localhost:8080" if os.environ.get("LOCAL") == "true" else "https://new.sfucsss.org"
-GITHUB_ORG_NAME = "CSSS-Test-Organization" if os.environ.get("LOCAL") == "true" else "CSSS"
+ENV_LOCAL = os.environ.get("LOCAL")
+IS_PROD = True if not ENV_LOCAL or ENV_LOCAL.lower() != "true" else False
+GITHUB_ORG_NAME = "CSSS-Test-Organization" if not IS_PROD else "CSSS"
W3_GUILD_ID = "1260652618875797504"
CSSS_GUILD_ID = "228761314644852736"
-ACTIVE_GUILD_ID = W3_GUILD_ID if os.environ.get("LOCAL") == "true" else CSSS_GUILD_ID
+ACTIVE_GUILD_ID = W3_GUILD_ID if not IS_PROD else CSSS_GUILD_ID
SESSION_ID_LEN = 512
# technically a max of 8 digits https://www.sfu.ca/computing/about/support/tips/sfu-userid.html
@@ -25,3 +26,7 @@
# https://docs.github.com/en/enterprise-server@3.10/admin/identity-and-access-management/iam-configuration-reference/username-considerations-for-external-authentication
GITHUB_USERNAME_LEN = 39
+
+# COOKIE
+SAMESITE="none" if IS_PROD else "lax"
+DOMAIN=".sfucsss.org" if IS_PROD else None
diff --git a/src/main.py b/src/main.py
index 37b700f..0a5433b 100755
--- a/src/main.py
+++ b/src/main.py
@@ -1,9 +1,9 @@
import logging
-import os
from fastapi import FastAPI, Request, status
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
+from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import auth.urls
@@ -11,38 +11,47 @@
import elections.urls
import officers.urls
import permission.urls
+from constants import IS_PROD
logging.basicConfig(level=logging.DEBUG)
database.setup_database()
-_login_link = (
- "https://cas.sfu.ca/cas/login?service=" + (
- "http%3A%2F%2Flocalhost%3A8080"
- if os.environ.get("LOCAL") == "true"
- else "https%3A%2F%2Fnew.sfucsss.org"
- ) + "%2Fapi%2Fauth%2Flogin%3Fredirect_path%3D%2Fapi%2Fapi%2Fdocs%2F%26redirect_fragment%3D"
-)
-
# Enable OpenAPI docs only for local development
-if os.environ.get("LOCAL") == "true":
+if not IS_PROD:
+ print("Running local environment")
+ origins = [
+ "http://localhost:4200", # default Angular
+ "http://localhost:8080", # for existing applications/sites
+ ]
app = FastAPI(
lifespan=database.lifespan,
title="CSSS Site Backend",
- description=f'To login, please click here
To logout from CAS click here',
root_path="/api",
)
-# if on production, disable vieweing the docs
+# if on production, disable viewing the docs
else:
+ print("Running production environment")
+ origins = [
+ "https://sfucsss.org",
+ "https://test.sfucsss.org",
+ "https://admin.sfucsss.org"
+ ]
app = FastAPI(
lifespan=database.lifespan,
title="CSSS Site Backend",
- description=f'To login, please click here
To logout from CAS click here',
root_path="/api",
docs_url=None, # disables Swagger UI
redoc_url=None, # disables ReDoc
openapi_url=None # disables OpenAPI schema
)
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=origins,
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"]
+)
app.include_router(auth.urls.router)
app.include_router(elections.urls.router)
@@ -55,7 +64,7 @@ async def read_root():
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(
- request: Request,
+ _: Request,
exception: RequestValidationError,
):
return JSONResponse(
diff --git a/src/utils/shared_models.py b/src/utils/shared_models.py
index ceaa2e2..121ede4 100644
--- a/src/utils/shared_models.py
+++ b/src/utils/shared_models.py
@@ -3,3 +3,6 @@
class SuccessFailModel(BaseModel):
success: bool
+
+class DetailModel(BaseModel):
+ detail: str