Skip to content

Commit bc7c536

Browse files
DanielAvdardokterbob
authored andcommitted
Add SQLite DB tests and fixtures
Introduce SQLite database tests and fixtures for pytest. This includes database schema definitions, temporary file handling, and data layer context management, along with updating dependencies in pyproject.toml and poetry.lock. Signed-off-by: DanielAvdar <[email protected]>
1 parent 90b61e7 commit bc7c536

File tree

7 files changed

+1457
-1035
lines changed

7 files changed

+1457
-1035
lines changed

backend/chainlit/data/suites/__init__.py

Whitespace-only changes.
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import abc
2+
import dataclasses as dc
3+
4+
import pytest
5+
from chainlit.context import ChainlitContext
6+
7+
from chainlit import User
8+
9+
max_size = 10
10+
11+
12+
@dc.dataclass
13+
class DummyChainlitContext(ChainlitContext):
14+
user: User
15+
16+
def __post_init__(self):
17+
pass
18+
19+
@property
20+
def session(self):
21+
return self
22+
23+
24+
class CheckDB:
25+
@pytest.fixture()
26+
@abc.abstractmethod
27+
def data_layer(self, tmp_files_folder):
28+
pass
29+
30+
@pytest.mark.asyncio
31+
async def test_get_current_timestamp(self, data_layer):
32+
timestamp = await data_layer.get_current_timestamp()
33+
assert isinstance(timestamp, str)
34+
35+
@pytest.mark.asyncio
36+
async def test_get_user(self, data_layer):
37+
result = await data_layer.get_user("test_id")
38+
assert result is None
39+
40+
@pytest.mark.asyncio
41+
async def test_create_user(self, data_layer):
42+
user = User("test_user")
43+
result = await data_layer.create_user(user)
44+
assert result is not None
45+
46+
@pytest.mark.asyncio
47+
async def test_get_thread_author(self, data_layer):
48+
_ = await data_layer.update_thread("test_thread", "test_user")
49+
author = await data_layer.get_thread_author("test_thread")
50+
assert author == "test_user"
51+
52+
@pytest.mark.asyncio
53+
async def test_get_thread(self, data_layer):
54+
result = await data_layer.get_thread("test_thread")
55+
assert result is None
56+
57+
@pytest.mark.asyncio
58+
async def test_update_thread(self, data_layer):
59+
await data_layer.update_thread("test_thread", "test_user")
60+
assert True
61+
62+
@pytest.mark.asyncio
63+
async def test_delete_thread(self, data_layer):
64+
await data_layer.update_thread("test_thread", "test_user")
65+
await data_layer.delete_thread("test_thread")
66+
thread = await data_layer.get_thread("test_thread")
67+
assert thread is None
68+
assert True

backend/poetry.lock

Lines changed: 1142 additions & 1035 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backend/tests/test_db/__init__.py

Whitespace-only changes.

backend/tests/test_db/builder.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import dataclasses as dc
2+
import json
3+
from typing import Any, Dict, List, Optional, Union
4+
5+
from chainlit.context import ChainlitContext
6+
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
7+
from chainlit.element import ElementDict
8+
from chainlit.step import StepDict
9+
from sqlalchemy import (
10+
JSON,
11+
UUID,
12+
Boolean,
13+
Column,
14+
DateTime,
15+
ForeignKey,
16+
Integer,
17+
MetaData,
18+
String,
19+
Table,
20+
)
21+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
22+
from sqlalchemy.orm import sessionmaker
23+
24+
import chainlit as cl
25+
from chainlit import logger
26+
27+
28+
async def build_db(conninfo="sqlite+aiosqlite:///database.db"):
29+
engine = create_async_engine(conninfo)
30+
31+
metadata_obj = MetaData()
32+
33+
# Create 'users' table
34+
Table(
35+
"users",
36+
metadata_obj,
37+
Column("id", UUID(as_uuid=True), primary_key=True),
38+
Column("identifier", String, nullable=False, unique=True),
39+
Column("metadata", JSON, nullable=False),
40+
Column(
41+
"createdAt",
42+
String,
43+
),
44+
keep_existing=True,
45+
)
46+
47+
# Create 'threads' table
48+
Table(
49+
"threads",
50+
metadata_obj,
51+
Column("id", UUID(as_uuid=True), primary_key=True),
52+
Column("createdAt", String),
53+
Column("name", String),
54+
Column(
55+
"userId", UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE")
56+
),
57+
Column("userIdentifier", String),
58+
Column(
59+
"tags", String
60+
), # Changed from ARRAY(String) as SQLite doesn't support array types
61+
Column("metadata", JSON),
62+
keep_existing=True,
63+
)
64+
65+
Table(
66+
"steps",
67+
metadata_obj,
68+
Column("id", UUID(as_uuid=True), primary_key=True),
69+
Column("name", String, nullable=False),
70+
Column("type", String, nullable=False),
71+
Column("threadId", UUID(as_uuid=True), nullable=False),
72+
Column("parentId", UUID(as_uuid=True)),
73+
Column("disableFeedback", Boolean, nullable=False),
74+
Column("streaming", Boolean, nullable=False),
75+
Column("waitForAnswer", Boolean),
76+
Column("isError", Boolean),
77+
Column("metadata", JSON),
78+
Column(
79+
"tags", String
80+
), # Changed from ARRAY(String) as SQLite doesn't support array types
81+
Column("input", String),
82+
Column("output", String),
83+
Column("createdAt", String),
84+
Column("start", DateTime),
85+
Column("end", DateTime),
86+
Column("generation", JSON),
87+
Column("showInput", String),
88+
Column("language", String),
89+
Column("indent", Integer),
90+
keep_existing=True,
91+
)
92+
Table(
93+
"elements",
94+
metadata_obj,
95+
Column("id", UUID(as_uuid=True), primary_key=True),
96+
Column("threadId", UUID(as_uuid=True)),
97+
Column("type", String),
98+
Column("url", String),
99+
Column("chainlitKey", String),
100+
Column("name", String, nullable=False),
101+
Column("display", String),
102+
Column("objectKey", String),
103+
Column("size", String),
104+
Column("page", Integer),
105+
Column("language", String),
106+
Column("forId", UUID(as_uuid=True)),
107+
Column("mime", String),
108+
keep_existing=True,
109+
)
110+
Table(
111+
"feedbacks",
112+
metadata_obj,
113+
Column("id", UUID(as_uuid=True), primary_key=True),
114+
Column("forId", UUID(as_uuid=True), nullable=False),
115+
Column("threadId", UUID(as_uuid=True), nullable=False),
116+
Column("value", Integer, nullable=False),
117+
Column("comment", String),
118+
keep_existing=True,
119+
)
120+
121+
async with engine.begin() as conn:
122+
await conn.run_sync(metadata_obj.create_all)
123+
124+
125+
@dc.dataclass
126+
class DummyChainlitContext(ChainlitContext):
127+
user: cl.User
128+
129+
def __post_init__(self):
130+
pass
131+
132+
@property
133+
def session(self):
134+
return self
135+
136+
137+
class CustomDataLayer(SQLAlchemyDataLayer):
138+
def __init__(
139+
self,
140+
conninfo: str,
141+
context: ChainlitContext,
142+
# ssl_require: bool = False,
143+
# storage_provider: Optional[BaseStorageClient] = None,
144+
user_thread_limit: Optional[int] = 1000,
145+
show_logger: Optional[bool] = False,
146+
):
147+
self._conninfo = conninfo
148+
self.user_thread_limit = user_thread_limit
149+
self.show_logger = show_logger
150+
self._context = context
151+
ssl_args = {} # type: ignore
152+
153+
self.engine: AsyncEngine = create_async_engine(
154+
self._conninfo, connect_args=ssl_args
155+
)
156+
self.async_session = sessionmaker(bind=self.engine, expire_on_commit=False, class_=AsyncSession) # type: ignore
157+
158+
@property
159+
def context(self):
160+
return self._context
161+
162+
@context.setter
163+
def context(self, context: ChainlitContext):
164+
self._context = context
165+
166+
async def execute_sql(
167+
self, query: str, parameters: dict
168+
) -> Union[List[Dict[str, Any]], int, None]:
169+
require_metadata = "metadata" in query or "*" in query
170+
171+
res = await super().execute_sql(query=query, parameters=parameters)
172+
if not require_metadata or not isinstance(res, list):
173+
return res
174+
175+
for r in res:
176+
for key in r.keys():
177+
if "metadata" in key:
178+
r[key] = json.loads(r[key]) if r[key] is not None else None
179+
180+
return res
181+
182+
async def get_thread_author(self, thread_id: str) -> str:
183+
if self.show_logger:
184+
logger.info(f"SQLAlchemy: get_thread_author, thread_id={thread_id}")
185+
query = """SELECT "userIdentifier" FROM threads WHERE "id" = :id"""
186+
parameters = {"id": thread_id}
187+
result = await self.execute_sql(query=query, parameters=parameters)
188+
if isinstance(result, list) and result:
189+
author_identifier = result[0].get("userIdentifier")
190+
if author_identifier is not None:
191+
return author_identifier
192+
raise ValueError(f"Author not found for thread_id {thread_id}")
193+
194+
async def create_step(self, step_dict: "StepDict"):
195+
if "disableFeedback" not in step_dict.keys():
196+
step_dict["disableFeedback"] = False # type: ignore
197+
return await super().create_step(step_dict)
198+
199+
async def get_element(
200+
self, thread_id: str, element_id: str
201+
) -> Optional["ElementDict"]:
202+
pass

backend/tests/test_db/conftest.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import shutil
2+
from pathlib import Path
3+
4+
import pytest
5+
6+
7+
@pytest.fixture
8+
def tmp_files_folder(test_cleaner) -> Path:
9+
folder_path = Path(__file__).parent / "tmp_folder_for_tests"
10+
if folder_path.exists():
11+
shutil.rmtree(folder_path)
12+
folder_path.mkdir()
13+
test_cleaner(lambda: shutil.rmtree(folder_path))
14+
return folder_path
15+
16+
17+
@pytest.fixture()
18+
def test_cleaner():
19+
funcs = []
20+
21+
def add_func(func):
22+
funcs.append(func)
23+
24+
yield add_func
25+
for func in funcs:
26+
func()

backend/tests/test_db/test_db.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import asyncio
2+
3+
import pytest
4+
from chainlit.data.suites.checks_suites import CheckDB, DummyChainlitContext
5+
6+
import chainlit as cl
7+
8+
from .builder import CustomDataLayer, build_db
9+
10+
11+
class TestSQLiteDB(CheckDB):
12+
@pytest.fixture()
13+
def data_layer(self, tmp_files_folder):
14+
db_path = tmp_files_folder / "test.db"
15+
asyncio.run(build_db(f"sqlite+aiosqlite:///{db_path.as_posix()}"))
16+
return CustomDataLayer(
17+
f"sqlite+aiosqlite:///{db_path.as_posix()}",
18+
context=DummyChainlitContext(user=cl.User(identifier="test_user")),
19+
)

0 commit comments

Comments
 (0)