55
66import requests # TODO: make this async
77import xmltodict
8- from fastapi import APIRouter , BackgroundTasks , HTTPException , Request
8+ from fastapi import APIRouter , BackgroundTasks , HTTPException , Request , Response
99from fastapi .responses import JSONResponse , PlainTextResponse , RedirectResponse
1010
1111import database
1212from auth import crud
13- from constants import FRONTEND_ROOT_URL
13+ from auth .models import LoginBodyModel
14+ from constants import DOMAIN , IS_PROD , SAMESITE
15+ from utils .shared_models import DetailModel
1416
1517_logger = logging .getLogger (__name__ )
1618
@@ -32,27 +34,34 @@ def generate_session_id_b64(num_bytes: int) -> str:
3234)
3335
3436
35- # NOTE: logging in a second time invaldiates the last session_id
36- @router .get (
37+ # NOTE: logging in a second time invalidates the last session_id
38+ @router .post (
3739 "/login" ,
38- description = "Login to the sfucsss.org. Must redirect to this endpoint from SFU's cas authentication service for correct parameters" ,
40+ description = "Create a login session." ,
41+ response_description = "Successfully validated with SFU's CAS" ,
42+ response_model = str ,
43+ responses = {
44+ 307 : { "description" : "Successful validation, with redirect" },
45+ 400 : { "description" : "Origin is missing." , "model" : DetailModel },
46+ 401 : { "description" : "Failed to validate ticket with SFU's CAS" , "model" : DetailModel }
47+ },
48+ operation_id = "login" ,
3949)
4050async def login_user (
41- redirect_path : str ,
42- redirect_fragment : str ,
43- ticket : str ,
51+ request : Request ,
4452 db_session : database .DBSession ,
4553 background_tasks : BackgroundTasks ,
54+ body : LoginBodyModel
4655):
4756 # verify the ticket is valid
48- service = urllib .parse .quote (f"{ FRONTEND_ROOT_URL } /api/auth/login?redirect_path={ redirect_path } &redirect_fragment={ redirect_fragment } " )
49- service_validate_url = f"https://cas.sfu.ca/cas/serviceValidate?service={ service } &ticket={ ticket } "
57+ service_url = body .service
58+ service = urllib .parse .quote (service_url )
59+ service_validate_url = f"https://cas.sfu.ca/cas/serviceValidate?service={ service } &ticket={ body .ticket } "
5060 cas_response = xmltodict .parse (requests .get (service_validate_url ).text )
5161
5262 if "cas:authenticationFailure" in cas_response ["cas:serviceResponse" ]:
5363 _logger .info (f"User failed to login, with response { cas_response } " )
54- raise HTTPException (status_code = 401 , detail = "authentication error, ticket likely invalid" )
55-
64+ raise HTTPException (status_code = 401 , detail = "authentication error" )
5665 else :
5766 session_id = generate_session_id_b64 (256 )
5867 computing_id = cas_response ["cas:serviceResponse" ]["cas:authenticationSuccess" ]["cas:user" ]
@@ -63,15 +72,29 @@ async def login_user(
6372 # clean old sessions after sending the response
6473 background_tasks .add_task (crud .task_clean_expired_user_sessions , db_session )
6574
66- response = RedirectResponse (FRONTEND_ROOT_URL + redirect_path + "#" + redirect_fragment )
75+ if body .redirect_url :
76+ origin = request .headers .get ("origin" )
77+ if origin :
78+ response = RedirectResponse (origin + body .redirect_url )
79+ else :
80+ raise HTTPException (status_code = 400 , detail = "bad origin" )
81+ else :
82+ response = Response ()
83+
6784 response .set_cookie (
68- key = "session_id" , value = session_id
85+ key = "session_id" ,
86+ value = session_id ,
87+ secure = IS_PROD ,
88+ httponly = True ,
89+ samesite = SAMESITE ,
90+ domain = DOMAIN
6991 ) # this overwrites any past, possibly invalid, session_id
7092 return response
7193
7294
7395@router .get (
7496 "/logout" ,
97+ operation_id = "logout" ,
7598 description = "Logs out the current user by invalidating the session_id cookie" ,
7699)
77100async def logout_user (
@@ -94,6 +117,7 @@ async def logout_user(
94117
95118@router .get (
96119 "/user" ,
120+ operation_id = "get_user" ,
97121 description = "Get info about the current user. Only accessible by that user" ,
98122)
99123async def get_user (
@@ -116,6 +140,7 @@ async def get_user(
116140
117141@router .patch (
118142 "/user" ,
143+ operation_id = "update_user" ,
119144 description = "Update information for the currently logged in user. Only accessible by that user" ,
120145)
121146async def update_user (
0 commit comments