Skip to content

Commit 2bdd541

Browse files
hayescodedokterbob
andauthored
Add get_element() and test infra for sql_alchemy.py (#1346)
* Add get_element() to sql_alchemy.py * Add test for create_element and get_element in SQLAlchemyDataLayer. * Add aiosqlite test dep. * Add missing attribute to mocked WebsocketSession object * Add mocked user to ChainlitContext in test --------- Co-authored-by: Mathijs de Bruin (aider) <[email protected]>
1 parent b86fa05 commit 2bdd541

File tree

7 files changed

+207
-2
lines changed

7 files changed

+207
-2
lines changed

backend/chainlit/data/sql_alchemy.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,34 @@ async def delete_feedback(self, feedback_id: str) -> bool:
373373
return True
374374

375375
###### Elements ######
376+
async def get_element(self, thread_id: str, element_id: str) -> Optional["ElementDict"]:
377+
if self.show_logger:
378+
logger.info(f"SQLAlchemy: get_element, thread_id={thread_id}, element_id={element_id}")
379+
query = """SELECT * FROM elements WHERE "threadId" = :thread_id AND "id" = :element_id"""
380+
parameters = {"thread_id": thread_id, "element_id": element_id}
381+
element: Union[List[Dict[str, Any]], int, None] = await self.execute_sql(query=query, parameters=parameters)
382+
if isinstance(element, list) and element:
383+
element_dict: Dict[str, Any] = element[0]
384+
return ElementDict(
385+
id=element_dict["id"],
386+
threadId=element_dict.get("threadId"),
387+
type=element_dict["type"],
388+
chainlitKey=element_dict.get("chainlitKey"),
389+
url=element_dict.get("url"),
390+
objectKey=element_dict.get("objectKey"),
391+
name=element_dict["name"],
392+
display=element_dict["display"],
393+
size=element_dict.get("size"),
394+
language=element_dict.get("language"),
395+
page=element_dict.get("page"),
396+
autoPlay=element_dict.get("autoPlay"),
397+
playerConfig=element_dict.get("playerConfig"),
398+
forId=element_dict.get("forId"),
399+
mime=element_dict.get("mime")
400+
)
401+
else:
402+
return None
403+
376404
@queue_until_user_message()
377405
async def create_element(self, element: "Element"):
378406
if self.show_logger:

backend/poetry.lock

Lines changed: 19 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backend/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ plotly = "^5.18.0"
6868
slack_bolt = "^1.18.1"
6969
discord = "^2.3.2"
7070
botbuilder-core = "^4.15.0"
71+
aiosqlite = "^0.20.0"
7172

7273
[tool.poetry.group.dev.dependencies]
7374
black = "^24.8.0"
@@ -106,6 +107,7 @@ ignore_missing_imports = true
106107

107108

108109

110+
109111
[tool.poetry.group.custom-data]
110112
optional = true
111113

backend/tests/conftest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytest_asyncio
66
from chainlit.context import ChainlitContext, context_var
77
from chainlit.session import HTTPSession, WebsocketSession
8+
from chainlit.user import PersistedUser
89
from chainlit.user_session import UserSession
910

1011

@@ -14,13 +15,16 @@ async def create_chainlit_context():
1415
mock_session.id = "test_session_id"
1516
mock_session.user_env = {"test_env": "value"}
1617
mock_session.chat_settings = {}
17-
mock_session.user = None
18+
mock_user = Mock(spec=PersistedUser)
19+
mock_user.id = "test_user_id"
20+
mock_session.user = mock_user
1821
mock_session.chat_profile = None
1922
mock_session.http_referer = None
2023
mock_session.client_type = "webapp"
2124
mock_session.languages = ["en"]
2225
mock_session.thread_id = "test_thread_id"
2326
mock_session.emit = AsyncMock()
27+
mock_session.has_first_interaction = True
2428

2529
context = ChainlitContext(mock_session)
2630
token = context_var.set(context)

backend/tests/data/__init__.py

Whitespace-only changes.

backend/tests/data/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pytest
2+
3+
from unittest.mock import AsyncMock
4+
5+
from chainlit.data.base import BaseStorageClient
6+
7+
8+
@pytest.fixture
9+
def mock_storage_client():
10+
mock_client = AsyncMock(spec=BaseStorageClient)
11+
mock_client.upload_file.return_value = {
12+
"url": "https://example.com/test.txt",
13+
"object_key": "test_user/test_element/test.txt",
14+
}
15+
return mock_client
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
import uuid
2+
from pathlib import Path
3+
4+
import pytest
5+
6+
from sqlalchemy.ext.asyncio import create_async_engine
7+
from sqlalchemy import text
8+
9+
from chainlit.data.base import BaseStorageClient
10+
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
11+
from chainlit.element import Text
12+
13+
14+
@pytest.fixture
15+
async def data_layer(mock_storage_client: BaseStorageClient, tmp_path: Path):
16+
db_file = tmp_path / "test_db.sqlite"
17+
conninfo = f"sqlite+aiosqlite:///{db_file}"
18+
19+
# Create async engine
20+
engine = create_async_engine(conninfo)
21+
22+
# Execute initialization statements
23+
# Ref: https://docs.chainlit.io/data-persistence/custom#sql-alchemy-data-layer
24+
async with engine.begin() as conn:
25+
await conn.execute(
26+
text("""
27+
CREATE TABLE users (
28+
"id" UUID PRIMARY KEY,
29+
"identifier" TEXT NOT NULL UNIQUE,
30+
"metadata" JSONB NOT NULL,
31+
"createdAt" TEXT
32+
);
33+
""")
34+
)
35+
36+
await conn.execute(
37+
text("""
38+
CREATE TABLE IF NOT EXISTS threads (
39+
"id" UUID PRIMARY KEY,
40+
"createdAt" TEXT,
41+
"name" TEXT,
42+
"userId" UUID,
43+
"userIdentifier" TEXT,
44+
"tags" TEXT[],
45+
"metadata" JSONB,
46+
FOREIGN KEY ("userId") REFERENCES users("id") ON DELETE CASCADE
47+
);
48+
""")
49+
)
50+
51+
await conn.execute(
52+
text("""
53+
CREATE TABLE IF NOT EXISTS steps (
54+
"id" UUID PRIMARY KEY,
55+
"name" TEXT NOT NULL,
56+
"type" TEXT NOT NULL,
57+
"threadId" UUID NOT NULL,
58+
"parentId" UUID,
59+
"disableFeedback" BOOLEAN NOT NULL,
60+
"streaming" BOOLEAN NOT NULL,
61+
"waitForAnswer" BOOLEAN,
62+
"isError" BOOLEAN,
63+
"metadata" JSONB,
64+
"tags" TEXT[],
65+
"input" TEXT,
66+
"output" TEXT,
67+
"createdAt" TEXT,
68+
"start" TEXT,
69+
"end" TEXT,
70+
"generation" JSONB,
71+
"showInput" TEXT,
72+
"language" TEXT,
73+
"indent" INT
74+
);
75+
""")
76+
)
77+
78+
await conn.execute(
79+
text("""
80+
CREATE TABLE IF NOT EXISTS elements (
81+
"id" UUID PRIMARY KEY,
82+
"threadId" UUID,
83+
"type" TEXT,
84+
"url" TEXT,
85+
"chainlitKey" TEXT,
86+
"name" TEXT NOT NULL,
87+
"display" TEXT,
88+
"objectKey" TEXT,
89+
"size" TEXT,
90+
"page" INT,
91+
"language" TEXT,
92+
"forId" UUID,
93+
"mime" TEXT
94+
);
95+
""")
96+
)
97+
98+
await conn.execute(
99+
text("""
100+
CREATE TABLE IF NOT EXISTS feedbacks (
101+
"id" UUID PRIMARY KEY,
102+
"forId" UUID NOT NULL,
103+
"threadId" UUID NOT NULL,
104+
"value" INT NOT NULL,
105+
"comment" TEXT
106+
);
107+
""")
108+
)
109+
110+
# Create SQLAlchemyDataLayer instance
111+
data_layer = SQLAlchemyDataLayer(conninfo, storage_provider=mock_storage_client)
112+
113+
yield data_layer
114+
115+
116+
@pytest.mark.asyncio
117+
async def test_create_and_get_element(
118+
mock_chainlit_context, data_layer: SQLAlchemyDataLayer
119+
):
120+
async with mock_chainlit_context:
121+
text_element = Text(
122+
id=str(uuid.uuid4()),
123+
name="test.txt",
124+
mime="text/plain",
125+
content="test content",
126+
for_id="test_step_id",
127+
)
128+
129+
await data_layer.create_element(text_element)
130+
131+
retrieved_element = await data_layer.get_element(
132+
text_element.thread_id, text_element.id
133+
)
134+
assert retrieved_element is not None
135+
assert retrieved_element["id"] == text_element.id
136+
assert retrieved_element["name"] == text_element.name
137+
assert retrieved_element["mime"] == text_element.mime
138+
# The 'content' field is not part of the ElementDict, so we remove this assertion

0 commit comments

Comments
 (0)