Skip to content

Commit 1964409

Browse files
Get rid of context from SQL Alchemy data layer (#1319), fix SQLite support (#1137).
* Add SQLite DB tests and fixtures * Get rid of context in SQL Alchemy data layer. --------- Signed-off-by: DanielAvdar <[email protected]> Co-authored-by: Mathijs de Bruin <[email protected]>
1 parent 2bdd541 commit 1964409

File tree

6 files changed

+1333
-1088
lines changed

6 files changed

+1333
-1088
lines changed

backend/chainlit/data/dynamodb.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def _update_item(self, key: Dict[str, Any], updates: Dict[str, Any]):
9393
ExpressionAttributeValues=self._serialize_item(expression_attribute_values),
9494
)
9595

96+
@property
97+
def context(self):
98+
return context
99+
96100
async def get_user(self, identifier: str) -> Optional["PersistedUser"]:
97101
_logger.info("DynamoDB: get_user identifier=%s", identifier)
98102

@@ -241,7 +245,7 @@ async def create_element(self, element: "Element"):
241245
if not element.mime:
242246
element.mime = "application/octet-stream"
243247

244-
context_user = context.session.user
248+
context_user = self.context.session.user
245249
user_folder = getattr(context_user, "id", "unknown")
246250
file_object_key = f"{user_folder}/{element.thread_id}/{element.id}"
247251

@@ -293,7 +297,7 @@ async def get_element(
293297

294298
@queue_until_user_message()
295299
async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
296-
thread_id = context.session.thread_id
300+
thread_id = self.context.session.thread_id
297301
_logger.info(
298302
"DynamoDB: delete_element thread=%s element=%s", thread_id, element_id
299303
)
@@ -349,7 +353,7 @@ async def update_step(self, step_dict: "StepDict"):
349353

350354
@queue_until_user_message()
351355
async def delete_step(self, step_id: str):
352-
thread_id = context.session.thread_id
356+
thread_id = self.context.session.thread_id
353357
_logger.info("DynamoDB: delete_feedback thread=%s step=%s", thread_id, step_id)
354358

355359
self.client.delete_item(

backend/chainlit/data/sql_alchemy.py

Lines changed: 67 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import aiofiles
99
import aiohttp
10-
from chainlit.context import context
10+
1111
from chainlit.data.base import BaseDataLayer, BaseStorageClient
1212
from chainlit.data.utils import queue_until_user_message
1313
from chainlit.element import ElementDict
@@ -84,6 +84,9 @@ async def execute_sql(
8484
if result.returns_rows:
8585
json_result = [dict(row._mapping) for row in result.fetchall()]
8686
clean_json_result = self.clean_result(json_result)
87+
assert isinstance(clean_json_result, list) or isinstance(
88+
clean_json_result, int
89+
)
8790
return clean_json_result
8891
else:
8992
return result.rowcount
@@ -118,7 +121,47 @@ async def get_user(self, identifier: str) -> Optional[PersistedUser]:
118121
result = await self.execute_sql(query=query, parameters=parameters)
119122
if result and isinstance(result, list):
120123
user_data = result[0]
121-
return PersistedUser(**user_data)
124+
125+
# SQLite returns JSON as string, we most convert it. (#1137)
126+
metadata = user_data.get("metadata", {})
127+
if isinstance(metadata, str):
128+
metadata = json.loads(metadata)
129+
130+
assert isinstance(metadata, dict)
131+
assert isinstance(user_data["id"], str)
132+
assert isinstance(user_data["identifier"], str)
133+
assert isinstance(user_data["createdAt"], str)
134+
135+
return PersistedUser(
136+
id=user_data["id"],
137+
identifier=user_data["identifier"],
138+
createdAt=user_data["createdAt"],
139+
metadata=metadata,
140+
)
141+
return None
142+
143+
async def _get_user_identifer_by_id(self, user_id: str) -> str:
144+
if self.show_logger:
145+
logger.info(f"SQLAlchemy: _get_user_identifer_by_id, user_id={user_id}")
146+
query = "SELECT identifier FROM users WHERE id = :user_id"
147+
parameters = {"user_id": user_id}
148+
result = await self.execute_sql(query=query, parameters=parameters)
149+
150+
assert result
151+
assert isinstance(result, list)
152+
153+
return result[0]["identifier"]
154+
155+
async def _get_user_id_by_thread(self, thread_id: str) -> Optional[str]:
156+
if self.show_logger:
157+
logger.info(f"SQLAlchemy: _get_user_id_by_thread, thread_id={thread_id}")
158+
query = "SELECT userId FROM threads WHERE id = :thread_id"
159+
parameters = {"thread_id": thread_id}
160+
result = await self.execute_sql(query=query, parameters=parameters)
161+
if result:
162+
assert isinstance(result, list)
163+
return result[0]["userId"]
164+
122165
return None
123166

124167
async def create_user(self, user: User) -> Optional[PersistedUser]:
@@ -179,10 +222,11 @@ async def update_thread(
179222
):
180223
if self.show_logger:
181224
logger.info(f"SQLAlchemy: update_thread, thread_id={thread_id}")
182-
if context.session.user is not None:
183-
user_identifier = context.session.user.identifier
184-
else:
185-
raise ValueError("User not found in session context")
225+
226+
user_identifier = None
227+
if user_id:
228+
user_identifier = await self._get_user_identifer_by_id(user_id)
229+
186230
data = {
187231
"id": thread_id,
188232
"createdAt": (
@@ -294,8 +338,7 @@ async def list_threads(
294338
async def create_step(self, step_dict: "StepDict"):
295339
if self.show_logger:
296340
logger.info(f"SQLAlchemy: create_step, step_id={step_dict.get('id')}")
297-
if not getattr(context.session.user, "id", None):
298-
raise ValueError("No authenticated user in context")
341+
299342
step_dict["showInput"] = (
300343
str(step_dict.get("showInput", "")).lower()
301344
if "showInput" in step_dict
@@ -373,12 +416,18 @@ async def delete_feedback(self, feedback_id: str) -> bool:
373416
return True
374417

375418
###### Elements ######
376-
async def get_element(self, thread_id: str, element_id: str) -> Optional["ElementDict"]:
419+
async def get_element(
420+
self, thread_id: str, element_id: str
421+
) -> Optional["ElementDict"]:
377422
if self.show_logger:
378-
logger.info(f"SQLAlchemy: get_element, thread_id={thread_id}, element_id={element_id}")
423+
logger.info(
424+
f"SQLAlchemy: get_element, thread_id={thread_id}, element_id={element_id}"
425+
)
379426
query = """SELECT * FROM elements WHERE "threadId" = :thread_id AND "id" = :element_id"""
380427
parameters = {"thread_id": thread_id, "element_id": element_id}
381-
element: Union[List[Dict[str, Any]], int, None] = await self.execute_sql(query=query, parameters=parameters)
428+
element: Union[List[Dict[str, Any]], int, None] = await self.execute_sql(
429+
query=query, parameters=parameters
430+
)
382431
if isinstance(element, list) and element:
383432
element_dict: Dict[str, Any] = element[0]
384433
return ElementDict(
@@ -396,7 +445,7 @@ async def get_element(self, thread_id: str, element_id: str) -> Optional["Elemen
396445
autoPlay=element_dict.get("autoPlay"),
397446
playerConfig=element_dict.get("playerConfig"),
398447
forId=element_dict.get("forId"),
399-
mime=element_dict.get("mime")
448+
mime=element_dict.get("mime"),
400449
)
401450
else:
402451
return None
@@ -405,8 +454,7 @@ async def get_element(self, thread_id: str, element_id: str) -> Optional["Elemen
405454
async def create_element(self, element: "Element"):
406455
if self.show_logger:
407456
logger.info(f"SQLAlchemy: create_element, element_id = {element.id}")
408-
if not getattr(context.session.user, "id", None):
409-
raise ValueError("No authenticated user in context")
457+
410458
if not self.storage_provider:
411459
logger.warn(
412460
"SQLAlchemy: create_element error. No blob_storage_client is configured!"
@@ -434,10 +482,8 @@ async def create_element(self, element: "Element"):
434482
if content is None:
435483
raise ValueError("Content is None, cannot upload file")
436484

437-
context_user = context.session.user
438-
439-
user_folder = getattr(context_user, "id", "unknown")
440-
file_object_key = f"{user_folder}/{element.id}" + (
485+
user_id: str = await self._get_user_id_by_thread(element.thread_id) or "unknown"
486+
file_object_key = f"{user_id}/{element.id}" + (
441487
f"/{element.name}" if element.name else ""
442488
)
443489

@@ -607,8 +653,9 @@ async def get_all_user_threads(
607653
tags=step_feedback.get("step_tags"),
608654
input=(
609655
step_feedback.get("step_input", "")
610-
if step_feedback.get("step_showinput") not in [None, "false"]
611-
else None
656+
if step_feedback.get("step_showinput")
657+
not in [None, "false"]
658+
else ""
612659
),
613660
output=step_feedback.get("step_output", ""),
614661
createdAt=step_feedback.get("step_createdat"),

0 commit comments

Comments
 (0)