Skip to content

Commit 7a72954

Browse files
committed
Merge pre-existing with newly added tests, solves Chainlit#1137.
1 parent bc7c536 commit 7a72954

File tree

10 files changed

+143
-343
lines changed

10 files changed

+143
-343
lines changed

backend/chainlit/data/sql_alchemy.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,23 @@ async def get_user(self, identifier: str) -> Optional[PersistedUser]:
122122
result = await self.execute_sql(query=query, parameters=parameters)
123123
if result and isinstance(result, list):
124124
user_data = result[0]
125-
return PersistedUser(**user_data)
125+
126+
# SQLite returns JSON as string, we most convert it. (#1137)
127+
metadata = user_data.get("metadata", {})
128+
if isinstance(metadata, str):
129+
metadata = json.loads(metadata)
130+
131+
assert isinstance(metadata, dict)
132+
assert isinstance(user_data["id"], str)
133+
assert isinstance(user_data["identifier"], str)
134+
assert isinstance(user_data["createdAt"], str)
135+
136+
return PersistedUser(
137+
id=user_data["id"],
138+
identifier=user_data["identifier"],
139+
createdAt=user_data["createdAt"],
140+
metadata=metadata,
141+
)
126142
return None
127143

128144
async def create_user(self, user: User) -> Optional[PersistedUser]:
@@ -377,12 +393,18 @@ async def delete_feedback(self, feedback_id: str) -> bool:
377393
return True
378394

379395
###### Elements ######
380-
async def get_element(self, thread_id: str, element_id: str) -> Optional["ElementDict"]:
396+
async def get_element(
397+
self, thread_id: str, element_id: str
398+
) -> Optional["ElementDict"]:
381399
if self.show_logger:
382-
logger.info(f"SQLAlchemy: get_element, thread_id={thread_id}, element_id={element_id}")
400+
logger.info(
401+
f"SQLAlchemy: get_element, thread_id={thread_id}, element_id={element_id}"
402+
)
383403
query = """SELECT * FROM elements WHERE "threadId" = :thread_id AND "id" = :element_id"""
384404
parameters = {"thread_id": thread_id, "element_id": element_id}
385-
element: Union[List[Dict[str, Any]], int, None] = await self.execute_sql(query=query, parameters=parameters)
405+
element: Union[List[Dict[str, Any]], int, None] = await self.execute_sql(
406+
query=query, parameters=parameters
407+
)
386408
if isinstance(element, list) and element:
387409
element_dict: Dict[str, Any] = element[0]
388410
return ElementDict(
@@ -400,7 +422,7 @@ async def get_element(self, thread_id: str, element_id: str) -> Optional["Elemen
400422
autoPlay=element_dict.get("autoPlay"),
401423
playerConfig=element_dict.get("playerConfig"),
402424
forId=element_dict.get("forId"),
403-
mime=element_dict.get("mime")
425+
mime=element_dict.get("mime"),
404426
)
405427
else:
406428
return None
@@ -611,7 +633,8 @@ async def get_all_user_threads(
611633
tags=step_feedback.get("step_tags"),
612634
input=(
613635
step_feedback.get("step_input", "")
614-
if step_feedback.get("step_showinput") not in [None, "false"]
636+
if step_feedback.get("step_showinput")
637+
not in [None, "false"]
615638
else None
616639
),
617640
output=step_feedback.get("step_output", ""),

backend/chainlit/data/suites/__init__.py

Whitespace-only changes.

backend/chainlit/data/suites/checks_suites.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

backend/tests/conftest.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,23 +9,32 @@
99
from chainlit.user_session import UserSession
1010

1111

12-
@asynccontextmanager
13-
async def create_chainlit_context():
14-
mock_session = Mock(spec=WebsocketSession)
15-
mock_session.id = "test_session_id"
16-
mock_session.user_env = {"test_env": "value"}
17-
mock_session.chat_settings = {}
18-
mock_user = Mock(spec=PersistedUser)
19-
mock_user.id = "test_user_id"
20-
mock_session.user = mock_user
21-
mock_session.chat_profile = None
22-
mock_session.http_referer = None
23-
mock_session.client_type = "webapp"
24-
mock_session.languages = ["en"]
25-
mock_session.thread_id = "test_thread_id"
26-
mock_session.emit = AsyncMock()
27-
mock_session.has_first_interaction = True
12+
@pytest.fixture
13+
def mock_persisted_user():
14+
mock = Mock(spec=PersistedUser)
15+
mock.id = "test_user_id"
16+
return mock
17+
18+
19+
@pytest.fixture
20+
def mock_session():
21+
mock = Mock(spec=WebsocketSession)
22+
mock.id = "test_session_id"
23+
mock.user_env = {"test_env": "value"}
24+
mock.chat_settings = {}
25+
mock.chat_profile = None
26+
mock.http_referer = None
27+
mock.client_type = "webapp"
28+
mock.languages = ["en"]
29+
mock.thread_id = "test_thread_id"
30+
mock.emit = AsyncMock()
31+
mock.has_first_interaction = True
32+
33+
return mock
2834

35+
36+
@asynccontextmanager
37+
async def create_chainlit_context(mock_session):
2938
context = ChainlitContext(mock_session)
3039
token = context_var.set(context)
3140
try:
@@ -35,8 +44,9 @@ async def create_chainlit_context():
3544

3645

3746
@pytest_asyncio.fixture
38-
async def mock_chainlit_context():
39-
return create_chainlit_context()
47+
async def mock_chainlit_context(mock_persisted_user, mock_session):
48+
mock_session.user = mock_persisted_user
49+
return create_chainlit_context(mock_session)
4050

4151

4252
@pytest.fixture

backend/tests/data/test_sql_alchemy.py

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1+
from unittest.mock import Mock
12
import uuid
23
from pathlib import Path
34

45
import pytest
5-
6+
import pytest_asyncio
67
from sqlalchemy.ext.asyncio import create_async_engine
78
from sqlalchemy import text
89

9-
from chainlit.data.base import BaseStorageClient
10+
from chainlit.data.base import BaseDataLayer, BaseStorageClient
1011
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
1112
from chainlit.element import Text
13+
from chainlit import User
14+
from chainlit.user import PersistedUser
1215

1316

1417
@pytest.fixture
@@ -113,7 +116,11 @@ async def data_layer(mock_storage_client: BaseStorageClient, tmp_path: Path):
113116
yield data_layer
114117

115118

116-
@pytest.mark.asyncio
119+
@pytest.fixture
120+
def test_user() -> User:
121+
return User(identifier="test_user_id", metadata={"test": 1})
122+
123+
117124
async def test_create_and_get_element(
118125
mock_chainlit_context, data_layer: SQLAlchemyDataLayer
119126
):
@@ -136,3 +143,79 @@ async def test_create_and_get_element(
136143
assert retrieved_element["name"] == text_element.name
137144
assert retrieved_element["mime"] == text_element.mime
138145
# The 'content' field is not part of the ElementDict, so we remove this assertion
146+
147+
148+
async def test_get_current_timestamp(
149+
mock_chainlit_context, data_layer: SQLAlchemyDataLayer
150+
):
151+
async with mock_chainlit_context:
152+
timestamp = await data_layer.get_current_timestamp()
153+
assert isinstance(timestamp, str)
154+
155+
156+
async def test_get_user(
157+
mock_chainlit_context, test_user: User, data_layer: SQLAlchemyDataLayer
158+
):
159+
async with mock_chainlit_context:
160+
result = await data_layer.get_user(test_user.identifier)
161+
assert result is None
162+
163+
164+
async def test_create_user(
165+
mock_chainlit_context, test_user: User, data_layer: SQLAlchemyDataLayer
166+
):
167+
async with mock_chainlit_context:
168+
persisted_user = await data_layer.create_user(test_user)
169+
170+
assert persisted_user
171+
assert persisted_user.identifier == test_user.identifier
172+
assert persisted_user.id
173+
assert persisted_user.createdAt
174+
175+
176+
async def test_get_thread_author(
177+
mock_chainlit_context, test_user: User, data_layer: SQLAlchemyDataLayer
178+
):
179+
async with mock_chainlit_context:
180+
persisted_user = await data_layer.create_user(test_user)
181+
assert persisted_user
182+
183+
await data_layer.update_thread("test_thread", persisted_user.identifier)
184+
author = await data_layer.get_thread_author("test_thread")
185+
assert author == persisted_user.identifier
186+
187+
188+
async def test_get_thread(
189+
mock_chainlit_context, test_user: User, data_layer: SQLAlchemyDataLayer
190+
):
191+
async with mock_chainlit_context:
192+
persisted_user = await data_layer.create_user(test_user)
193+
assert persisted_user
194+
195+
result = await data_layer.get_thread("test_thread")
196+
assert result is None
197+
198+
199+
async def test_update_thread(
200+
mock_chainlit_context, test_user: User, data_layer: SQLAlchemyDataLayer
201+
):
202+
async with mock_chainlit_context:
203+
persisted_user = await data_layer.create_user(test_user)
204+
assert persisted_user
205+
206+
await data_layer.update_thread("test_thread", persisted_user.identifier)
207+
assert True
208+
209+
210+
async def test_delete_thread(
211+
mock_chainlit_context, test_user: User, data_layer: SQLAlchemyDataLayer
212+
):
213+
async with mock_chainlit_context:
214+
persisted_user = await data_layer.create_user(test_user)
215+
assert persisted_user
216+
217+
await data_layer.update_thread("test_thread", "test_user")
218+
await data_layer.delete_thread("test_thread")
219+
thread = await data_layer.get_thread("test_thread")
220+
assert thread is None
221+
assert True

backend/tests/test_callbacks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ async def auth_func(
109109

110110
async def test_on_message(mock_chainlit_context, test_config):
111111
from chainlit.callbacks import on_message
112-
from chainlit.config import config
113112
from chainlit.message import Message
114113

115114
async with mock_chainlit_context as context:

backend/tests/test_db/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)