diff --git a/.gitignore b/.gitignore index f2272e91eb..288d0b25ee 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ build dist *.egg-info +*.ipynb .env diff --git a/backend/chainlit/cli/__init__.py b/backend/chainlit/cli/__init__.py index 5a03136e67..aab637a78d 100644 --- a/backend/chainlit/cli/__init__.py +++ b/backend/chainlit/cli/__init__.py @@ -47,6 +47,15 @@ def cli(): # Define the function to run Chainlit with provided options def run_chainlit(target: str): + import os + import subprocess + + def auto_run_alembic_upgrade(): + try: + subprocess.run(["alembic", "upgrade", "head"], check=True) + logger.info("Alembic migrations applied (upgrade to head).") + except Exception as e: + logger.error(f"Failed to run Alembic migrations: {e}") host = os.environ.get("CHAINLIT_HOST", DEFAULT_HOST) port = int(os.environ.get("CHAINLIT_PORT", DEFAULT_PORT)) root_path = os.environ.get("CHAINLIT_ROOT_PATH", DEFAULT_ROOT_PATH) @@ -76,6 +85,28 @@ def run_chainlit(target: str): config.run.module_name = target load_module(config.run.module_name) + # Check if SQLModelDataLayer is used and warn about Alembic migrations + data_layer_func = getattr(config.code, "data_layer", None) + if data_layer_func: + try: + dl_instance = data_layer_func() + from backend.chainlit.data.sql_data_layer import SQLModelDataLayer + if isinstance(dl_instance, SQLModelDataLayer): + # Get current version + try: + from chainlit.version import __version__ + except Exception: + __version__ = "unknown" + logger.info(f"SQLModelDataLayer detected. Chainlit version: {__version__}.") + auto_migrate = os.environ.get("CHAINLIT_AUTO_MIGRATE", "false").lower() in ["true", "1", "yes"] + if auto_migrate: + logger.info("Auto-migration enabled. Running Alembic migrations...") + auto_run_alembic_upgrade() + else: + logger.info("Auto-migration disabled. Run 'alembic upgrade head' after updating models or upgrading Chainlit.") + except Exception as e: + logger.warning(f"Could not check data layer type: {e}") + ensure_jwt_secret() assert_app() diff --git a/backend/chainlit/data/alembic/README.md b/backend/chainlit/data/alembic/README.md new file mode 100644 index 0000000000..5f504a1b51 --- /dev/null +++ b/backend/chainlit/data/alembic/README.md @@ -0,0 +1,39 @@ +# Alembic Migrations for Chainlit SQLModelDataLayer + +This directory contains Alembic migration scripts for the SQLModel-based data layer. + +## Best Practices + +- **Do not use `SQLModel.metadata.create_all()` in production.** +- Always manage schema changes with Alembic migrations. +- Keep migration scripts in version control. +- Run migrations before starting the app, or enable auto-migration with `CHAINLIT_AUTO_MIGRATE=true`. + +## Usage + +1. **Configure your database URL** in `alembic.ini`: + ```ini + sqlalchemy.url = + ``` + +2. **Autogenerate a migration** (after changing models): + ```bash + alembic revision --autogenerate -m "Initial tables" + ``` + +3. **Apply migrations**: + ```bash + alembic upgrade head + ``` + +## Initial Migration + +The first migration should create all tables defined in `chainlit.models`. + +## env.py + +Alembic is configured to use `SQLModel.metadata` from `chainlit.models`. + +--- + +For more details, see the [Alembic documentation](https://alembic.sqlalchemy.org/en/latest/). diff --git a/backend/chainlit/data/alembic/env.py b/backend/chainlit/data/alembic/env.py new file mode 100644 index 0000000000..0435faa2af --- /dev/null +++ b/backend/chainlit/data/alembic/env.py @@ -0,0 +1,52 @@ +import sys +import os +from logging.config import fileConfig +from sqlalchemy import engine_from_config, pool +from alembic import context + +# Add the parent directory to sys.path to import models +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +from chainlit.models import SQLModel + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) + +target_metadata = SQLModel.metadata + +def run_migrations_offline(): + """ + Run migrations in 'offline' mode. + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, target_metadata=target_metadata, literal_binds=True, dialect_opts={"paramstyle": "named"} + ) + + with context.begin_transaction(): + context.run_migrations() + +def run_migrations_online(): + """ + Run migrations in 'online' mode. + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/backend/chainlit/data/alembic/versions/0001_create_tables.py b/backend/chainlit/data/alembic/versions/0001_create_tables.py new file mode 100644 index 0000000000..ca1067b11c --- /dev/null +++ b/backend/chainlit/data/alembic/versions/0001_create_tables.py @@ -0,0 +1,50 @@ +""" +Initial migration: migrate camelCase columns to snake_case for SQLModelDataLayer +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = '0001_create_tables' +down_revision = None +branch_labels = None +depends_on = None + +def upgrade(): + # Thread table: rename camelCase columns to snake_case + with op.batch_alter_table('thread') as batch_op: + batch_op.rename_column('createdAt', 'created_at') + batch_op.rename_column('userId', 'user_id') + batch_op.rename_column('userIdentifier', 'user_identifier') + # tags and metadata are already snake_case or compatible + # Repeat for other tables (User, Step, Element, Feedback) + with op.batch_alter_table('user') as batch_op: + batch_op.rename_column('createdAt', 'created_at') + with op.batch_alter_table('step') as batch_op: + batch_op.rename_column('threadId', 'thread_id') + batch_op.rename_column('parentId', 'parent_id') + batch_op.rename_column('createdAt', 'created_at') + with op.batch_alter_table('element') as batch_op: + batch_op.rename_column('threadId', 'thread_id') + batch_op.rename_column('objectKey', 'object_key') + with op.batch_alter_table('feedback') as batch_op: + batch_op.rename_column('forId', 'for_id') + # If tables do not exist, Alembic will error; users should run this only once during migration. + +def downgrade(): + # Reverse the renames for downgrade + with op.batch_alter_table('thread') as batch_op: + batch_op.rename_column('created_at', 'createdAt') + batch_op.rename_column('user_id', 'userId') + batch_op.rename_column('user_identifier', 'userIdentifier') + with op.batch_alter_table('user') as batch_op: + batch_op.rename_column('created_at', 'createdAt') + with op.batch_alter_table('step') as batch_op: + batch_op.rename_column('thread_id', 'threadId') + batch_op.rename_column('parent_id', 'parentId') + batch_op.rename_column('created_at', 'createdAt') + with op.batch_alter_table('element') as batch_op: + batch_op.rename_column('thread_id', 'threadId') + batch_op.rename_column('object_key', 'objectKey') + with op.batch_alter_table('feedback') as batch_op: + batch_op.rename_column('for_id', 'forId') diff --git a/backend/chainlit/data/sql_data_layer.py b/backend/chainlit/data/sql_data_layer.py new file mode 100644 index 0000000000..4e00c56f73 --- /dev/null +++ b/backend/chainlit/data/sql_data_layer.py @@ -0,0 +1,391 @@ +from sqlmodel import SQLModel, create_engine, select +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine, async_sessionmaker +from contextlib import asynccontextmanager +from chainlit.data.base import BaseDataLayer +from chainlit.data.storage_clients.base import BaseStorageClient +from chainlit.data.utils import queue_until_user_message +from typing import Optional, Any, Dict +from datetime import datetime +from pydantic import ValidationError +from chainlit.models import PersistedUser, User, Feedback, Thread, Element, Step +import json +import ssl +import uuid +from chainlit.logger import logger +from chainlit.types import ( + PaginatedResponse, + Pagination, + ThreadFilter, + PageInfo +) +from sqlalchemy.engine import make_url +from sqlalchemy.pool import NullPool +from sqlalchemy import event + +ALLOWED_ASYNC_DRIVERS = { + "postgresql+asyncpg", + "postgresql+psycopg", # psycopg3 async + "sqlite+aiosqlite", + "mysql+aiomysql", + "mysql+asyncmy", + "mariadb+aiomysql", + "mariadb+asyncmy", + "mssql+aioodbc", +} + +class SQLDataLayer(BaseDataLayer): + def __init__( + self, + conninfo: str, + connect_args: Optional[dict[str, Any]] = None, + ssl_require: bool = False, + storage_provider: Optional[BaseStorageClient] = None, + user_thread_limit: Optional[int] = 1000, + show_logger: Optional[bool] = False, + ): + self._conninfo = conninfo + self.user_thread_limit = user_thread_limit + self.show_logger = bool(show_logger) + + connect_args = dict(connect_args or {}) + + # Validate async driver and prepare per-dialect settings + url = make_url(self._conninfo) + driver = url.drivername # e.g., "postgresql+asyncpg" + backend = url.get_backend_name() # e.g., "postgresql" + if driver not in ALLOWED_ASYNC_DRIVERS: + raise ValueError(f"Connection URL must use an async driver. Got '{driver}'. Use one of: {ALLOWED_ASYNC_DRIVERS}") + + if ssl_require: + # Create an SSL context to require an SSL connection + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = True + ssl_context.verify_mode = ssl.CERT_REQUIRED + connect_args.setdefault("ssl", ssl_context) + self.engine: AsyncEngine = create_async_engine( + self._conninfo, + connect_args=connect_args, + echo=self.show_logger, + ) + self.async_session = async_sessionmaker( + bind=self.engine, expire_on_commit=False, class_=AsyncSession + ) + if storage_provider: + self.storage_provider: Optional[BaseStorageClient] = storage_provider + if self.show_logger: + logger.info("SQLDataLayer storage client initialized") + else: + self.storage_provider = None + logger.warning("SQLDataLayer storage client is not initialized and elements will not be persisted!") + + async def init_db(self): + """ + Explicitly create tables for development or testing only. + In production, use Alembic migrations! + """ + logger.warning("init_db should only be used for local development or tests. Use Alembic for production migrations.") + async with self.engine.begin() as conn: + # await conn.run_sync(SQLModel.metadata.drop_all) # Uncomment to drop tables + await conn.run_sync(SQLModel.metadata.create_all) + + async def aclose(self) -> None: + await self.engine.dispose() + + async def create_user(self, user: User) -> Optional[PersistedUser]: + async with self.async_session.begin() as session: + result = await session.execute(select(PersistedUser).where(PersistedUser.identifier == user.identifier)) + existing = result.scalar_one_or_none() + if existing: + return existing + db_user = PersistedUser( + identifier=user.identifier, + metadata=user.metadata, + ) + session.add(db_user) + return db_user + + async def get_user(self, identifier: str) -> Optional[PersistedUser]: + async with self.async_session.begin() as session: + result = await session.execute(select(PersistedUser).where(PersistedUser.identifier == identifier)) + user = result.scalar_one_or_none() + return user + + async def update_user(self, identifier: str, metadata: Optional[dict] = None) -> Optional[PersistedUser]: + async with self.async_session.begin() as session: + result = await session.execute(select(PersistedUser).where(PersistedUser.identifier == identifier)) + user = result.scalar_one_or_none() + if user: + if metadata is not None: + user.metadata = metadata + await session.refresh(user) + return PersistedUser(identifier=user.identifier, metadata=user.metadata) + return None + + async def delete_user(self, identifier: str) -> bool: + async with self.async_session.begin() as session: + result = await session.execute(select(PersistedUser).where(PersistedUser.identifier == identifier)) + user = result.scalar_one_or_none() + if user: + await session.delete(user) + return True + return False + + async def create_thread(self, thread_data: dict) -> Optional[Dict]: + try: + thread = Thread.model_validate(thread_data) + except ValidationError as e: + logger.error(f"Thread data validation error: {e}") + return None + async with self.async_session.begin() as session: + session.add(thread) + await session.refresh(thread) + return thread.to_dict() + + async def get_thread(self, thread_id: str) -> Optional[Dict]: + async with self.async_session.begin() as session: + result = await session.execute(select(Thread).where(Thread.id == thread_id)) + thread = result.scalar_one_or_none() + if thread: + return thread.to_dict() + return None + + async def get_thread_author(self, thread_id: str) -> str: + async with self.async_session.begin() as session: + result = await session.execute(select(Thread).where(Thread.id == thread_id)) + thread: Thread = result.scalar_one_or_none() + if thread and thread.user_identifier: + return thread.user_identifier + return "" + + async def update_thread(self, thread_id: str, **kwargs) -> Optional[Dict]: + async with self.async_session.begin() as session: + result = await session.execute(select(Thread).where(Thread.id == thread_id)) + thread = result.scalar_one_or_none() + if thread: + for k, v in kwargs.items(): + setattr(thread, k, v) + await session.refresh(thread) + return thread.to_dict() + return None + + async def delete_thread(self, thread_id: str) -> bool: + async with self.async_session.begin() as session: + result = await session.execute(select(Thread).where(Thread.id == thread_id)) + thread = result.scalar_one_or_none() + if thread: + await session.delete(thread) + return True + return False + + @queue_until_user_message() + async def create_step(self, step_data: dict) -> Optional[Dict]: + try: + step = Step.model_validate(step_data) + except ValidationError as e: + logger.error(f"Thread data validation error: {e}") + return None + async with self.async_session.begin() as session: + session.add(step) + await session.refresh(step) + return step.to_dict() + + async def get_step(self, step_id: str) -> Optional[Dict]: + async with self.async_session.begin() as session: + result = await session.execute(select(Step).where(Step.id == step_id)) + step = result.scalar_one_or_none() + if step: + return step.to_dict() + return None + + @queue_until_user_message() + async def update_step(self, step_id: str, **kwargs) -> Optional[Dict]: + async with self.async_session.begin() as session: + result = await session.execute(select(Step).where(Step.id == step_id)) + step = result.scalar_one_or_none() + if step: + for k, v in kwargs.items(): + setattr(step, k, v) + await session.refresh(step) + return step.to_dict() + return None + + @queue_until_user_message() + async def delete_step(self, step_id: str) -> bool: + async with self.async_session.begin() as session: + result = await session.execute(select(Step).where(Step.id == step_id)) + step = result.scalar_one_or_none() + if step: + await session.delete(step) + return True + return False + + async def upsert_feedback(self, feedback: Feedback) -> str: + feedback_id = feedback.id or str(uuid.uuid4()) + feedback_dict = feedback.dict() + feedback_dict["id"] = feedback_id + async with self.async_session.begin() as session: + result = await session.execute(select(Feedback).where(Feedback.id == feedback_id)) + db_feedback = result.scalar_one_or_none() + if db_feedback: + for k, v in feedback_dict.items(): + setattr(db_feedback, k, v) + else: + db_feedback = Feedback.model_validate(feedback_dict) + session.add(db_feedback) + await session.refresh(db_feedback) + return db_feedback.id + + async def get_feedback(self, feedback_id: str) -> Optional[Dict]: + async with self.async_session.begin() as session: + result = await session.execute(select(Feedback).where(Feedback.id == feedback_id)) + feedback = result.scalar_one_or_none() + if feedback: + return feedback.to_dict() + return None + + async def delete_feedback(self, feedback_id: str) -> bool: + async with self.async_session.begin() as session: + result = await session.execute(select(Feedback).where(Feedback.id == feedback_id)) + feedback = result.scalar_one_or_none() + if feedback: + await session.delete(feedback) + return True + return False + + async def get_element(self, thread_id: str, element_id: str) -> Optional[Dict]: + async with self.async_session.begin() as session: + result = await session.execute( + select(Element).where(Element.thread_id == thread_id, Element.id == element_id) + ) + element = result.scalar_one_or_none() + if element: + # props should be deserialized if stored as JSON string + props = element.props + if isinstance(props, str): + props = json.loads(props) + return { + **element.to_dict(), + "props": props, + } + return None + + @queue_until_user_message() + async def create_element(self, element: "Element"): + if self.show_logger: + logger.info(f"SQLDataLayer: create_element, element_id = {element.id}") + + if not self.storage_provider: + logger.warning("SQLDataLayer: create_element error. No blob_storage_client is configured!") + return + if not element.for_id: + return + + content: Optional[bytes] = None + if element.path: + import aiofiles + async with aiofiles.open(element.path, "rb") as f: + content = await f.read() + elif element.url: + import aiohttp + async with aiohttp.ClientSession() as session_http: + async with session_http.get(element.url) as response: + if response.status == 200: + content = await response.read() + else: + content = None + elif element.content: + content = element.content + else: + raise ValueError("Element url, path or content must be provided") + if content is None: + raise ValueError("Content is None, cannot upload file") + + user_id: str = await self._get_user_id_by_thread(element.thread_id) or "unknown" + file_object_key = f"{user_id}/{element.id}" + (f"/{element.name}" if element.name else "") + + if not element.mime: + element.mime = "application/octet-stream" + + uploaded_file = await self.storage_provider.upload_file( + object_key=file_object_key, data=content, mime=element.mime, overwrite=True + ) + if not uploaded_file: + raise ValueError("SQLModel Error: create_element, Failed to persist data in storage_provider") + + element_dict = element.to_dict() + element_dict["url"] = uploaded_file.get("url") + element_dict["objectKey"] = uploaded_file.get("object_key") + element_dict_cleaned = {k: v for k, v in element_dict.items() if v is not None} + if "props" in element_dict_cleaned: + element_dict_cleaned["props"] = json.dumps(element_dict_cleaned["props"]) + + async with self.async_session.begin() as session: + db_element = Element.model_validate(element_dict_cleaned) + session.add(db_element) + await session.refresh(db_element) + return db_element.to_dict() + + @queue_until_user_message() + async def delete_element(self, element_id: str, thread_id: Optional[str] = None): + if self.show_logger: + logger.info(f"SQLDataLayer: delete_element, element_id={element_id}") + + async with self.async_session.begin() as session: + query = select(Element).where(Element.id == element_id) + if thread_id: + query = query.where(Element.thread_id == thread_id) + result = await session.execute(query) + element = result.scalar_one_or_none() + element_dict = element.to_dict() if element else None + if ( + self.storage_provider is not None + and element is not None + and getattr(element_dict, "objectKey", None) + ): + await self.storage_provider.delete_file(object_key=element['objectKey']) + if element: + await session.delete(element) + + async def build_debug_url(self) -> str: + # Implement as needed, or return empty string for now + return "" + + async def list_threads( + self, pagination: Pagination, filters: ThreadFilter + ) -> PaginatedResponse[Dict]: + # Fetch threads for a user, apply pagination and filters + async with self.async_session.begin() as session: + if filters.userId: + query = select(Thread).where(Thread.user_id == filters.userId) + result = await session.execute(query) + threads = result.scalars().all() + # Apply search filter + if filters.search: + threads = [t for t in threads if filters.search.lower() in (t.name or '').lower()] + # Apply feedback filter (if present) + if filters.feedback is not None: + # This requires joining with Feedback, so for now, skip or implement as needed + pass + # Pagination + start = 0 + if pagination.cursor: + for i, t in enumerate(threads): + if t.id == pagination.cursor: + start = i + 1 + break + end = start + pagination.first + paginated_threads = threads[start:end] + has_next_page = len(threads) > end + start_cursor = paginated_threads[0].id if paginated_threads else None + end_cursor = paginated_threads[-1].id if paginated_threads else None + # Convert to dicts + data = [t.to_dict() for t in paginated_threads] + # Build PaginatedResponse + return PaginatedResponse( + pageInfo=PageInfo( + hasNextPage=has_next_page, + startCursor=start_cursor, + endCursor=end_cursor, + ), + data=data, + ) \ No newline at end of file diff --git a/backend/chainlit/models/__init__.py b/backend/chainlit/models/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/backend/chainlit/models/element.py b/backend/chainlit/models/element.py new file mode 100644 index 0000000000..b992204bf5 --- /dev/null +++ b/backend/chainlit/models/element.py @@ -0,0 +1,390 @@ +from typing import Optional, Dict, List, Union, ClassVar, TypeVar, Any, Literal, get_args +from sqlmodel import SQLModel, Field +import uuid +from pydantic import ConfigDict +from pydantic.functional_validators import field_validator +from pydantic import PrivateAttr +from pydantic.alias_generators import to_camel +import asyncio +import filetype +from chainlit.context import context +from chainlit.data import get_data_layer +from chainlit.logger import logger +from chainlit.element import Task, TaskStatus +import json +from sqlalchemy import Column, JSON, ForeignKey, String + +APPLICATION_JSON = "application/json" + +ElementType = Literal[ + "image", + "text", + "pdf", + "tasklist", + "audio", + "video", + "file", + "plotly", + "dataframe", + "custom", +] + +ElementDisplay = Literal["inline", "side", "page"] +ElementSize = Literal["small", "medium", "large"] + +mime_types: Dict[str, str] = { + "text": "text/plain", + "tasklist": APPLICATION_JSON, + "plotly": APPLICATION_JSON, +} +class ElementBase(SQLModel): + type: ElementType + name: str = "" + url: Optional[str] = None + path: Optional[str] = None + object_key: Optional[str] = None + chainlit_key: Optional[str] = None + display: ElementDisplay = "inline" + size: Optional[ElementSize] = None + language: Optional[str] = None + mime: Optional[str] = None + page: Optional[int] = None + props: Optional[dict] = None + auto_play: Optional[bool] = None + player_config: Optional[dict] = None + # runtime-only + _content: Optional[Union[str, bytes]] = PrivateAttr(default=None) + _persisted: bool = PrivateAttr(default=False) + _updatable: bool = PrivateAttr(default=False) + _bg_task: Any = PrivateAttr(default=None) + + model_config = ConfigDict( + alias_generator=to_camel, + populate_by_name=True, + ) + + @property + def content(self) -> Optional[Union[str, bytes]]: + return self._content + + @content.setter + def content(self, value: Optional[Union[str, bytes]]): + self._content = value + + def to_dict(self): + return self.model_dump(by_alias=True) + + @classmethod + def from_dict(cls, **kwargs): + # Default to file if missing + t = kwargs.get("type", "file") + if t not in TYPE_MAP: + t = "file" + kwargs["type"] = t + model = TYPE_MAP.get(t, File) + return model.model_validate(kwargs) + + @classmethod + def infer_type_from_mime(cls, mime_type: str): + """Infer the element type from a mime type. Useful to know which element to instantiate from a file upload.""" + if "image" in mime_type: + return "image" + + elif mime_type == "application/pdf": + return "pdf" + + elif "audio" in mime_type: + return "audio" + + elif "video" in mime_type: + return "video" + + else: + return "file" + + def _resolve_mime(self) -> None: + # Resolve MIME if needed + if self.mime: + return + key = self.type + if isinstance(key, str) and key in mime_types: + self.mime = mime_types[key] + elif self.path or isinstance(self.content, (bytes, bytearray)): + file_type = filetype.guess(self.path or self.content) + if file_type: + self.mime = file_type.mime + elif self.url: + import mimetypes + self.mime = mimetypes.guess_type(self.url)[0] + + async def _persist_file_if_needed(self) -> None: + # Persist file if needed + if self.url: + return + if not self.chainlit_key or getattr(self, "updatable", False) or self._updatable: + file_dict = await context.session.persist_file( + name=self.name, + path=self.path, + content=self.content, + mime=self.mime or "", + ) + self.chainlit_key = file_dict["id"] + + async def _create(self, persist: bool = True, for_id: Optional[str] = None) -> None: + if self._persisted and not (getattr(self, "updatable", False) or self._updatable): + return None + + self._resolve_mime() + await self._persist_file_if_needed() + + data_layer = get_data_layer() + if data_layer and persist: + try: + # Map to DB element and persist + db_elem = Element.from_base(self, for_id=for_id) + self._bg_task = asyncio.create_task(data_layer.create_element(db_elem)) + except Exception as e: + logger.error(f"Failed to create element: {e!s}") + + self._persisted = True + return None + + async def remove(self): + data_layer = get_data_layer() + if data_layer: + await data_layer.delete_element(self.id, self.thread_id) + await context.emitter.emit("remove_element", {"id": self.id}) + + async def send(self, for_id: str, persist: bool = True): + await self._create(persist=persist, for_id=for_id) + + if not self.url and not self.chainlit_key: + raise ValueError("Must provide url or chainlit key to send element") + + await context.emitter.send_element(self.to_dict()) + +ElementBased = TypeVar("ElementBased", bound=ElementBase) + +# Subclasses for runtime logic (not DB tables) +class Image(ElementBase): + type: Literal["image"] = "image" + size: Optional[ElementSize] = "medium" + +class Text(ElementBase): + type: Literal["text"] = "text" + language: Optional[str] = None + +class Pdf(ElementBase): + type: Literal["pdf"] = "pdf" + mime: str = "application/pdf" + page: Optional[int] = None + +class Pyplot(ElementBase): + """Useful to send a pyplot to the UI.""" + type: Literal["image"] = "image" + size: Optional[ElementSize] = "medium" + figure: Any = Field(default=None, exclude=True) + + def model_post_init(self, __context) -> None: + if hasattr(self, "figure") and self.figure is not None: + from matplotlib.figure import Figure + from io import BytesIO + if not isinstance(self.figure, Figure): + raise TypeError("figure must be a matplotlib.figure.Figure") + image = BytesIO() + self.figure.savefig( + image, dpi=200, bbox_inches="tight", backend="Agg", format="png" + ) + self.content = image.getvalue() + super().model_post_init(__context) + +class TaskList(ElementBase): + type: Literal["tasklist"] = "tasklist" + tasks: List[Task] = Field(default_factory=list, exclude=True) + status: str = "Ready" + name: str = "tasklist" + + def model_post_init(self, __context) -> None: + super().model_post_init(__context) + self._updatable = True + setattr(self, "updatable", True) + + async def add_task(self, task: Task): + self.tasks.append(task) + + async def update(self): + await self.send(for_id=self.for_id or "") + + async def send(self, for_id: str, persist: bool = True): + await self.preprocess_content() + await super().send(for_id=for_id, persist=persist) + + async def preprocess_content(self): + # serialize enum + tasks = [ + {"title": task.title, "status": task.status.value, "forId": task.forId} + for task in self.tasks + ] + # store stringified json in content so that it's correctly stored in the database + self.content = json.dumps( + { + "status": self.status, + "tasks": tasks, + }, + indent=4, + ensure_ascii=False, + ) + +class Audio(ElementBase): + type: Literal["audio"] = "audio" + auto_play: bool = False + +class Video(ElementBase): + type: Literal["video"] = "video" + size: Optional[ElementSize] = "medium" + +class File(ElementBase): + type: Literal["file"] = "file" + +class Plotly(ElementBase): + type: Literal["plotly"] = "plotly" + size: Optional[ElementSize] = "medium" + figure: Any = Field(default=None, exclude=True) + + def model_post_init(self, __context) -> None: + if hasattr(self, "figure") and self.figure is not None: + from plotly import graph_objects as go, io as pio + if not isinstance(self.figure, go.Figure): + raise TypeError("figure must be a plotly.graph_objects.Figure") + self.figure.layout.autosize = True + self.figure.layout.width = None + self.figure.layout.height = None + self.content = pio.to_json(self.figure, validate=True) + self.mime = APPLICATION_JSON + super().model_post_init(__context) + +class Dataframe(ElementBase): + type: Literal["dataframe"] = "dataframe" + size: Optional[ElementSize] = "large" + data: Any = Field(default=None, exclude=True) + + def model_post_init(self, __context) -> None: + if hasattr(self, "data") and self.data is not None: + from pandas import DataFrame + if not isinstance(self.data, DataFrame): + raise TypeError("data must be a pandas.DataFrame") + self.content = self.data.to_json(orient="split", date_format="iso") + super().model_post_init(__context) + +class CustomElement(ElementBase): + """Useful to send a custom element to the UI.""" + type: Literal["custom"] = "custom" + mime: str = APPLICATION_JSON + + def model_post_init(self, __context) -> None: + self.content = json.dumps(self.props) + super().model_post_init(__context) + self._updatable = True + setattr(self, "updatable", True) + + async def update(self): + await super().send(for_id="") + +# DB model with table=True +class Element(ElementBase, table=True): + __tablename__ = "elements" + + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) + thread_id: Optional[str] = Field( + default=None, + sa_column=Column(String, ForeignKey("threads.id", ondelete="CASCADE"), nullable=True), + ) + for_id: Optional[str] = Field( + default=None, + sa_column=Column(String, ForeignKey("steps.id", ondelete="CASCADE"), nullable=True), + ) + # Override Literal fields with DB-mappable types + type: str = Field(..., nullable=False) + display: str = Field(..., nullable=False) + size: Optional[str] = None + props: Optional[dict] = Field(default_factory=dict, sa_type=JSON) + player_config: Optional[dict] = Field(default_factory=dict, sa_type=JSON) + + # Strict validation of DB fields using runtime Literal definitions + @field_validator("type", mode="before") + @classmethod + def _validate_type(cls, v: Any) -> str: + if v is None: + raise ValueError("type is required") + v_str = str(v) + if v_str not in get_args(ElementType): + raise ValueError(f"Invalid type: {v_str}") + return v_str + + @field_validator("display", mode="before") + @classmethod + def _validate_display(cls, v: Any) -> str: + if v is None: + raise ValueError("display is required") + v_str = str(v) + if v_str not in get_args(ElementDisplay): + raise ValueError(f"Invalid display: {v_str}") + return v_str + + @field_validator("size", mode="before") + @classmethod + def _validate_size(cls, v: Any) -> Optional[str]: + if v is None or v == "None": + return None + v_str = str(v) + if v_str not in get_args(ElementSize): + raise ValueError(f"Invalid size: {v_str}") + return v_str + + @classmethod + def from_base(cls, base: ElementBase, for_id: Optional[str] = None) -> "Element": + return cls( + type=str(base.type), + name=base.name, + url=base.url, + path=base.path, + object_key=base.object_key, + chainlit_key=base.chainlit_key, + display=str(base.display), + size=str(base.size) if base.size is not None else None, + language=base.language, + mime=base.mime, + page=base.page, + props=base.props or {}, + auto_play=base.auto_play, + player_config=base.player_config or {}, + for_id=for_id, + ) + + # Validators to enforce allowed values on the DB model + @classmethod + def _allowed(cls, lit) -> List[str]: + return list(get_args(lit)) + + @classmethod + def _validate_choice(cls, value: Optional[str], lit) -> Optional[str]: + if value is None: + return value + allowed = cls._allowed(lit) + if value not in allowed: + raise ValueError(f"Invalid value: {value}. Must be one of: {allowed}") + return value + +# Simple mapping for type discrimination (Pyplot shares "image", so not included) +TYPE_MAP: Dict[str, Any] = { + "image": Image, + "text": Text, + "pdf": Pdf, + "tasklist": TaskList, + "audio": Audio, + "video": Video, + "file": File, + "plotly": Plotly, + "dataframe": Dataframe, + "custom": CustomElement, +} \ No newline at end of file diff --git a/backend/chainlit/models/feedback.py b/backend/chainlit/models/feedback.py new file mode 100644 index 0000000000..229ed8d120 --- /dev/null +++ b/backend/chainlit/models/feedback.py @@ -0,0 +1,52 @@ +from typing import Optional, Literal +from sqlmodel import SQLModel, Field +from pydantic import BaseModel, field_validator, ConfigDict +from pydantic.alias_generators import to_camel +import uuid +from sqlalchemy import Column, ForeignKey, String + +FeedbackStrategy = Literal["BINARY"] + + +class FeedbackBase(SQLModel): + for_id: str + value: int + thread_id: Optional[str] = None + comment: Optional[str] = None + + model_config = ConfigDict( + alias_generator=to_camel, + populate_by_name=True, + ) + + @field_validator("value", mode="before") + @classmethod + def validate_value(cls, v): + allowed = [0, 1] + if v not in allowed: + raise ValueError(f"Invalid value: {v}. Must be one of: {allowed}") + return v + + def to_dict(self): + data = self.model_dump(by_alias=True) + return data + + +class Feedback(FeedbackBase, table=True): + __tablename__ = "feedbacks" + + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) + for_id: str = Field( + sa_column=Column(String, ForeignKey("steps.id", ondelete="CASCADE")) + ) + value: int = Field(..., ge=0, le=1) + comment: Optional[str] = None + + +class UpdateFeedbackRequest(BaseModel): + feedback: Feedback + session_id: str + + +class DeleteFeedbackRequest(BaseModel): + feedbackId: str diff --git a/backend/chainlit/models/step.py b/backend/chainlit/models/step.py new file mode 100644 index 0000000000..5b24e335d5 --- /dev/null +++ b/backend/chainlit/models/step.py @@ -0,0 +1,499 @@ + +import asyncio +import inspect +import json +import uuid +import asyncio +import inspect +import json +import uuid +from copy import deepcopy +from functools import wraps +from typing import Callable, Dict, List, Optional, Union, Literal, Any, get_args + +from sqlmodel import SQLModel, Field +from sqlalchemy import Column, JSON, ForeignKey, String +from pydantic import PrivateAttr +from pydantic import field_validator +from pydantic import ConfigDict +from pydantic.alias_generators import to_camel + +from chainlit.config import config +from chainlit.context import CL_RUN_NAMES, context, local_steps +from chainlit.data import get_data_layer +from chainlit.logger import logger +from chainlit.utils import utc_now + +# Import the Element runtime class via models init to avoid circular import +try: + from chainlit.models import Element # type: ignore +except Exception: # pragma: no cover - optional during partial refactors + Element = Any # fallback for type hints + +TrueStepType = Literal[ + "run", "tool", "llm", "embedding", "retrieval", "rerank", "undefined" +] + +MessageStepType = Literal["user_message", "assistant_message", "system_message"] + +StepType = Union[TrueStepType, MessageStepType] + + +class StepBase(SQLModel): + """Runtime Step model. DB fields overridden in Step(table=True).""" + + # Core fields (runtime view). The DB model will override types as str with validators. + name: str = Field(default="") + type: StepType = Field(default="undefined") + + # Optional linkage; DB model defines FKs + thread_id: Optional[str] = None + parent_id: Optional[str] = None + + # Rendering/behavior + disable_feedback: bool = Field(default=False) + streaming: bool = Field(default=False) + wait_for_answer: Optional[bool] = None + is_error: Optional[bool] = None + + # Payload and metadata + input: Optional[str] = None + output: Optional[str] = None + created_at: Optional[str] = None + start: Optional[str] = None + end: Optional[str] = None + generation: Optional[dict] = None + show_input: Union[bool, str] = Field(default="json") + language: Optional[str] = None + indent: Optional[int] = None + tags: Optional[List[str]] = None + default_open: Optional[bool] = Field(default=False) + metadata_: Optional[dict] = Field( + default_factory=dict, + alias="metadata", + sa_column=Column("metadata", JSON), + schema_extra={"serialization_alias": "metadata"}, + ) + + model_config = ConfigDict( + alias_generator=to_camel, + populate_by_name=True, + ) + + # Private attributes for business logic (not persisted) + _elements: List[Any] = PrivateAttr(default_factory=list) + _fail_on_persist_error: bool = PrivateAttr(default=False) + _input: str = PrivateAttr(default="") + _output: str = PrivateAttr(default="") + _persisted: bool = PrivateAttr(default=False) + + # Convenience properties + @property + def persisted(self) -> bool: + return self._persisted + + @persisted.setter + def persisted(self, v: bool) -> None: + self._persisted = bool(v) + + @property + def elements(self) -> List[Any]: + return self._elements + + @property + def fail_on_persist_error(self) -> bool: + return self._fail_on_persist_error + + @fail_on_persist_error.setter + def fail_on_persist_error(self, v: bool) -> None: + self._fail_on_persist_error = bool(v) + + @field_validator("type", mode="before") + @classmethod + def _validate_type(cls, v: Any) -> Any: + # Accept literals on base; DB class enforces strict string values + allowed = [ + value + for arg in get_args(StepType) + for value in (get_args(arg) if hasattr(arg, "__args__") else [arg]) + ] + if v not in allowed: + raise ValueError(f"Invalid type: {v}. Must be one of: {allowed}") + return v + + @property + def input_value(self): + return self._input + + @input_value.setter + def input_value(self, content: Union[Dict, str]): + self._input = self._process_content(content, set_language=False) + self.input = self._input + + @property + def output_value(self): + return self._output + + @output_value.setter + def output_value(self, content: Union[Dict, str]): + self._output = self._process_content(content, set_language=True) + self.output = self._output + + def _clean_content(self, content): + def handle_bytes(item): + if isinstance(item, bytes): + return "STRIPPED_BINARY_DATA" + elif isinstance(item, dict): + return {k: handle_bytes(v) for k, v in item.items()} + elif isinstance(item, list): + return [handle_bytes(i) for i in item] + elif isinstance(item, tuple): + return tuple(handle_bytes(i) for i in item) + return item + + return handle_bytes(content) + + def _process_content(self, content, set_language=False): + if content is None: + return "" + content = self._clean_content(content) + if isinstance(content, (dict, list, tuple)): + try: + processed_content = json.dumps(content, indent=4, ensure_ascii=False) + if set_language: + self.language = "json" + except TypeError: + processed_content = str(content).replace("\\n", "\n") + if set_language: + self.language = "text" + elif isinstance(content, str): + processed_content = content + else: + processed_content = str(content).replace("\\n", "\n") + if set_language: + self.language = "text" + return processed_content + + def to_dict(self): + return self.model_dump(by_alias=True) + + # Context manager support + async def __aenter__(self): + self.start = utc_now() + previous_steps = local_steps.get() or [] + parent_step = previous_steps[-1] if previous_steps else None + + if not self.parent_id and parent_step: + self.parent_id = parent_step.id + local_steps.set(previous_steps + [self]) + await self.send() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + self.end = utc_now() + + if exc_type: + self.output_value = str(exc_val) + self.is_error = True + + current_steps = local_steps.get() + if current_steps and self in current_steps: + current_steps.remove(self) + local_steps.set(current_steps) + + await self.update() + + def __enter__(self): + self.start = utc_now() + + previous_steps = local_steps.get() or [] + parent_step = previous_steps[-1] if previous_steps else None + + if not self.parent_id and parent_step: + self.parent_id = parent_step.id + local_steps.set(previous_steps + [self]) + + asyncio.create_task(self.send()) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end = utc_now() + + if exc_type: + self.output_value = str(exc_val) + self.is_error = True + + current_steps = local_steps.get() + if current_steps and self in current_steps: + current_steps.remove(self) + local_steps.set(current_steps) + + asyncio.create_task(self.update()) + + async def update(self): + if self.streaming: + self.streaming = False + + step_dict = self.to_dict() + data_layer = get_data_layer() + + if data_layer: + try: + asyncio.create_task(data_layer.update_step(step_dict.copy())) + except Exception as e: + if self.fail_on_persist_error: + raise e + logger.error(f"Failed to persist step update: {e!s}") + + tasks = [el.send(for_id=self.id) for el in getattr(self, 'elements', [])] + await asyncio.gather(*tasks) + + from chainlit.context import check_add_step_in_cot, stub_step + if not check_add_step_in_cot(self): + await context.emitter.update_step(stub_step(self)) + else: + await context.emitter.update_step(step_dict) + + return True + + async def remove(self): + step_dict = self.to_dict() + data_layer = get_data_layer() + + if data_layer: + try: + asyncio.create_task(data_layer.delete_step(self.id)) + except Exception as e: + if self.fail_on_persist_error: + raise e + logger.error(f"Failed to persist step deletion: {e!s}") + + await context.emitter.delete_step(step_dict) + return True + + async def send(self): + if self.persisted: + return self + + if getattr(config.code, "author_rename", None): + self.name = await config.code.author_rename(self.name) + + if self.streaming: + self.streaming = False + + step_dict = self.to_dict() + data_layer = get_data_layer() + + if data_layer: + try: + asyncio.create_task(data_layer.create_step(step_dict.copy())) + self.persisted = True + except Exception as e: + if self.fail_on_persist_error: + raise e + logger.error(f"Failed to persist step creation: {e!s}") + + tasks = [el.send(for_id=self.id) for el in getattr(self, 'elements', [])] + await asyncio.gather(*tasks) + + from chainlit.context import check_add_step_in_cot + if not check_add_step_in_cot(self): + await context.emitter.send_step(self.to_dict()) + else: + await context.emitter.send_step(step_dict) + + return self + + async def stream_token(self, token: str, is_sequence=False, is_input=False): + if not token: + return + + from chainlit.context import check_add_step_in_cot, stub_step + + if is_sequence: + if is_input: + self.input_value = token + else: + self.output_value = token + else: + if is_input: + self.input_value += token + else: + self.output_value += token + + assert self.id + + if not check_add_step_in_cot(self): + await context.emitter.send_step(stub_step(self)) + return + + if not self.streaming: + self.streaming = True + step_dict = self.to_dict() + await context.emitter.stream_start(step_dict) + else: + await context.emitter.send_token( + id=self.id, token=token, is_sequence=is_sequence, is_input=is_input + ) + + +class Step(StepBase, table=True): + __tablename__ = "steps" + + # DB identity and relations + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) + thread_id: Optional[str] = Field( + default=None, + sa_column=Column(String, ForeignKey("threads.id", ondelete="CASCADE"), nullable=True), + ) + parent_id: Optional[str] = Field( + default=None, + sa_column=Column(String, ForeignKey("steps.id", ondelete="CASCADE"), nullable=True), + ) + + # Override Literal and complex fields with DB-compatible types/columns + type: str = Field(..., nullable=False) + tags: Optional[List[str]] = Field(default_factory=list, sa_column=Column(JSON)) + metadata_: Optional[dict] = Field( + default_factory=dict, + sa_column=Column("metadata", JSON), + alias="metadata", + schema_extra={"serialization_alias": "metadata"}, + ) + generation: Optional[dict] = Field( + default_factory=dict, + sa_column=Column("generation", JSON), + alias="generation", + ) + show_input: str + + model_config = ConfigDict( + alias_generator=to_camel, + populate_by_name=True, + ) + + @field_validator("type", mode="before") + @classmethod + def _validate_type_db(cls, v: Any) -> str: + if v is None: + raise ValueError("type is required") + v_str = str(v) + allowed = [ + value + for arg in get_args(StepType) + for value in (get_args(arg) if hasattr(arg, "__args__") else [arg]) + ] + if v_str not in allowed: + raise ValueError(f"Invalid type: {v}. Must be one of: {allowed}") + return v_str + + @classmethod + def from_base(cls, base: "StepBase") -> "Step": + data = base.model_dump(by_alias=True) + # Map runtime metadata -> metadata_ + if "metadata" in data and data.get("metadata") is not None: + data["metadata_"] = data.pop("metadata") + return cls.model_validate(data) + + +def flatten_args_kwargs(func, args, kwargs): + signature = inspect.signature(func) + bound_arguments = signature.bind(*args, **kwargs) + bound_arguments.apply_defaults() + return {k: deepcopy(v) for k, v in bound_arguments.arguments.items()} + + +def check_add_step_in_cot(step: StepBase): + is_message = step.type in [ + "user_message", + "assistant_message", + ] + is_cl_run = step.name in CL_RUN_NAMES and step.type == "run" + if config.ui.cot == "hidden" and not is_message and not is_cl_run: + return False + return True + + +def step( + original_function: Optional[Callable] = None, + *, + name: Optional[str] = "", + type: Optional[str] = "undefined", + id: Optional[str] = None, + parent_id: Optional[str] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict] = None, + language: Optional[str] = None, + show_input: Union[bool, str] = "json", + default_open: bool = False, +) -> Callable: + """Decorator to wrap functions in a Step context.""" + + def wrapper(func: Callable): + nonlocal name + if not name: + name = func.__name__ + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args, **kwargs): + async with Step( + type=type, + name=name, + id=id, + parent_id=parent_id, + tags=tags, + language=language, + show_input=show_input, + default_open=default_open, + metadata=metadata, + ) as step_obj: + try: + step_obj.input = flatten_args_kwargs(func, args, kwargs) + except Exception: + pass + result = await func(*args, **kwargs) + try: + if result and not step_obj.output: + step_obj.output = result + except Exception: + step_obj.is_error = True + step_obj.output = str(result) + return result + + return async_wrapper + else: + + @wraps(func) + def sync_wrapper(*args, **kwargs): + with Step( + type=type, + name=name, + id=id, + parent_id=parent_id, + tags=tags, + language=language, + show_input=show_input, + default_open=default_open, + metadata=metadata, + ) as step_obj: + try: + step_obj.input = flatten_args_kwargs(func, args, kwargs) + except Exception: + pass + result = func(*args, **kwargs) + try: + if result and not step_obj.output: + step_obj.output = result + except Exception: + step_obj.is_error = True + step_obj.output = str(result) + return result + + return sync_wrapper + + func = original_function + if not func: + return wrapper + else: + return wrapper(func) diff --git a/backend/chainlit/models/thread.py b/backend/chainlit/models/thread.py new file mode 100644 index 0000000000..860be718c6 --- /dev/null +++ b/backend/chainlit/models/thread.py @@ -0,0 +1,110 @@ + +from typing import Dict, Generic, List, Optional, TypeVar, Self +from sqlmodel import SQLModel, Field +from pydantic import PrivateAttr, BaseModel +import uuid +from pydantic import ConfigDict +from pydantic.alias_generators import to_camel +from sqlalchemy import Column, JSON, ForeignKey, String + + +class ThreadBase(SQLModel): + created_at: Optional[str] = None + name: Optional[str] = None + user_id: Optional[str] = None + user_identifier: Optional[str] = None + tags: Optional[List[str]] = None + # Persisted as JSON column named "metadata", but exposed as `metadata` in the API + metadata_: Optional[dict] = Field( + default_factory=dict, + alias="metadata", + sa_column=Column("metadata", JSON), + schema_extra={"serialization_alias": "metadata"}, + ) + + model_config = ConfigDict( + alias_generator=to_camel, + populate_by_name=True, + ) + + # Private runtime attributes + _steps: Optional[List] = None + _elements: Optional[List] = None + _runtime_state: dict = PrivateAttr(default_factory=dict) + + def add_tag(self, tag: str): + if self.tags is None: + self.tags = [] + if tag not in self.tags: + self.tags.append(tag) + + def to_dict(self): + return self.model_dump(by_alias=True) + + @classmethod + def from_dict(cls, **kwargs) -> Self: + return cls.model_validate(**kwargs) + + +class Thread(ThreadBase, table=True): + __tablename__ = "threads" + + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) + user_id: Optional[str] = Field( + default=None, + sa_column=Column(String, ForeignKey("users.id", ondelete="CASCADE"), nullable=True), + ) + tags: Optional[List[str]] = Field(default_factory=list, sa_column=Column(JSON)) + + +# Pagination and ThreadFilter +class Pagination(BaseModel): + first: int + cursor: Optional[str] = None + +class ThreadFilter(BaseModel): + feedback: Optional[int] = None + user_id: Optional[str] = None + search: Optional[str] = None + + +class PageInfo(BaseModel): + hasNextPage: bool + startCursor: Optional[str] + endCursor: Optional[str] + + def to_dict(self): + return self.model_dump() + + @classmethod + def from_dict(cls, page_info_dict: Dict) -> Self: + return cls(**page_info_dict) + +T = TypeVar("T", covariant=True) +class PaginatedResponse(BaseModel, Generic[T]): + page_info: PageInfo + data: List[T] + + def to_dict(self): + return self.model_dump() + + @classmethod + def from_dict( + cls, paginated_response_dict: Dict + ) -> "PaginatedResponse[T]": + page_info = PageInfo.from_dict(paginated_response_dict.get("page_info", {})) + # Without runtime type info for T, return data as-is + data_list = paginated_response_dict.get("data", []) + return cls(page_info=page_info, data=data_list) + +# Thread requests/responses +class UpdateThreadRequest(BaseModel): + thread_id: str + name: str + +class DeleteThreadRequest(BaseModel): + thread_id: str + +class GetThreadsRequest(BaseModel): + pagination: Pagination + filter: ThreadFilter diff --git a/backend/chainlit/models/user.py b/backend/chainlit/models/user.py new file mode 100644 index 0000000000..03549e179e --- /dev/null +++ b/backend/chainlit/models/user.py @@ -0,0 +1,46 @@ +from typing import Dict, Optional, Literal +from sqlmodel import SQLModel, Field +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic.alias_generators import to_camel +from sqlalchemy import Column, JSON +import uuid +from chainlit.utils import utc_now + +Provider = Literal[ + "credentials", + "header", + "github", + "google", + "azure-ad", + "azure-ad-hybrid", + "okta", + "auth0", + "descope", +] + +# Non-persisted user (for runtime/session use) +class User(BaseModel): + identifier: str + display_name: Optional[str] = None + metadata: Dict = Field(default_factory=dict) + + +class PersistedUser(SQLModel, table=True): + __tablename__ = "users" + + id: str = Field(default_factory=lambda: str(uuid.uuid4()), primary_key=True) + identifier: str + display_name: Optional[str] = None + metadata_: Optional[dict] = Field( + default_factory=dict, + sa_column=Column("metadata", JSON), + alias="metadata", + schema_extra={"serialization_alias": "metadata"}, + ) + created_at: str = Field(default_factory=utc_now) + + model_config = ConfigDict( + alias_generator=to_camel, + populate_by_name=True, + ) \ No newline at end of file diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 08a8fb9093..38f85abd70 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -55,7 +55,9 @@ dependencies = [ "python-multipart>=0.0.18,<1.0.0", "pyjwt>=2.8.0,<3.0.0", "audioop-lts>=0.2.1,<0.3.0; python_version>='3.13'", - "pydantic-settings>=2.10.1" + "pydantic-settings>=2.10.1", + "sqlmodel>=0.0.24", + "alembic>=1.16.5" ] [project.urls] diff --git a/backend/uv.lock b/backend/uv.lock index 2983388a36..9beeab9cf7 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -170,6 +170,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f5/10/6c25ed6de94c49f88a91fa5018cb4c0f3625f31d5be9f771ebe5cc7cd506/aiosqlite-0.21.0-py3-none-any.whl", hash = "sha256:2549cf4057f95f53dcba16f2b64e8e2791d7e1adedb13197dd8ed77bb226d7d0", size = 15792, upload-time = "2025-02-03T07:30:13.6Z" }, ] +[[package]] +name = "alembic" +version = "1.16.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mako" }, + { name = "sqlalchemy" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9a/ca/4dc52902cf3491892d464f5265a81e9dff094692c8a049a3ed6a05fe7ee8/alembic-1.16.5.tar.gz", hash = "sha256:a88bb7f6e513bd4301ecf4c7f2206fe93f9913f9b48dac3b78babde2d6fe765e", size = 1969868, upload-time = "2025-08-27T18:02:05.668Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/39/4a/4c61d4c84cfd9befb6fa08a702535b27b21fff08c946bc2f6139decbf7f7/alembic-1.16.5-py3-none-any.whl", hash = "sha256:e845dfe090c5ffa7b92593ae6687c5cb1a101e91fa53868497dbd79847f9dbe3", size = 247355, upload-time = "2025-08-27T18:02:07.37Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -679,10 +694,11 @@ wheels = [ [[package]] name = "chainlit" -version = "2.7.1.1" +version = "2.7.2" source = { editable = "." } dependencies = [ { name = "aiofiles" }, + { name = "alembic" }, { name = "asyncer" }, { name = "audioop-lts", marker = "python_full_version >= '3.13'" }, { name = "click" }, @@ -701,6 +717,7 @@ dependencies = [ { name = "python-dotenv" }, { name = "python-multipart" }, { name = "python-socketio" }, + { name = "sqlmodel" }, { name = "starlette" }, { name = "syncer" }, { name = "tomli" }, @@ -752,6 +769,7 @@ tests = [ requires-dist = [ { name = "aiofiles", specifier = ">=23.1.0,<25.0.0" }, { name = "aiosqlite", marker = "extra == 'tests'", specifier = ">=0.20.0,<1.0.0" }, + { name = "alembic", specifier = ">=1.16.5" }, { name = "asyncer", specifier = ">=0.0.8,<0.1.0" }, { name = "asyncpg", marker = "extra == 'custom-data'", specifier = ">=0.30.0,<1.0.0" }, { name = "audioop-lts", marker = "python_full_version >= '3.13'", specifier = ">=0.2.1,<0.3.0" }, @@ -795,6 +813,7 @@ requires-dist = [ { name = "semantic-kernel", marker = "extra == 'tests'", specifier = ">=1.24.0,<2.0.0" }, { name = "slack-bolt", marker = "extra == 'tests'", specifier = ">=1.18.1,<2.0.0" }, { name = "sqlalchemy", marker = "extra == 'custom-data'", specifier = ">=2.0.28,<3.0.0" }, + { name = "sqlmodel", specifier = ">=0.0.24" }, { name = "starlette", specifier = ">=0.47.2" }, { name = "syncer", specifier = ">=2.0.3,<3.0.0" }, { name = "tenacity", marker = "extra == 'tests'", specifier = ">=8.4.1,<9.0.0" }, @@ -2499,6 +2518,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/05/50/c5ccd2a50daa0a10c7f3f7d4e6992392454198cd8a7d99fcb96cb60d0686/llama_parse-0.6.54-py3-none-any.whl", hash = "sha256:c66c8d51cf6f29a44eaa8595a595de5d2598afc86e5a33a4cebe5fe228036920", size = 4879, upload-time = "2025-08-01T20:09:22.651Z" }, ] +[[package]] +name = "mako" +version = "1.3.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9e/38/bd5b78a920a64d708fe6bc8e0a2c075e1389d53bef8413725c63ba041535/mako-1.3.10.tar.gz", hash = "sha256:99579a6f39583fa7e5630a28c3c1f440e4e97a414b80372649c0ce338da2ea28", size = 392474, upload-time = "2025-04-10T12:44:31.16Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" }, +] + [[package]] name = "markupsafe" version = "3.0.2" @@ -5314,6 +5345,19 @@ asyncio = [ { name = "greenlet" }, ] +[[package]] +name = "sqlmodel" +version = "0.0.24" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "sqlalchemy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/86/4b/c2ad0496f5bdc6073d9b4cef52be9c04f2b37a5773441cc6600b1857648b/sqlmodel-0.0.24.tar.gz", hash = "sha256:cc5c7613c1a5533c9c7867e1aab2fd489a76c9e8a061984da11b4e613c182423", size = 116780, upload-time = "2025-03-07T05:43:32.887Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/16/91/484cd2d05569892b7fef7f5ceab3bc89fb0f8a8c0cde1030d383dbc5449c/sqlmodel-0.0.24-py3-none-any.whl", hash = "sha256:6778852f09370908985b667d6a3ab92910d0d5ec88adcaf23dbc242715ff7193", size = 28622, upload-time = "2025-03-07T05:43:30.37Z" }, +] + [[package]] name = "sse-starlette" version = "3.0.2"