diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index 30b81af60..efc907cc9 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -1,8 +1,10 @@ import os +from pydantic_settings import BaseSettings, SettingsConfigDict + from astrbot.core.config import AstrBotConfig from astrbot.core.config.default import DB_PATH -from astrbot.core.db.sqlite import SQLiteDatabase +from astrbot.core.db.sqlite import BaseDatabase from astrbot.core.file_token_service import FileTokenService from astrbot.core.utils.pip_installer import PipInstaller from astrbot.core.utils.shared_preferences import SharedPreferences @@ -14,13 +16,44 @@ # 初始化数据存储文件夹 os.makedirs(get_astrbot_data_path(), exist_ok=True) + +class AstrBotMySQLSettings(BaseSettings): + host: str = "localhost" + port: int = 3306 + user: str = "root" + password: str = "" + database: str = "astrbot" + charset: str = "utf8mb4" + + model_config = SettingsConfigDict(env_file=".env", env_prefix="ASTR_MYSQL_") + + +def get_db_helper() -> BaseDatabase: + db_type = os.getenv("ASTR_DB_TYPE", "sqlite") + match db_type: + case "sqlite": + from astrbot.core.db.sqlite import SQLiteDatabase + + return SQLiteDatabase(DB_PATH) + case "mysql": + from astrbot.core.db.mysql import MySQLDatabase + + mysql_settings = AstrBotMySQLSettings() + + return MySQLDatabase(**mysql_settings.model_dump()) + case _: + from astrbot.core.db.sqlite import SQLiteDatabase + + return SQLiteDatabase(DB_PATH) + + DEMO_MODE = os.getenv("DEMO_MODE", False) astrbot_config = AstrBotConfig() t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img") html_renderer = HtmlRenderer(t2i_base_url) logger = LogManager.GetLogger(log_name="astrbot") -db_helper = SQLiteDatabase(DB_PATH) +db_helper = get_db_helper() # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中 sp = SharedPreferences(db_helper=db_helper) # 文件令牌服务 diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 58d1c6a9c..67f9098e4 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -3,6 +3,7 @@ import typing as T from contextlib import asynccontextmanager from dataclasses import dataclass +from enum import Enum from deprecated import deprecated from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine @@ -20,11 +21,17 @@ ) +class DatabaseType(Enum): + SQLITE = "sqlite" + MYSQL = "mysql" + + @dataclass class BaseDatabase(abc.ABC): """数据库基类""" DATABASE_URL = "" + database_type: DatabaseType def __init__(self) -> None: self.engine = create_async_engine( @@ -83,7 +90,7 @@ async def insert_platform_stats( @abc.abstractmethod async def count_platform_stats(self) -> int: - """Count the number of platform statistics records.""" + """Sum the count of platform statistics records.""" ... @abc.abstractmethod @@ -91,6 +98,16 @@ async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat """Get platform statistics within the specified offset in seconds and group by platform_id.""" ... + @abc.abstractmethod + async def get_platform_stats_time_series( + self, offset_sec: int = 86400 + ) -> list[tuple[int, int]]: + """Get platform statistics time series data grouped by hour. + + Returns a list of tuples (hour_timestamp, count) sorted by timestamp ascending. + """ + ... + @abc.abstractmethod async def get_conversations( self, diff --git a/astrbot/core/db/migration/helper.py b/astrbot/core/db/migration/helper.py index d7bca3067..fd0746acc 100644 --- a/astrbot/core/db/migration/helper.py +++ b/astrbot/core/db/migration/helper.py @@ -2,7 +2,7 @@ from astrbot.api import logger, sp from astrbot.core.config import AstrBotConfig -from astrbot.core.db import BaseDatabase +from astrbot.core.db import BaseDatabase, DatabaseType from astrbot.core.utils.astrbot_path import get_astrbot_data_path from .migra_3_to_4 import ( @@ -24,6 +24,10 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool: if not os.path.exists(data_v3_db): return False + + if db_helper.database_type == DatabaseType.MYSQL: + return False + migration_done = await db_helper.get_preference( "global", "global", diff --git a/astrbot/core/db/migration/migra_webchat_session.py b/astrbot/core/db/migration/migra_webchat_session.py index ff0b5ca6f..4a0858e7d 100644 --- a/astrbot/core/db/migration/migra_webchat_session.py +++ b/astrbot/core/db/migration/migra_webchat_session.py @@ -38,7 +38,7 @@ async def migrate_webchat_session(db_helper: BaseDatabase): query = ( select( col(PlatformMessageHistory.user_id), - col(PlatformMessageHistory.sender_name), + func.max(PlatformMessageHistory.sender_name).label("sender_name"), func.min(PlatformMessageHistory.created_at).label("earliest"), func.max(PlatformMessageHistory.updated_at).label("latest"), ) diff --git a/astrbot/core/db/mysql.py b/astrbot/core/db/mysql.py new file mode 100644 index 000000000..b4e229a83 --- /dev/null +++ b/astrbot/core/db/mysql.py @@ -0,0 +1,875 @@ +import asyncio +import typing as T +from contextlib import asynccontextmanager +from datetime import datetime, timedelta, timezone + +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker +from sqlmodel import col, delete, desc, func, or_, select, text, update + +from astrbot.core.db import BaseDatabase, DatabaseType +from astrbot.core.db.po import ( + Attachment, + ConversationV2, + Persona, + PlatformMessageHistory, + PlatformSession, + PlatformStat, + Preference, + SQLModel, +) +from astrbot.core.db.po import Stats as DeprecatedStats + +NOT_GIVEN = T.TypeVar("NOT_GIVEN") + + +class MySQLDatabase(BaseDatabase): + """MySQL 数据库实现 + + 使用方式: + db = MySQLDatabase( + host="localhost", + port=3306, + user="root", + password="password", + database="astrbot" + ) + await db.initialize() + """ + + database_type = DatabaseType.MYSQL + + def __init__( + self, + host: str = "localhost", + port: int = 3306, + user: str = "root", + password: str = "", + database: str = "astrbot", + charset: str = "utf8mb4", + ) -> None: + self.host = host + self.port = port + self.user = user + self.password = password + self.database = database + self.charset = charset + self.DATABASE_URL = ( + f"mysql+aiomysql://{user}:{password}@{host}:{port}/{database}" + f"?charset={charset}" + ) + self.inited = False + self._current_loop: asyncio.AbstractEventLoop | None = None + super().__init__() + + def _recreate_engine(self) -> None: + """重新创建数据库引擎和会话工厂,用于处理事件循环切换的情况""" + self.engine = create_async_engine( + self.DATABASE_URL, + echo=False, + future=True, + ) + self.AsyncSessionLocal = sessionmaker( + self.engine, + class_=AsyncSession, + expire_on_commit=False, + ) + + @asynccontextmanager + async def get_db(self) -> T.AsyncGenerator[AsyncSession, None]: + """Get a database session. + + 此方法会检查当前事件循环,如果事件循环发生变化会重新创建引擎, + 以解决 aiomysql 的 "attached to a different loop" 问题。 + """ + try: + current_loop = asyncio.get_running_loop() + except RuntimeError: + current_loop = None + + # 检查事件循环是否变化,如果变化则重新创建引擎 + if current_loop is not None and self._current_loop != current_loop: + self._recreate_engine() + self._current_loop = current_loop + self.inited = False # 需要重新初始化 + + if not self.inited: + await self.initialize() + self.inited = True + + async with self.AsyncSessionLocal() as session: + yield session + + async def initialize(self) -> None: + """Initialize the database by creating tables if they do not exist.""" + async with self.engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + await conn.commit() + + # ==== + # Platform Statistics + # ==== + + async def insert_platform_stats( + self, + platform_id, + platform_type, + count=1, + timestamp=None, + ) -> None: + """Insert a new platform statistic record.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + if timestamp is None: + timestamp = datetime.now().replace( + minute=0, + second=0, + microsecond=0, + ) + current_hour = timestamp + await session.execute( + text(""" + INSERT INTO platform_stats (timestamp, platform_id, platform_type, count) + VALUES (:timestamp, :platform_id, :platform_type, :count) + ON DUPLICATE KEY UPDATE + count = count + VALUES(count) + """), + { + "timestamp": current_hour, + "platform_id": platform_id, + "platform_type": platform_type, + "count": count, + }, + ) + + async def count_platform_stats(self) -> int: + """Count the number of platform statistics records.""" + async with self.get_db() as session: + session: AsyncSession + result = await session.execute( + select(func.sum(PlatformStat.count)).select_from(PlatformStat), + ) + count = result.scalar_one_or_none() + return count or 0 + + async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]: + """Get platform statistics within the specified offset in seconds and group by platform_id.""" + async with self.get_db() as session: + session: AsyncSession + now = datetime.now() + start_time = now - timedelta(seconds=offset_sec) + result = await session.execute( + text(""" + SELECT platform_id, platform_type, SUM(count) as total_count, MAX(timestamp) as latest_ts + FROM platform_stats + WHERE timestamp >= :start_time + GROUP BY platform_id, platform_type + ORDER BY latest_ts DESC + """), + {"start_time": start_time}, + ) + rows = result.fetchall() + return [ + PlatformStat( + id=0, + platform_id=row.platform_id, + platform_type=row.platform_type, + count=row.total_count, + timestamp=row.latest_ts, + ) + for row in rows + ] + + async def get_platform_stats_time_series( + self, offset_sec: int = 86400 + ) -> list[tuple[int, int]]: + """Get platform statistics time series data grouped by hour.""" + async with self.get_db() as session: + session: AsyncSession + now = datetime.now() + start_time = now - timedelta(seconds=offset_sec) + result = await session.execute( + text(""" + SELECT UNIX_TIMESTAMP(DATE_FORMAT(timestamp, '%Y-%m-%d %H:00:00')) as hour_ts, SUM(count) as total_count + FROM platform_stats + WHERE timestamp >= :start_time + GROUP BY hour_ts + ORDER BY hour_ts ASC + """), + {"start_time": start_time}, + ) + rows = result.fetchall() + return [(int(row.hour_ts), row.total_count) for row in rows] + + # ==== + # Conversation Management + # ==== + + async def get_conversations(self, user_id=None, platform_id=None): + async with self.get_db() as session: + session: AsyncSession + query = select(ConversationV2) + + if user_id: + query = query.where(ConversationV2.user_id == user_id) + if platform_id: + query = query.where(ConversationV2.platform_id == platform_id) + # order by + query = query.order_by(desc(ConversationV2.created_at)) + result = await session.execute(query) + + return result.scalars().all() + + async def get_conversation_by_id(self, cid): + async with self.get_db() as session: + session: AsyncSession + query = select(ConversationV2).where(ConversationV2.conversation_id == cid) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def get_all_conversations(self, page=1, page_size=20): + async with self.get_db() as session: + session: AsyncSession + offset = (page - 1) * page_size + result = await session.execute( + select(ConversationV2) + .order_by(desc(ConversationV2.created_at)) + .offset(offset) + .limit(page_size), + ) + return result.scalars().all() + + async def get_filtered_conversations( + self, + page=1, + page_size=20, + platform_ids=None, + search_query="", + **kwargs, + ): + async with self.get_db() as session: + session: AsyncSession + # Build the base query with filters + base_query = select(ConversationV2) + + if platform_ids: + base_query = base_query.where( + col(ConversationV2.platform_id).in_(platform_ids), + ) + if search_query: + search_query = search_query.encode("unicode_escape").decode("utf-8") + base_query = base_query.where( + or_( + col(ConversationV2.title).ilike(f"%{search_query}%"), + col(ConversationV2.content).ilike(f"%{search_query}%"), + col(ConversationV2.user_id).ilike(f"%{search_query}%"), + col(ConversationV2.conversation_id).ilike(f"%{search_query}%"), + ), + ) + if "message_types" in kwargs and len(kwargs["message_types"]) > 0: + for msg_type in kwargs["message_types"]: + base_query = base_query.where( + col(ConversationV2.user_id).ilike(f"%:{msg_type}:%"), + ) + if "platforms" in kwargs and len(kwargs["platforms"]) > 0: + base_query = base_query.where( + col(ConversationV2.platform_id).in_(kwargs["platforms"]), + ) + + # Get total count matching the filters + count_query = select(func.count()).select_from(base_query.subquery()) + total_count = await session.execute(count_query) + total = total_count.scalar_one() + + # Get paginated results + offset = (page - 1) * page_size + result_query = ( + base_query.order_by(desc(ConversationV2.created_at)) + .offset(offset) + .limit(page_size) + ) + result = await session.execute(result_query) + conversations = result.scalars().all() + + return conversations, total + + async def create_conversation( + self, + user_id, + platform_id, + content=None, + title=None, + persona_id=None, + cid=None, + created_at=None, + updated_at=None, + ): + kwargs = {} + if cid: + kwargs["conversation_id"] = cid + if created_at: + kwargs["created_at"] = created_at + if updated_at: + kwargs["updated_at"] = updated_at + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + new_conversation = ConversationV2( + user_id=user_id, + content=content or [], + platform_id=platform_id, + title=title, + persona_id=persona_id, + **kwargs, + ) + session.add(new_conversation) + return new_conversation + + async def update_conversation(self, cid, title=None, persona_id=None, content=None): + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + query = update(ConversationV2).where( + col(ConversationV2.conversation_id) == cid, + ) + values = {} + if title is not None: + values["title"] = title + if persona_id is not None: + values["persona_id"] = persona_id + if content is not None: + values["content"] = content + if not values: + return None + query = query.values(**values) + await session.execute(query) + return await self.get_conversation_by_id(cid) + + async def delete_conversation(self, cid): + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + delete(ConversationV2).where( + col(ConversationV2.conversation_id) == cid, + ), + ) + + async def delete_conversations_by_user_id(self, user_id: str) -> None: + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + delete(ConversationV2).where( + col(ConversationV2.user_id) == user_id + ), + ) + + async def get_session_conversations( + self, + page=1, + page_size=20, + search_query=None, + platform=None, + ) -> tuple[list[dict], int]: + """Get paginated session conversations with joined conversation and persona details.""" + async with self.get_db() as session: + session: AsyncSession + offset = (page - 1) * page_size + + # MySQL 使用 JSON_EXTRACT 函数(与 SQLite 的 json_extract 兼容) + base_query = ( + select( + col(Preference.scope_id).label("session_id"), + func.json_extract(Preference.value, "$.val").label( + "conversation_id", + ), # type: ignore + col(ConversationV2.persona_id).label("persona_id"), + col(ConversationV2.title).label("title"), + col(Persona.persona_id).label("persona_name"), + ) + .select_from(Preference) + .outerjoin( + ConversationV2, + func.json_extract(Preference.value, "$.val") + == ConversationV2.conversation_id, + ) + .outerjoin( + Persona, + col(ConversationV2.persona_id) == Persona.persona_id, + ) + .where(Preference.scope == "umo", Preference.key == "sel_conv_id") + ) + + # 搜索筛选 + if search_query: + search_pattern = f"%{search_query}%" + base_query = base_query.where( + or_( + col(Preference.scope_id).ilike(search_pattern), + col(ConversationV2.title).ilike(search_pattern), + col(Persona.persona_id).ilike(search_pattern), + ), + ) + + # 平台筛选 + if platform: + platform_pattern = f"{platform}:%" + base_query = base_query.where( + col(Preference.scope_id).like(platform_pattern), + ) + + # 排序 + base_query = base_query.order_by(Preference.scope_id) + + # 分页结果 + result_query = base_query.offset(offset).limit(page_size) + result = await session.execute(result_query) + rows = result.fetchall() + + # 查询总数(应用相同的筛选条件) + count_base_query = ( + select(func.count(col(Preference.scope_id))) + .select_from(Preference) + .outerjoin( + ConversationV2, + func.json_extract(Preference.value, "$.val") + == ConversationV2.conversation_id, + ) + .outerjoin( + Persona, + col(ConversationV2.persona_id) == Persona.persona_id, + ) + .where(Preference.scope == "umo", Preference.key == "sel_conv_id") + ) + + # 应用相同的搜索和平台筛选条件到计数查询 + if search_query: + search_pattern = f"%{search_query}%" + count_base_query = count_base_query.where( + or_( + col(Preference.scope_id).ilike(search_pattern), + col(ConversationV2.title).ilike(search_pattern), + col(Persona.persona_id).ilike(search_pattern), + ), + ) + + if platform: + platform_pattern = f"{platform}:%" + count_base_query = count_base_query.where( + col(Preference.scope_id).like(platform_pattern), + ) + + total_result = await session.execute(count_base_query) + total = total_result.scalar() or 0 + + sessions_data = [ + { + "session_id": row.session_id, + "conversation_id": row.conversation_id, + "persona_id": row.persona_id, + "title": row.title, + "persona_name": row.persona_name, + } + for row in rows + ] + return sessions_data, total + + async def insert_platform_message_history( + self, + platform_id, + user_id, + content, + sender_id=None, + sender_name=None, + ): + """Insert a new platform message history record.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + new_history = PlatformMessageHistory( + platform_id=platform_id, + user_id=user_id, + content=content, + sender_id=sender_id, + sender_name=sender_name, + ) + session.add(new_history) + return new_history + + async def delete_platform_message_offset( + self, + platform_id, + user_id, + offset_sec=86400, + ): + """Delete platform message history records newer than the specified offset.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + now = datetime.now() + cutoff_time = now - timedelta(seconds=offset_sec) + await session.execute( + delete(PlatformMessageHistory).where( + col(PlatformMessageHistory.platform_id) == platform_id, + col(PlatformMessageHistory.user_id) == user_id, + col(PlatformMessageHistory.created_at) >= cutoff_time, + ), + ) + + async def get_platform_message_history( + self, + platform_id, + user_id, + page=1, + page_size=20, + ): + """Get platform message history records.""" + async with self.get_db() as session: + session: AsyncSession + offset = (page - 1) * page_size + query = ( + select(PlatformMessageHistory) + .where( + PlatformMessageHistory.platform_id == platform_id, + PlatformMessageHistory.user_id == user_id, + ) + .order_by(desc(PlatformMessageHistory.created_at)) + ) + result = await session.execute(query.offset(offset).limit(page_size)) + return result.scalars().all() + + async def get_platform_message_history_by_id( + self, message_id: int + ) -> PlatformMessageHistory | None: + """Get a platform message history record by its ID.""" + async with self.get_db() as session: + session: AsyncSession + query = select(PlatformMessageHistory).where( + PlatformMessageHistory.id == message_id + ) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def insert_attachment(self, path, type, mime_type): + """Insert a new attachment record.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + new_attachment = Attachment( + path=path, + type=type, + mime_type=mime_type, + ) + session.add(new_attachment) + return new_attachment + + async def get_attachment_by_id(self, attachment_id): + """Get an attachment by its ID.""" + async with self.get_db() as session: + session: AsyncSession + query = select(Attachment).where(Attachment.attachment_id == attachment_id) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def get_attachments(self, attachment_ids: list[str]) -> list: + """Get multiple attachments by their IDs.""" + if not attachment_ids: + return [] + async with self.get_db() as session: + session: AsyncSession + query = select(Attachment).where( + Attachment.attachment_id.in_(attachment_ids) + ) + result = await session.execute(query) + return list(result.scalars().all()) + + async def delete_attachment(self, attachment_id: str) -> bool: + """Delete an attachment by its ID. + + Returns True if the attachment was deleted, False if it was not found. + """ + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + query = delete(Attachment).where( + col(Attachment.attachment_id) == attachment_id + ) + result = await session.execute(query) + return result.rowcount > 0 + + async def delete_attachments(self, attachment_ids: list[str]) -> int: + """Delete multiple attachments by their IDs. + + Returns the number of attachments deleted. + """ + if not attachment_ids: + return 0 + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + query = delete(Attachment).where( + col(Attachment.attachment_id).in_(attachment_ids) + ) + result = await session.execute(query) + return result.rowcount + + async def insert_persona( + self, + persona_id, + system_prompt, + begin_dialogs=None, + tools=None, + ): + """Insert a new persona record.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + new_persona = Persona( + persona_id=persona_id, + system_prompt=system_prompt, + begin_dialogs=begin_dialogs or [], + tools=tools, + ) + session.add(new_persona) + return new_persona + + async def get_persona_by_id(self, persona_id): + """Get a persona by its ID.""" + async with self.get_db() as session: + session: AsyncSession + query = select(Persona).where(Persona.persona_id == persona_id) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def get_personas(self): + """Get all personas for a specific bot.""" + async with self.get_db() as session: + session: AsyncSession + query = select(Persona) + result = await session.execute(query) + return result.scalars().all() + + async def update_persona( + self, + persona_id, + system_prompt=None, + begin_dialogs=None, + tools=NOT_GIVEN, + ): + """Update a persona's system prompt or begin dialogs.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + query = update(Persona).where(col(Persona.persona_id) == persona_id) + values = {} + if system_prompt is not None: + values["system_prompt"] = system_prompt + if begin_dialogs is not None: + values["begin_dialogs"] = begin_dialogs + if tools is not NOT_GIVEN: + values["tools"] = tools + if not values: + return None + query = query.values(**values) + await session.execute(query) + return await self.get_persona_by_id(persona_id) + + async def delete_persona(self, persona_id): + """Delete a persona by its ID.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + delete(Persona).where(col(Persona.persona_id) == persona_id), + ) + + async def insert_preference_or_update(self, scope, scope_id, key, value): + """Insert a new preference record or update if it exists.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + query = select(Preference).where( + Preference.scope == scope, + Preference.scope_id == scope_id, + Preference.key == key, + ) + result = await session.execute(query) + existing_preference = result.scalar_one_or_none() + if existing_preference: + existing_preference.value = value + else: + new_preference = Preference( + scope=scope, + scope_id=scope_id, + key=key, + value=value, + ) + session.add(new_preference) + return existing_preference or new_preference + + async def get_preference(self, scope, scope_id, key): + """Get a preference by key.""" + async with self.get_db() as session: + session: AsyncSession + query = select(Preference).where( + Preference.scope == scope, + Preference.scope_id == scope_id, + Preference.key == key, + ) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def get_preferences(self, scope, scope_id=None, key=None): + """Get all preferences for a specific scope ID or key.""" + async with self.get_db() as session: + session: AsyncSession + query = select(Preference).where(Preference.scope == scope) + if scope_id is not None: + query = query.where(Preference.scope_id == scope_id) + if key is not None: + query = query.where(Preference.key == key) + result = await session.execute(query) + return result.scalars().all() + + async def remove_preference(self, scope, scope_id, key): + """Remove a preference by scope ID and key.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + delete(Preference).where( + col(Preference.scope) == scope, + col(Preference.scope_id) == scope_id, + col(Preference.key) == key, + ), + ) + await session.commit() + + async def clear_preferences(self, scope, scope_id): + """Clear all preferences for a specific scope ID.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + delete(Preference).where( + col(Preference.scope) == scope, + col(Preference.scope_id) == scope_id, + ), + ) + await session.commit() + + # ==== + # Platform Session Management + # ==== + + async def create_platform_session( + self, + creator: str, + platform_id: str = "webchat", + session_id: str | None = None, + display_name: str | None = None, + is_group: int = 0, + ) -> PlatformSession: + """Create a new Platform session.""" + kwargs = {} + if session_id: + kwargs["session_id"] = session_id + + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + new_session = PlatformSession( + creator=creator, + platform_id=platform_id, + display_name=display_name, + is_group=is_group, + **kwargs, + ) + session.add(new_session) + await session.flush() + await session.refresh(new_session) + return new_session + + async def get_platform_session_by_id( + self, session_id: str + ) -> PlatformSession | None: + """Get a Platform session by its ID.""" + async with self.get_db() as session: + session: AsyncSession + query = select(PlatformSession).where( + PlatformSession.session_id == session_id, + ) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def get_platform_sessions_by_creator( + self, + creator: str, + platform_id: str | None = None, + page: int = 1, + page_size: int = 20, + ) -> list[PlatformSession]: + """Get all Platform sessions for a specific creator (username) and optionally platform.""" + async with self.get_db() as session: + session: AsyncSession + offset = (page - 1) * page_size + query = select(PlatformSession).where(PlatformSession.creator == creator) + + if platform_id: + query = query.where(PlatformSession.platform_id == platform_id) + + query = ( + query.order_by(desc(PlatformSession.updated_at)) + .offset(offset) + .limit(page_size) + ) + result = await session.execute(query) + return list(result.scalars().all()) + + async def update_platform_session( + self, + session_id: str, + display_name: str | None = None, + ) -> None: + """Update a Platform session's updated_at timestamp and optionally display_name.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + values: dict[str, T.Any] = {"updated_at": datetime.now(timezone.utc)} + if display_name is not None: + values["display_name"] = display_name + + await session.execute( + update(PlatformSession) + .where(col(PlatformSession.session_id) == session_id) + .values(**values), + ) + + async def delete_platform_session(self, session_id: str) -> None: + """Delete a Platform session by its ID.""" + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + await session.execute( + delete(PlatformSession).where( + col(PlatformSession.session_id) == session_id, + ), + ) + + # ==== + # Deprecated Methods + # ==== + + def get_base_stats(self, offset_sec=86400): + """Get base statistics within the specified offset in seconds.""" + return DeprecatedStats() + + def get_total_message_count(self): + """Get the total message count from platform statistics.""" + return 0 + + def get_grouped_base_stats(self, offset_sec=86400): + # group by platform_id + return DeprecatedStats() diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index ef2c2ad54..cd60d4b38 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlmodel import col, delete, desc, func, or_, select, text, update -from astrbot.core.db import BaseDatabase +from astrbot.core.db import BaseDatabase, DatabaseType from astrbot.core.db.po import ( Attachment, ConversationV2, @@ -28,6 +28,8 @@ class SQLiteDatabase(BaseDatabase): + database_type = DatabaseType.SQLITE + def __init__(self, db_path: str) -> None: self.db_path = db_path self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}" @@ -88,12 +90,10 @@ async def count_platform_stats(self) -> int: async with self.get_db() as session: session: AsyncSession result = await session.execute( - select(func.count(col(PlatformStat.platform_id))).select_from( - PlatformStat, - ), + select(func.sum(PlatformStat.count)).select_from(PlatformStat), ) count = result.scalar_one_or_none() - return count if count is not None else 0 + return count or 0 async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]: """Get platform statistics within the specified offset in seconds and group by platform_id.""" @@ -103,14 +103,46 @@ async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat start_time = now - timedelta(seconds=offset_sec) result = await session.execute( text(""" - SELECT * FROM platform_stats + SELECT platform_id, platform_type, SUM(count) as total_count, MAX(timestamp) as latest_ts + FROM platform_stats WHERE timestamp >= :start_time - GROUP BY platform_id - ORDER BY timestamp DESC + GROUP BY platform_id, platform_type + ORDER BY latest_ts DESC """), {"start_time": start_time}, ) - return list(result.scalars().all()) + rows = result.fetchall() + return [ + PlatformStat( + id=0, + platform_id=row.platform_id, + platform_type=row.platform_type, + count=row.total_count, + timestamp=row.latest_ts, + ) + for row in rows + ] + + async def get_platform_stats_time_series( + self, offset_sec: int = 86400 + ) -> list[tuple[int, int]]: + """Get platform statistics time series data grouped by hour.""" + async with self.get_db() as session: + session: AsyncSession + now = datetime.now() + start_time = now - timedelta(seconds=offset_sec) + result = await session.execute( + text(""" + SELECT strftime('%s', datetime(timestamp, 'start of hour')) as hour_ts, SUM(count) as total_count + FROM platform_stats + WHERE timestamp >= :start_time + GROUP BY hour_ts + ORDER BY hour_ts ASC + """), + {"start_time": start_time}, + ) + rows = result.fetchall() + return [(int(row.hour_ts), row.total_count) for row in rows] # ==== # Conversation Management @@ -669,102 +701,6 @@ async def clear_preferences(self, scope, scope_id): ) await session.commit() - # ==== - # Deprecated Methods - # ==== - - def get_base_stats(self, offset_sec=86400): - """Get base statistics within the specified offset in seconds.""" - - async def _inner(): - async with self.get_db() as session: - session: AsyncSession - now = datetime.now() - start_time = now - timedelta(seconds=offset_sec) - result = await session.execute( - select(PlatformStat).where(PlatformStat.timestamp >= start_time), - ) - all_datas = result.scalars().all() - deprecated_stats = DeprecatedStats() - for data in all_datas: - deprecated_stats.platform.append( - DeprecatedPlatformStat( - name=data.platform_id, - count=data.count, - timestamp=int(data.timestamp.timestamp()), - ), - ) - return deprecated_stats - - result = None - - def runner(): - nonlocal result - result = asyncio.run(_inner()) - - t = threading.Thread(target=runner) - t.start() - t.join() - return result - - def get_total_message_count(self): - """Get the total message count from platform statistics.""" - - async def _inner(): - async with self.get_db() as session: - session: AsyncSession - result = await session.execute( - select(func.sum(PlatformStat.count)).select_from(PlatformStat), - ) - total_count = result.scalar_one_or_none() - return total_count if total_count is not None else 0 - - result = None - - def runner(): - nonlocal result - result = asyncio.run(_inner()) - - t = threading.Thread(target=runner) - t.start() - t.join() - return result - - def get_grouped_base_stats(self, offset_sec=86400): - # group by platform_id - async def _inner(): - async with self.get_db() as session: - session: AsyncSession - now = datetime.now() - start_time = now - timedelta(seconds=offset_sec) - result = await session.execute( - select(PlatformStat.platform_id, func.sum(PlatformStat.count)) - .where(PlatformStat.timestamp >= start_time) - .group_by(PlatformStat.platform_id), - ) - grouped_stats = result.all() - deprecated_stats = DeprecatedStats() - for platform_id, count in grouped_stats: - deprecated_stats.platform.append( - DeprecatedPlatformStat( - name=platform_id, - count=count, - timestamp=int(start_time.timestamp()), - ), - ) - return deprecated_stats - - result = None - - def runner(): - nonlocal result - result = asyncio.run(_inner()) - - t = threading.Thread(target=runner) - t.start() - t.join() - return result - # ==== # Platform Session Management # ==== @@ -862,3 +798,99 @@ async def delete_platform_session(self, session_id: str) -> None: col(PlatformSession.session_id) == session_id, ), ) + + # ==== + # Deprecated Methods + # ==== + + def get_base_stats(self, offset_sec=86400): + """Get base statistics within the specified offset in seconds.""" + + async def _inner(): + async with self.get_db() as session: + session: AsyncSession + now = datetime.now() + start_time = now - timedelta(seconds=offset_sec) + result = await session.execute( + select(PlatformStat).where(PlatformStat.timestamp >= start_time), + ) + all_datas = result.scalars().all() + deprecated_stats = DeprecatedStats() + for data in all_datas: + deprecated_stats.platform.append( + DeprecatedPlatformStat( + name=data.platform_id, + count=data.count, + timestamp=int(data.timestamp.timestamp()), + ), + ) + return deprecated_stats + + result = None + + def runner(): + nonlocal result + result = asyncio.run(_inner()) + + t = threading.Thread(target=runner) + t.start() + t.join() + return result + + def get_total_message_count(self): + """Get the total message count from platform statistics.""" + + async def _inner(): + async with self.get_db() as session: + session: AsyncSession + result = await session.execute( + select(func.sum(PlatformStat.count)).select_from(PlatformStat), + ) + total_count = result.scalar_one_or_none() + return total_count if total_count is not None else 0 + + result = None + + def runner(): + nonlocal result + result = asyncio.run(_inner()) + + t = threading.Thread(target=runner) + t.start() + t.join() + return result + + def get_grouped_base_stats(self, offset_sec=86400): + # group by platform_id + async def _inner(): + async with self.get_db() as session: + session: AsyncSession + now = datetime.now() + start_time = now - timedelta(seconds=offset_sec) + result = await session.execute( + select(PlatformStat.platform_id, func.sum(PlatformStat.count)) + .where(PlatformStat.timestamp >= start_time) + .group_by(PlatformStat.platform_id), + ) + grouped_stats = result.all() + deprecated_stats = DeprecatedStats() + for platform_id, count in grouped_stats: + deprecated_stats.platform.append( + DeprecatedPlatformStat( + name=platform_id, + count=count, + timestamp=int(start_time.timestamp()), + ), + ) + return deprecated_stats + + result = None + + def runner(): + nonlocal result + result = asyncio.run(_inner()) + + t = threading.Thread(target=runner) + t.start() + t.join() + return result diff --git a/astrbot/dashboard/routes/stat.py b/astrbot/dashboard/routes/stat.py index 8df690cc2..70139d983 100644 --- a/astrbot/dashboard/routes/stat.py +++ b/astrbot/dashboard/routes/stat.py @@ -84,23 +84,14 @@ async def get_stat(self): offset_sec = request.args.get("offset_sec", 86400) offset_sec = int(offset_sec) try: - stat = self.db_helper.get_base_stats(offset_sec) - now = int(time.time()) - start_time = now - offset_sec - message_time_based_stats = [] - - idx = 0 - for bucket_end in range(start_time, now, 3600): - cnt = 0 - while ( - idx < len(stat.platform) - and stat.platform[idx].timestamp < bucket_end - ): - cnt += stat.platform[idx].count - idx += 1 - message_time_based_stats.append([bucket_end, cnt]) - - stat_dict = stat.__dict__ + platform_stats = await self.db_helper.get_platform_stats(offset_sec) + + # 获取时间序列数据(按小时分桶) + time_series = await self.db_helper.get_platform_stats_time_series( + offset_sec + ) + message_time_based_stats = [[ts, cnt] for ts, cnt in time_series] + message_count = await self.db_helper.count_platform_stats() cpu_percent = psutil.cpu_percent(interval=0.5) thread_count = threading.active_count() @@ -121,28 +112,36 @@ async def get_stat(self): int(time.time()) - self.core_lifecycle.start_time, ) - stat_dict.update( + # 构建平台统计数据 + platform_data = [ { - "platform": self.db_helper.get_grouped_base_stats( - offset_sec, - ).platform, - "message_count": self.db_helper.get_total_message_count() or 0, - "platform_count": len( - self.core_lifecycle.platform_manager.get_insts(), - ), - "plugin_count": len(plugins), - "plugins": plugin_info, - "message_time_series": message_time_based_stats, - "running": running_time, # 现在返回时间组件而不是格式化的字符串 - "memory": { - "process": psutil.Process().memory_info().rss >> 20, - "system": psutil.virtual_memory().total >> 20, - }, - "cpu_percent": round(cpu_percent, 1), - "thread_count": thread_count, - "start_time": self.core_lifecycle.start_time, + "name": stat.platform_id, + "count": stat.count, + "timestamp": int(stat.timestamp.timestamp()) + if stat.timestamp + else 0, + } + for stat in platform_stats + ] + + stat_dict = { + "platform": platform_data, + "message_count": message_count, + "platform_count": len( + self.core_lifecycle.platform_manager.get_insts(), + ), + "plugin_count": len(plugins), + "plugins": plugin_info, + "message_time_series": message_time_based_stats, + "running": running_time, # 现在返回时间组件而不是格式化的字符串 + "memory": { + "process": psutil.Process().memory_info().rss >> 20, + "system": psutil.virtual_memory().total >> 20, }, - ) + "cpu_percent": round(cpu_percent, 1), + "thread_count": thread_count, + "start_time": self.core_lifecycle.start_time, + } return Response().ok(stat_dict).__dict__ except Exception as e: diff --git a/pyproject.toml b/pyproject.toml index 6490b65a3..4f22d5b7e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,6 +60,7 @@ dependencies = [ "markitdown-no-magika[docx,xls,xlsx]>=0.1.2", "xinference-client", "tenacity>=9.1.2", + "aiomysql>=0.3.2", ] [dependency-groups]