77
88import aiofiles
99import aiohttp
10- from chainlit .context import context
11- from chainlit .data import BaseDataLayer , BaseStorageClient , queue_until_user_message
10+
11+ from chainlit .data .base import BaseDataLayer , BaseStorageClient
12+ from chainlit .data .utils import queue_until_user_message
1213from chainlit .element import ElementDict
1314from chainlit .logger import logger
1415from chainlit .step import StepDict
2627from sqlalchemy .exc import SQLAlchemyError
2728from sqlalchemy .ext .asyncio import AsyncEngine , AsyncSession , create_async_engine
2829from sqlalchemy .orm import sessionmaker
29- import chainlit as cl
30- from literalai .helper import utc_now
31- now = utc_now ()
3230
3331if TYPE_CHECKING :
3432 from chainlit .element import Element , ElementDict
@@ -57,7 +55,9 @@ def __init__(
5755 self .engine : AsyncEngine = create_async_engine (
5856 self ._conninfo , connect_args = ssl_args
5957 )
60- self .async_session = sessionmaker (bind = self .engine , expire_on_commit = False , class_ = AsyncSession ) # type: ignore
58+ self .async_session = sessionmaker (
59+ bind = self .engine , expire_on_commit = False , class_ = AsyncSession
60+ ) # type: ignore
6161 if storage_provider :
6262 self .storage_provider : Optional [BaseStorageClient ] = storage_provider
6363 if self .show_logger :
@@ -84,6 +84,9 @@ async def execute_sql(
8484 if result .returns_rows :
8585 json_result = [dict (row ._mapping ) for row in result .fetchall ()]
8686 clean_json_result = self .clean_result (json_result )
87+ assert isinstance (clean_json_result , list ) or isinstance (
88+ clean_json_result , int
89+ )
8790 return clean_json_result
8891 else :
8992 return result .rowcount
@@ -111,21 +114,57 @@ def clean_result(self, obj):
111114
112115 ###### User ######
113116 async def get_user (self , identifier : str ) -> Optional [PersistedUser ]:
114- logger .debug (f"Getting user: { identifier } " )
115- return cl .PersistedUser (id = "test" , createdAt = now , identifier = identifier )
116117 if self .show_logger :
117118 logger .info (f"SQLAlchemy: get_user, identifier={ identifier } " )
118119 query = "SELECT * FROM users WHERE identifier = :identifier"
119120 parameters = {"identifier" : identifier }
120121 result = await self .execute_sql (query = query , parameters = parameters )
121122 if result and isinstance (result , list ):
122123 user_data = result [0 ]
123- return PersistedUser (** user_data )
124+
125+ # SQLite returns JSON as string, we most convert it. (#1137)
126+ metadata = user_data .get ("metadata" , {})
127+ if isinstance (metadata , str ):
128+ metadata = json .loads (metadata )
129+
130+ assert isinstance (metadata , dict )
131+ assert isinstance (user_data ["id" ], str )
132+ assert isinstance (user_data ["identifier" ], str )
133+ assert isinstance (user_data ["createdAt" ], str )
134+
135+ return PersistedUser (
136+ id = user_data ["id" ],
137+ identifier = user_data ["identifier" ],
138+ createdAt = user_data ["createdAt" ],
139+ metadata = metadata ,
140+ )
141+ return None
142+
143+ async def _get_user_identifer_by_id (self , user_id : str ) -> str :
144+ if self .show_logger :
145+ logger .info (f"SQLAlchemy: _get_user_identifer_by_id, user_id={ user_id } " )
146+ query = "SELECT identifier FROM users WHERE id = :user_id"
147+ parameters = {"user_id" : user_id }
148+ result = await self .execute_sql (query = query , parameters = parameters )
149+
150+ assert result
151+ assert isinstance (result , list )
152+
153+ return result [0 ]["identifier" ]
154+
155+ async def _get_user_id_by_thread (self , thread_id : str ) -> Optional [str ]:
156+ if self .show_logger :
157+ logger .info (f"SQLAlchemy: _get_user_id_by_thread, thread_id={ thread_id } " )
158+ query = "SELECT userId FROM threads WHERE id = :thread_id"
159+ parameters = {"thread_id" : thread_id }
160+ result = await self .execute_sql (query = query , parameters = parameters )
161+ if result :
162+ assert isinstance (result , list )
163+ return result [0 ]["userId" ]
164+
124165 return None
125166
126167 async def create_user (self , user : User ) -> Optional [PersistedUser ]:
127- logger .debug (f"Creating user: { user .identifier } " )
128- return cl .PersistedUser (id = "test" , createdAt = now , identifier = user .identifier )
129168 if self .show_logger :
130169 logger .info (f"SQLAlchemy: create_user, user_identifier={ user .identifier } " )
131170 existing_user : Optional ["PersistedUser" ] = await self .get_user (user .identifier )
@@ -151,8 +190,6 @@ async def create_user(self, user: User) -> Optional[PersistedUser]:
151190
152191 ###### Threads ######
153192 async def get_thread_author (self , thread_id : str ) -> str :
154- logger .debug (f"Getting thread author: { thread_id } " )
155- return "admin"
156193 if self .show_logger :
157194 logger .info (f"SQLAlchemy: get_thread_author, thread_id={ thread_id } " )
158195 query = """SELECT "userIdentifier" FROM threads WHERE "id" = :id"""
@@ -171,16 +208,7 @@ async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
171208 thread_id = thread_id
172209 )
173210 if user_threads :
174- thread = user_threads [0 ]
175- # Parse the metadata here
176- if isinstance (thread ['metadata' ], str ):
177- try :
178- thread ['metadata' ] = json .loads (thread ['metadata' ])
179- except json .JSONDecodeError :
180- thread ['metadata' ] = {}
181- elif thread ['metadata' ] is None :
182- thread ['metadata' ] = {}
183- return thread
211+ return user_threads [0 ]
184212 else :
185213 return None
186214
@@ -194,10 +222,11 @@ async def update_thread(
194222 ):
195223 if self .show_logger :
196224 logger .info (f"SQLAlchemy: update_thread, thread_id={ thread_id } " )
197- if context .session .user is not None :
198- user_identifier = context .session .user .identifier
199- else :
200- raise ValueError ("User not found in session context" )
225+
226+ user_identifier = None
227+ if user_id :
228+ user_identifier = await self ._get_user_identifer_by_id (user_id )
229+
201230 data = {
202231 "id" : thread_id ,
203232 "createdAt" : (
@@ -309,8 +338,7 @@ async def list_threads(
309338 async def create_step (self , step_dict : "StepDict" ):
310339 if self .show_logger :
311340 logger .info (f"SQLAlchemy: create_step, step_id={ step_dict .get ('id' )} " )
312- if not getattr (context .session .user , "id" , None ):
313- raise ValueError ("No authenticated user in context" )
341+
314342 step_dict ["showInput" ] = (
315343 str (step_dict .get ("showInput" , "" )).lower ()
316344 if "showInput" in step_dict
@@ -388,15 +416,48 @@ async def delete_feedback(self, feedback_id: str) -> bool:
388416 return True
389417
390418 ###### Elements ######
419+ async def get_element (
420+ self , thread_id : str , element_id : str
421+ ) -> Optional ["ElementDict" ]:
422+ if self .show_logger :
423+ logger .info (
424+ f"SQLAlchemy: get_element, thread_id={ thread_id } , element_id={ element_id } "
425+ )
426+ query = """SELECT * FROM elements WHERE "threadId" = :thread_id AND "id" = :element_id"""
427+ parameters = {"thread_id" : thread_id , "element_id" : element_id }
428+ element : Union [List [Dict [str , Any ]], int , None ] = await self .execute_sql (
429+ query = query , parameters = parameters
430+ )
431+ if isinstance (element , list ) and element :
432+ element_dict : Dict [str , Any ] = element [0 ]
433+ return ElementDict (
434+ id = element_dict ["id" ],
435+ threadId = element_dict .get ("threadId" ),
436+ type = element_dict ["type" ],
437+ chainlitKey = element_dict .get ("chainlitKey" ),
438+ url = element_dict .get ("url" ),
439+ objectKey = element_dict .get ("objectKey" ),
440+ name = element_dict ["name" ],
441+ display = element_dict ["display" ],
442+ size = element_dict .get ("size" ),
443+ language = element_dict .get ("language" ),
444+ page = element_dict .get ("page" ),
445+ autoPlay = element_dict .get ("autoPlay" ),
446+ playerConfig = element_dict .get ("playerConfig" ),
447+ forId = element_dict .get ("forId" ),
448+ mime = element_dict .get ("mime" ),
449+ )
450+ else :
451+ return None
452+
391453 @queue_until_user_message ()
392454 async def create_element (self , element : "Element" ):
393455 if self .show_logger :
394456 logger .info (f"SQLAlchemy: create_element, element_id = { element .id } " )
395- if not getattr (context .session .user , "id" , None ):
396- raise ValueError ("No authenticated user in context" )
457+
397458 if not self .storage_provider :
398459 logger .warn (
399- f "SQLAlchemy: create_element error. No blob_storage_client is configured!"
460+ "SQLAlchemy: create_element error. No blob_storage_client is configured!"
400461 )
401462 return
402463 if not element .for_id :
@@ -421,10 +482,8 @@ async def create_element(self, element: "Element"):
421482 if content is None :
422483 raise ValueError ("Content is None, cannot upload file" )
423484
424- context_user = context .session .user
425-
426- user_folder = getattr (context_user , "id" , "unknown" )
427- file_object_key = f"{ user_folder } /{ element .id } " + (
485+ user_id : str = await self ._get_user_id_by_thread (element .thread_id ) or "unknown"
486+ file_object_key = f"{ user_id } /{ element .id } " + (
428487 f"/{ element .name } " if element .name else ""
429488 )
430489
@@ -458,15 +517,12 @@ async def delete_element(self, element_id: str, thread_id: Optional[str] = None)
458517 parameters = {"id" : element_id }
459518 await self .execute_sql (query = query , parameters = parameters )
460519
461- async def delete_user_session (self , id : str ) -> bool :
462- return False # Not sure why documentation wants this
463-
464520 async def get_all_user_threads (
465521 self , user_id : Optional [str ] = None , thread_id : Optional [str ] = None
466522 ) -> Optional [List [ThreadDict ]]:
467523 """Fetch all user threads up to self.user_thread_limit, or one thread by id if thread_id is provided."""
468524 if self .show_logger :
469- logger .info (f "SQLAlchemy: get_all_user_threads" )
525+ logger .info ("SQLAlchemy: get_all_user_threads" )
470526 user_threads_query = """
471527 SELECT
472528 "id" AS thread_id,
@@ -522,7 +578,8 @@ async def get_all_user_threads(
522578 s."language" AS step_language,
523579 s."indent" AS step_indent,
524580 f."value" AS feedback_value,
525- f."comment" AS feedback_comment
581+ f."comment" AS feedback_comment,
582+ f."id" AS feedback_id
526583 FROM steps s LEFT JOIN feedbacks f ON s."id" = f."forId"
527584 WHERE s."threadId" IN { thread_ids }
528585 ORDER BY s."createdAt" ASC
@@ -596,8 +653,9 @@ async def get_all_user_threads(
596653 tags = step_feedback .get ("step_tags" ),
597654 input = (
598655 step_feedback .get ("step_input" , "" )
599- if step_feedback ["step_showinput" ] == "true"
600- else None
656+ if step_feedback .get ("step_showinput" )
657+ not in [None , "false" ]
658+ else ""
601659 ),
602660 output = step_feedback .get ("step_output" , "" ),
603661 createdAt = step_feedback .get ("step_createdat" ),
0 commit comments