77
88import aiofiles
99import aiohttp
10- from chainlit . context import context
10+
1111from chainlit .data .base import BaseDataLayer , BaseStorageClient
1212from chainlit .data .utils import queue_until_user_message
1313from chainlit .element import ElementDict
@@ -84,6 +84,9 @@ async def execute_sql(
8484 if result .returns_rows :
8585 json_result = [dict (row ._mapping ) for row in result .fetchall ()]
8686 clean_json_result = self .clean_result (json_result )
87+ assert isinstance (clean_json_result , list ) or isinstance (
88+ clean_json_result , int
89+ )
8790 return clean_json_result
8891 else :
8992 return result .rowcount
@@ -118,7 +121,47 @@ async def get_user(self, identifier: str) -> Optional[PersistedUser]:
118121 result = await self .execute_sql (query = query , parameters = parameters )
119122 if result and isinstance (result , list ):
120123 user_data = result [0 ]
121- return PersistedUser (** user_data )
124+
125+ # SQLite returns JSON as string, we most convert it. (#1137)
126+ metadata = user_data .get ("metadata" , {})
127+ if isinstance (metadata , str ):
128+ metadata = json .loads (metadata )
129+
130+ assert isinstance (metadata , dict )
131+ assert isinstance (user_data ["id" ], str )
132+ assert isinstance (user_data ["identifier" ], str )
133+ assert isinstance (user_data ["createdAt" ], str )
134+
135+ return PersistedUser (
136+ id = user_data ["id" ],
137+ identifier = user_data ["identifier" ],
138+ createdAt = user_data ["createdAt" ],
139+ metadata = metadata ,
140+ )
141+ return None
142+
143+ async def _get_user_identifer_by_id (self , user_id : str ) -> str :
144+ if self .show_logger :
145+ logger .info (f"SQLAlchemy: _get_user_identifer_by_id, user_id={ user_id } " )
146+ query = "SELECT identifier FROM users WHERE id = :user_id"
147+ parameters = {"user_id" : user_id }
148+ result = await self .execute_sql (query = query , parameters = parameters )
149+
150+ assert result
151+ assert isinstance (result , list )
152+
153+ return result [0 ]["identifier" ]
154+
155+ async def _get_user_id_by_thread (self , thread_id : str ) -> Optional [str ]:
156+ if self .show_logger :
157+ logger .info (f"SQLAlchemy: _get_user_id_by_thread, thread_id={ thread_id } " )
158+ query = "SELECT userId FROM threads WHERE id = :thread_id"
159+ parameters = {"thread_id" : thread_id }
160+ result = await self .execute_sql (query = query , parameters = parameters )
161+ if result :
162+ assert isinstance (result , list )
163+ return result [0 ]["userId" ]
164+
122165 return None
123166
124167 async def create_user (self , user : User ) -> Optional [PersistedUser ]:
@@ -179,10 +222,11 @@ async def update_thread(
179222 ):
180223 if self .show_logger :
181224 logger .info (f"SQLAlchemy: update_thread, thread_id={ thread_id } " )
182- if context .session .user is not None :
183- user_identifier = context .session .user .identifier
184- else :
185- raise ValueError ("User not found in session context" )
225+
226+ user_identifier = None
227+ if user_id :
228+ user_identifier = await self ._get_user_identifer_by_id (user_id )
229+
186230 data = {
187231 "id" : thread_id ,
188232 "createdAt" : (
@@ -294,8 +338,7 @@ async def list_threads(
294338 async def create_step (self , step_dict : "StepDict" ):
295339 if self .show_logger :
296340 logger .info (f"SQLAlchemy: create_step, step_id={ step_dict .get ('id' )} " )
297- if not getattr (context .session .user , "id" , None ):
298- raise ValueError ("No authenticated user in context" )
341+
299342 step_dict ["showInput" ] = (
300343 str (step_dict .get ("showInput" , "" )).lower ()
301344 if "showInput" in step_dict
@@ -373,12 +416,18 @@ async def delete_feedback(self, feedback_id: str) -> bool:
373416 return True
374417
375418 ###### Elements ######
376- async def get_element (self , thread_id : str , element_id : str ) -> Optional ["ElementDict" ]:
419+ async def get_element (
420+ self , thread_id : str , element_id : str
421+ ) -> Optional ["ElementDict" ]:
377422 if self .show_logger :
378- logger .info (f"SQLAlchemy: get_element, thread_id={ thread_id } , element_id={ element_id } " )
423+ logger .info (
424+ f"SQLAlchemy: get_element, thread_id={ thread_id } , element_id={ element_id } "
425+ )
379426 query = """SELECT * FROM elements WHERE "threadId" = :thread_id AND "id" = :element_id"""
380427 parameters = {"thread_id" : thread_id , "element_id" : element_id }
381- element : Union [List [Dict [str , Any ]], int , None ] = await self .execute_sql (query = query , parameters = parameters )
428+ element : Union [List [Dict [str , Any ]], int , None ] = await self .execute_sql (
429+ query = query , parameters = parameters
430+ )
382431 if isinstance (element , list ) and element :
383432 element_dict : Dict [str , Any ] = element [0 ]
384433 return ElementDict (
@@ -396,7 +445,7 @@ async def get_element(self, thread_id: str, element_id: str) -> Optional["Elemen
396445 autoPlay = element_dict .get ("autoPlay" ),
397446 playerConfig = element_dict .get ("playerConfig" ),
398447 forId = element_dict .get ("forId" ),
399- mime = element_dict .get ("mime" )
448+ mime = element_dict .get ("mime" ),
400449 )
401450 else :
402451 return None
@@ -405,8 +454,7 @@ async def get_element(self, thread_id: str, element_id: str) -> Optional["Elemen
405454 async def create_element (self , element : "Element" ):
406455 if self .show_logger :
407456 logger .info (f"SQLAlchemy: create_element, element_id = { element .id } " )
408- if not getattr (context .session .user , "id" , None ):
409- raise ValueError ("No authenticated user in context" )
457+
410458 if not self .storage_provider :
411459 logger .warn (
412460 "SQLAlchemy: create_element error. No blob_storage_client is configured!"
@@ -434,10 +482,8 @@ async def create_element(self, element: "Element"):
434482 if content is None :
435483 raise ValueError ("Content is None, cannot upload file" )
436484
437- context_user = context .session .user
438-
439- user_folder = getattr (context_user , "id" , "unknown" )
440- file_object_key = f"{ user_folder } /{ element .id } " + (
485+ user_id : str = await self ._get_user_id_by_thread (element .thread_id ) or "unknown"
486+ file_object_key = f"{ user_id } /{ element .id } " + (
441487 f"/{ element .name } " if element .name else ""
442488 )
443489
@@ -607,8 +653,9 @@ async def get_all_user_threads(
607653 tags = step_feedback .get ("step_tags" ),
608654 input = (
609655 step_feedback .get ("step_input" , "" )
610- if step_feedback .get ("step_showinput" ) not in [None , "false" ]
611- else None
656+ if step_feedback .get ("step_showinput" )
657+ not in [None , "false" ]
658+ else ""
612659 ),
613660 output = step_feedback .get ("step_output" , "" ),
614661 createdAt = step_feedback .get ("step_createdat" ),
0 commit comments