Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions astrbot/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os

from pydantic_settings import BaseSettings, SettingsConfigDict

from astrbot.core.config import AstrBotConfig
from astrbot.core.config.default import DB_PATH
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.db.sqlite import BaseDatabase
from astrbot.core.file_token_service import FileTokenService
from astrbot.core.utils.pip_installer import PipInstaller
from astrbot.core.utils.shared_preferences import SharedPreferences
Expand All @@ -14,13 +16,44 @@
# 初始化数据存储文件夹
os.makedirs(get_astrbot_data_path(), exist_ok=True)


class AstrBotMySQLSettings(BaseSettings):
host: str = "localhost"
port: int = 3306
user: str = "root"
password: str = ""
database: str = "astrbot"
charset: str = "utf8mb4"

model_config = SettingsConfigDict(env_file=".env", env_prefix="ASTR_MYSQL_")


def get_db_helper() -> BaseDatabase:
db_type = os.getenv("ASTR_DB_TYPE", "sqlite")
match db_type:
case "sqlite":
from astrbot.core.db.sqlite import SQLiteDatabase

return SQLiteDatabase(DB_PATH)
case "mysql":
from astrbot.core.db.mysql import MySQLDatabase

mysql_settings = AstrBotMySQLSettings()

return MySQLDatabase(**mysql_settings.model_dump())
case _:
from astrbot.core.db.sqlite import SQLiteDatabase

return SQLiteDatabase(DB_PATH)


DEMO_MODE = os.getenv("DEMO_MODE", False)

astrbot_config = AstrBotConfig()
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
html_renderer = HtmlRenderer(t2i_base_url)
logger = LogManager.GetLogger(log_name="astrbot")
db_helper = SQLiteDatabase(DB_PATH)
db_helper = get_db_helper()
# 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
sp = SharedPreferences(db_helper=db_helper)
# 文件令牌服务
Expand Down
19 changes: 18 additions & 1 deletion astrbot/core/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import typing as T
from contextlib import asynccontextmanager
from dataclasses import dataclass
from enum import Enum

from deprecated import deprecated
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
Expand All @@ -20,11 +21,17 @@
)


class DatabaseType(Enum):
SQLITE = "sqlite"
MYSQL = "mysql"


@dataclass
class BaseDatabase(abc.ABC):
"""数据库基类"""

DATABASE_URL = ""
database_type: DatabaseType

def __init__(self) -> None:
self.engine = create_async_engine(
Expand Down Expand Up @@ -83,14 +90,24 @@ async def insert_platform_stats(

@abc.abstractmethod
async def count_platform_stats(self) -> int:
"""Count the number of platform statistics records."""
"""Sum the count of platform statistics records."""
...

@abc.abstractmethod
async def get_platform_stats(self, offset_sec: int = 86400) -> list[PlatformStat]:
"""Get platform statistics within the specified offset in seconds and group by platform_id."""
...

@abc.abstractmethod
async def get_platform_stats_time_series(
self, offset_sec: int = 86400
) -> list[tuple[int, int]]:
"""Get platform statistics time series data grouped by hour.

Returns a list of tuples (hour_timestamp, count) sorted by timestamp ascending.
"""
...

@abc.abstractmethod
async def get_conversations(
self,
Expand Down
6 changes: 5 additions & 1 deletion astrbot/core/db/migration/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from astrbot.api import logger, sp
from astrbot.core.config import AstrBotConfig
from astrbot.core.db import BaseDatabase
from astrbot.core.db import BaseDatabase, DatabaseType
from astrbot.core.utils.astrbot_path import get_astrbot_data_path

from .migra_3_to_4 import (
Expand All @@ -24,6 +24,10 @@ async def check_migration_needed_v4(db_helper: BaseDatabase) -> bool:

if not os.path.exists(data_v3_db):
return False

if db_helper.database_type == DatabaseType.MYSQL:
return False

migration_done = await db_helper.get_preference(
"global",
"global",
Expand Down
2 changes: 1 addition & 1 deletion astrbot/core/db/migration/migra_webchat_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ async def migrate_webchat_session(db_helper: BaseDatabase):
query = (
select(
col(PlatformMessageHistory.user_id),
col(PlatformMessageHistory.sender_name),
func.max(PlatformMessageHistory.sender_name).label("sender_name"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question (bug_risk): 使用 MAX 来聚合 sender_name 能解决 SQL 分组问题,但会产生一个任意的名字;请确认这是否符合预期语义。

func.max 可以满足 ONLY_FULL_GROUP_BY 的要求,但它会选择字典序中最大的那个名字,在一个分组内这基本上是任意的。如果被选中的 sender_name 在语义或 UI 上有意义,建议考虑基于特定行(例如最早消息或最新消息)来派生,而不是依赖 MAX

Original comment in English

question (bug_risk): Aggregating sender_name with MAX fixes SQL grouping but produces an arbitrary name; confirm this matches the intended semantics.

func.max satisfies ONLY_FULL_GROUP_BY but picks the lexicographically greatest name, which is essentially arbitrary within the group. If the chosen sender_name has semantic or UI significance, consider deriving it from a specific row (e.g., earliest or latest message) rather than relying on MAX.

func.min(PlatformMessageHistory.created_at).label("earliest"),
func.max(PlatformMessageHistory.updated_at).label("latest"),
)
Expand Down
Loading