11import logging
22from datetime import datetime , timedelta
3- from typing import Optional
43
54import sqlalchemy
6- from auth .tables import SiteUser , UserSession
75from sqlalchemy .ext .asyncio import AsyncSession
86
7+ from auth .tables import SiteUser , UserSession
8+ from auth .types import SiteUserData
9+
910
1011async def create_user_session (db_session : AsyncSession , session_id : str , computing_id : str ):
1112 """
@@ -35,13 +36,6 @@ async def create_user_session(db_session: AsyncSession, session_id: str, computi
3536 # update the last time the user logged in to now
3637 existing_user .last_logged_in = datetime .now ()
3738 else :
38- new_user_session = UserSession (
39- issue_time = datetime .now (),
40- session_id = session_id ,
41- computing_id = computing_id ,
42- )
43- db_session .add (new_user_session )
44-
4539 # add new user to User table if it's their first time logging in
4640 query = sqlalchemy .select (SiteUser ).where (SiteUser .computing_id == computing_id )
4741 existing_user = (await db_session .scalars (query )).first ()
@@ -53,30 +47,20 @@ async def create_user_session(db_session: AsyncSession, session_id: str, computi
5347 )
5448 db_session .add (new_user )
5549
50+ new_user_session = UserSession (
51+ issue_time = datetime .now (),
52+ session_id = session_id ,
53+ computing_id = computing_id ,
54+ )
55+ db_session .add (new_user_session )
56+
5657
5758async def remove_user_session (db_session : AsyncSession , session_id : str ) -> dict :
5859 query = sqlalchemy .select (UserSession ).where (UserSession .session_id == session_id )
5960 user_session = await db_session .scalars (query )
6061 await db_session .delete (user_session .first ())
6162
6263
63- async def check_user_session (db_session : AsyncSession , session_id : str ) -> dict :
64- query = sqlalchemy .select (UserSession ).where (UserSession .session_id == session_id )
65- existing_user_session = (await db_session .scalars (query )).first ()
66-
67- if existing_user_session :
68- query = sqlalchemy .select (SiteUser ).where (SiteUser .computing_id == existing_user_session .computing_id )
69- existing_user = (await db_session .scalars (query )).first ()
70- return {
71- "is_valid" : True ,
72- "computing_id" : existing_user_session .computing_id ,
73- "first_logged_in" : existing_user .first_logged_in .isoformat (),
74- "last_logged_in" : existing_user .last_logged_in .isoformat ()
75- }
76- else :
77- return {"is_valid" : False }
78-
79-
8064async def get_computing_id (db_session : AsyncSession , session_id : str ) -> str | None :
8165 query = sqlalchemy .select (UserSession ).where (UserSession .session_id == session_id )
8266 existing_user_session = (await db_session .scalars (query )).first ()
@@ -92,7 +76,8 @@ async def task_clean_expired_user_sessions(db_session: AsyncSession):
9276 await db_session .commit ()
9377
9478
95- async def user_info (db_session : AsyncSession , session_id : str ) -> None | dict :
79+ # get the site user given a session ID; returns None when session is invalid
80+ async def get_site_user (db_session : AsyncSession , session_id : str ) -> None | SiteUserData :
9681 query = (
9782 sqlalchemy
9883 .select (UserSession )
@@ -111,8 +96,43 @@ async def user_info(db_session: AsyncSession, session_id: str) -> None | dict:
11196 if user is None :
11297 return None
11398
114- return {
115- "computing_id" : user_session .computing_id ,
116- "first_logged_in" : user .first_logged_in .isoformat (),
117- "last_logged_in" : user .last_logged_in .isoformat ()
118- }
99+ return SiteUserData (
100+ user_session .computing_id ,
101+ user .first_logged_in .isoformat (),
102+ user .last_logged_in .isoformat (),
103+ user .profile_picture_url
104+ )
105+
106+
107+ # update the optional user info for a given site user (e.g., display name, profile picture, ...)
108+ async def update_site_user (
109+ db_session : AsyncSession ,
110+ session_id : str ,
111+ profile_picture_url : str
112+ ) -> None | SiteUserData :
113+ query = (
114+ sqlalchemy
115+ .select (UserSession )
116+ .where (UserSession .session_id == session_id )
117+ )
118+ user_session = await db_session .scalar (query )
119+ if user_session is None :
120+ return None
121+
122+ query = (
123+ sqlalchemy
124+ .update (SiteUser )
125+ .where (SiteUser .computing_id == user_session .computing_id )
126+ .values (profile_picture_url = profile_picture_url )
127+ .returning (SiteUser ) # returns all columns of SiteUser
128+ )
129+ user = await db_session .scalar (query )
130+ if user is None :
131+ return None
132+
133+ return SiteUserData (
134+ user_session .computing_id ,
135+ user .first_logged_in .isoformat (),
136+ user .last_logged_in .isoformat (),
137+ user .profile_picture_url
138+ )
0 commit comments