diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 6f3c813eb..1cf572aa8 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -97,7 +97,6 @@ async def step(self): llm_resp_result = None async for llm_response in self._iter_llm_responses(): - assert isinstance(llm_response, LLMResponse) if llm_response.is_chunk: if llm_response.result_chain: yield AgentResponse( diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index b60088609..7f30f44ef 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -1,4 +1,4 @@ -from collections.abc import Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Any, Generic import jsonschema @@ -7,6 +7,8 @@ from pydantic import Field, model_validator from pydantic.dataclasses import dataclass +from astrbot.core.message.message_event_result import MessageEventResult + from .run_context import ContextWrapper, TContext ParametersType = dict[str, Any] @@ -38,7 +40,10 @@ def validate_parameters(self) -> "ToolSchema": class FunctionTool(ToolSchema, Generic[TContext]): """A callable tool, for function calling.""" - handler: Callable[..., Awaitable[Any]] | None = None + handler: ( + Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult, None]] + | None + ) = None """a callable that implements the tool's functionality. It should be an async function.""" handler_module_path: str | None = None diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index 440dea2d1..ed08e90a9 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -185,7 +185,11 @@ async def _execute_mcp( async def call_local_llm_tool( context: ContextWrapper[AstrAgentContext], - handler: T.Callable[..., T.Awaitable[T.Any]], + handler: T.Callable[ + ..., + T.Awaitable[MessageEventResult | mcp.types.CallToolResult | str | None] + | T.AsyncGenerator[MessageEventResult | CommandResult | str | None, None], + ], method_name: str, *args, **kwargs, diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 786d29c81..9477eabaa 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -24,6 +24,10 @@ class AstrBotConfig(dict): - 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。 """ + config_path: str + default_config: dict + schema: dict | None + def __init__( self, config_path: str = ASTRBOT_CONFIG_PATH, diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index e8241f85a..5a8672837 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -197,7 +197,7 @@ def _load(self) -> None: # 把插件中注册的所有协程函数注册到事件总线中并执行 extra_tasks = [] for task in self.star_context._register_tasks: - extra_tasks.append(asyncio.create_task(task, name=task.__name__)) + extra_tasks.append(asyncio.create_task(task, name=task.__name__)) # type: ignore tasks_ = [event_bus_task, *extra_tasks] for task in tasks_: diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 58d1c6a9c..44c69b209 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -5,8 +5,7 @@ from dataclasses import dataclass from deprecated import deprecated -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from astrbot.core.db.po import ( Attachment, @@ -32,7 +31,7 @@ def __init__(self) -> None: echo=False, future=True, ) - self.AsyncSessionLocal = sessionmaker( + self.AsyncSessionLocal = async_sessionmaker( self.engine, class_=AsyncSession, expire_on_commit=False, diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py index a75c60a1b..66b72d5cb 100644 --- a/astrbot/core/db/migration/migra_3_to_4.py +++ b/astrbot/core/db/migration/migra_3_to_4.py @@ -70,6 +70,7 @@ async def migration_conversation_table( logger.info( f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。", ) + continue if ":" not in conv.user_id: continue session = MessageSesion.from_str(session_str=conv.user_id) @@ -207,6 +208,7 @@ async def migration_webchat_data( logger.info( f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。", ) + continue if ":" in conv.user_id: continue platform_id = "webchat" diff --git a/astrbot/core/db/migration/sqlite_v3.py b/astrbot/core/db/migration/sqlite_v3.py index a301028d1..b1a780d48 100644 --- a/astrbot/core/db/migration/sqlite_v3.py +++ b/astrbot/core/db/migration/sqlite_v3.py @@ -127,7 +127,7 @@ def _get_conn(self, db_path: str) -> sqlite3.Connection: conn.text_factory = str return conn - def _exec_sql(self, sql: str, params: tuple = None): + def _exec_sql(self, sql: str, params: tuple | None = None): conn = self.conn try: c = self.conn.cursor() @@ -224,9 +224,11 @@ def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats: c.close() - return Stats(platform, [], []) + return Stats(platform) - def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: + def get_conversation_by_user_id( + self, user_id: str, cid: str + ) -> Conversation | None: try: c = self.conn.cursor() except sqlite3.ProgrammingError: @@ -258,7 +260,7 @@ def new_conversation(self, user_id: str, cid: str): (user_id, cid, history, updated_at, created_at), ) - def get_conversations(self, user_id: str) -> tuple: + def get_conversations(self, user_id: str) -> list[Conversation]: try: c = self.conn.cursor() except sqlite3.ProgrammingError: diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index 3d9947413..34b301c92 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -12,7 +12,7 @@ class PlatformStat(SQLModel, table=True): Note: In astrbot v4, we moved `platform` table to here. """ - __tablename__ = "platform_stats" # type: ignore + __tablename__: str = "platform_stats" id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True}) timestamp: datetime = Field(nullable=False) @@ -31,9 +31,10 @@ class PlatformStat(SQLModel, table=True): class ConversationV2(SQLModel, table=True): - __tablename__ = "conversations" # type: ignore + __tablename__: str = "conversations" - inner_conversation_id: int = Field( + inner_conversation_id: int | None = Field( + default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}, ) @@ -68,7 +69,7 @@ class Persona(SQLModel, table=True): It can be used to customize the behavior of LLMs. """ - __tablename__ = "personas" # type: ignore + __tablename__: str = "personas" id: int | None = Field( primary_key=True, @@ -98,7 +99,7 @@ class Persona(SQLModel, table=True): class Preference(SQLModel, table=True): """This class represents preferences for bots.""" - __tablename__ = "preferences" # type: ignore + __tablename__: str = "preferences" id: int | None = Field( default=None, @@ -134,7 +135,7 @@ class PlatformMessageHistory(SQLModel, table=True): or platform-specific messages. """ - __tablename__ = "platform_message_history" # type: ignore + __tablename__: str = "platform_message_history" id: int | None = Field( primary_key=True, @@ -162,7 +163,7 @@ class PlatformSession(SQLModel, table=True): Each session can have multiple conversations (对话) associated with it. """ - __tablename__ = "platform_sessions" # type: ignore + __tablename__: str = "platform_sessions" inner_id: int | None = Field( primary_key=True, @@ -203,7 +204,7 @@ class Attachment(SQLModel, table=True): Attachments can be images, files, or other media types. """ - __tablename__ = "attachments" # type: ignore + __tablename__: str = "attachments" inner_attachment_id: int | None = Field( primary_key=True, @@ -261,17 +262,17 @@ class Personality(TypedDict): 在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。 """ - prompt: str = "" - name: str = "" - begin_dialogs: list[str] = [] - mood_imitation_dialogs: list[str] = [] + prompt: str + name: str + begin_dialogs: list[str] + mood_imitation_dialogs: list[str] """情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。""" - tools: list[str] | None = None + tools: list[str] | None """工具列表。None 表示使用所有工具,空列表表示不使用任何工具""" # cache - _begin_dialogs_processed: list[dict] = [] - _mood_imitation_dialogs_processed: str = "" + _begin_dialogs_processed: list[dict] + _mood_imitation_dialogs_processed: str # ==== diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index ef2c2ad54..033d076c8 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -3,6 +3,7 @@ import typing as T from datetime import datetime, timedelta, timezone +from sqlalchemy import CursorResult from sqlalchemy.ext.asyncio import AsyncSession from sqlmodel import col, delete, desc, func, or_, select, text, update @@ -489,7 +490,7 @@ async def get_attachments(self, attachment_ids: list[str]) -> list: async with self.get_db() as session: session: AsyncSession query = select(Attachment).where( - Attachment.attachment_id.in_(attachment_ids) + col(Attachment.attachment_id).in_(attachment_ids) ) result = await session.execute(query) return list(result.scalars().all()) @@ -505,7 +506,7 @@ async def delete_attachment(self, attachment_id: str) -> bool: query = delete(Attachment).where( col(Attachment.attachment_id) == attachment_id ) - result = await session.execute(query) + result = T.cast(CursorResult, await session.execute(query)) return result.rowcount > 0 async def delete_attachments(self, attachment_ids: list[str]) -> int: @@ -521,7 +522,7 @@ async def delete_attachments(self, attachment_ids: list[str]) -> int: query = delete(Attachment).where( col(Attachment.attachment_id).in_(attachment_ids) ) - result = await session.execute(query) + result = T.cast(CursorResult, await session.execute(query)) return result.rowcount async def insert_persona( diff --git a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py index 24f1c323c..564454cb1 100644 --- a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py @@ -90,4 +90,6 @@ async def save_index(self): path (str): 保存索引的路径 """ + if self.index is None: + return faiss.write_index(self.index, self.path) diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index 749df753e..0017e65fa 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -27,7 +27,7 @@ def __init__( self, event_queue: Queue, pipeline_scheduler_mapping: dict[str, PipelineScheduler], - astrbot_config_mgr: AstrBotConfigManager = None, + astrbot_config_mgr: AstrBotConfigManager, ): self.event_queue = event_queue # 事件队列 # abconf uuid -> scheduler @@ -40,6 +40,11 @@ async def dispatch(self): conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin) self._print_event(event, conf_info["name"]) scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"]) + if not scheduler: + logger.error( + f"PipelineScheduler not found for id: {conf_info['id']}, event ignored." + ) + continue asyncio.create_task(scheduler.execute(event)) def _print_event(self, event: AstrMessageEvent, conf_name: str): diff --git a/astrbot/core/knowledge_base/retrieval/manager.py b/astrbot/core/knowledge_base/retrieval/manager.py index 9a42cd6cd..746406e90 100644 --- a/astrbot/core/knowledge_base/retrieval/manager.py +++ b/astrbot/core/knowledge_base/retrieval/manager.py @@ -166,7 +166,11 @@ async def retrieve( # 5. Rerank first_rerank = None for kb_id in kb_ids: - vec_db: FaissVecDB = kb_options[kb_id]["vec_db"] + vec_db = kb_options[kb_id]["vec_db"] + if not isinstance(vec_db, FaissVecDB): + logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB") + continue + rerank_pi = kb_options[kb_id]["rerank_provider_id"] if ( vec_db diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 47d6ff781..0e7b3bab6 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -66,6 +66,9 @@ class ComponentType(str, Enum): class BaseMessageComponent(BaseModel): type: ComponentType + def __init__(self, **kwargs): + super().__init__(**kwargs) + def toDict(self): data = {} for k, v in self.__dict__.items(): @@ -551,7 +554,7 @@ class Node(BaseMessageComponent): id: int | None = 0 # 忽略 name: str | None = "" # qq昵称 uin: str | None = "0" # qq号 - content: list[BaseMessageComponent] | None = [] + content: list[BaseMessageComponent] = [] seq: str | list | None = "" # 忽略 time: int | None = 0 # 忽略 @@ -615,7 +618,7 @@ def toDict(self): ret["messages"].append(d) return ret - async def to_dict(self): + async def to_dict(self) -> dict: """将 Nodes 转换为字典格式,适用于 OneBot JSON 格式""" ret = {"messages": []} for node in self.nodes: @@ -714,12 +717,15 @@ async def get_file(self, allow_return_url: bool = False) -> str: if self.url: await self._download_file() - return os.path.abspath(self.file_) + if self.file_: + return os.path.abspath(self.file_) return "" async def _download_file(self): """下载文件""" + if not self.url: + raise ValueError("Download failed: No URL provided in File component.") download_dir = os.path.join(get_astrbot_data_path(), "temp") os.makedirs(download_dir, exist_ok=True) if self.name: diff --git a/astrbot/core/persona_mgr.py b/astrbot/core/persona_mgr.py index 5d1743ab9..b2d2c6be1 100644 --- a/astrbot/core/persona_mgr.py +++ b/astrbot/core/persona_mgr.py @@ -98,8 +98,8 @@ async def create_persona( self, persona_id: str, system_prompt: str, - begin_dialogs: list[str] = None, - tools: list[str] = None, + begin_dialogs: list[str] | None = None, + tools: list[str] | None = None, ) -> Persona: """创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具""" if await self.db.get_persona_by_id(persona_id): diff --git a/astrbot/core/pipeline/content_safety_check/stage.py b/astrbot/core/pipeline/content_safety_check/stage.py index c477cc23a..b089c48e0 100644 --- a/astrbot/core/pipeline/content_safety_check/stage.py +++ b/astrbot/core/pipeline/content_safety_check/stage.py @@ -24,7 +24,7 @@ async def process( self, event: AstrMessageEvent, check_text: str | None = None, - ) -> None | AsyncGenerator[None, None]: + ) -> AsyncGenerator[None, None]: """检查内容安全""" text = check_text if check_text else event.get_message_str() ok, info = self.strategy_selector.check(text) diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py index 73d28c5d1..1f5ba43a0 100644 --- a/astrbot/core/pipeline/context_utils.py +++ b/astrbot/core/pipeline/context_utils.py @@ -11,7 +11,7 @@ async def call_handler( event: AstrMessageEvent, - handler: T.Callable[..., T.Awaitable[T.Any]], + handler: T.Callable[..., T.Awaitable[T.Any] | T.AsyncGenerator[T.Any, None]], *args, **kwargs, ) -> T.AsyncGenerator[T.Any, None]: @@ -91,6 +91,7 @@ async def call_event_hook( ) for handler in handlers: try: + assert inspect.iscoroutinefunction(handler.handler) logger.debug( f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}", ) diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index 56d305de4..00a89f9ab 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -24,7 +24,7 @@ async def initialize(self, ctx: PipelineContext) -> None: async def process( self, event: AstrMessageEvent, - ) -> AsyncGenerator[None, None]: + ) -> AsyncGenerator[Any, None]: activated_handlers: list[StarHandlerMetadata] = event.get_extra( "activated_handlers", ) diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index e19b8dc18..076f7f12a 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -60,7 +60,7 @@ async def process( ): # 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀 if ( - event.get_result() and not event.get_result().is_stopped() + event.get_result() and not event.is_stopped() ) or not event.get_result(): async for _ in self.agent_sub_stage.process(event): yield diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 760649563..8f1b87efc 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -117,7 +117,9 @@ def is_seg_reply_required(self, event: AstrMessageEvent) -> bool: if not self.enable_seg: return False - if self.only_llm_result and not event.get_result().is_llm_result(): + if (result := event.get_result()) is None: + return False + if self.only_llm_result and result.is_llm_result(): return False if event.get_platform_name() in [ @@ -185,7 +187,7 @@ async def process( if isinstance(component, Comp.File) and component.file: # 支持 File 消息段的路径映射。 component.file = path_Mapping(mappings, component.file) - event.get_result().chain[idx] = component + result.chain[idx] = component # 检查消息链是否为空 try: diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index ef394edcf..2d2e15d69 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -6,6 +6,7 @@ from astrbot.core import file_token_service, html_renderer, logger from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply from astrbot.core.message.message_event_result import ResultContentType +from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.message_type import MessageType from astrbot.core.star.session_llm_manager import SessionServiceManager @@ -93,11 +94,13 @@ async def process( for comp in result.chain: if isinstance(comp, Plain): text += comp.text - async for _ in self.content_safe_check_stage.process( - event, - check_text=text, - ): - yield + + if isinstance(self.content_safe_check_stage, ContentSafetyCheckStage): + async for _ in self.content_safe_check_stage.process( + event, + check_text=text, + ): + yield # 发送消息前事件钩子 handlers = star_handlers_registry.get_handlers_by_event_type( @@ -114,7 +117,8 @@ async def process( "启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作", ) await handler.handler(event) - if event.get_result() is None or not event.get_result().chain: + + if (result := event.get_result()) is None or not result.chain: logger.debug( f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。", ) diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index 5c461a1e1..5fb3034f5 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -2,6 +2,10 @@ from astrbot.core import logger from astrbot.core.platform import AstrMessageEvent +from astrbot.core.platform.sources.webchat.webchat_event import WebChatMessageEvent +from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import ( + WecomAIBotMessageEvent, +) from . import STAGES_ORDER from .context import PipelineContext @@ -78,7 +82,7 @@ async def execute(self, event: AstrMessageEvent): await self._process_stages(event) # 如果没有发送操作, 则发送一个空消息, 以便于后续的处理 - if event.get_platform_name() in ["webchat", "wecom_ai_bot"]: + if isinstance(event, (WebChatMessageEvent, WecomAIBotMessageEvent)): await event.send(None) logger.debug("pipeline 执行完毕。") diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 6402aeaed..f6eda07a9 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -153,7 +153,9 @@ def get_sender_id(self) -> str: def get_sender_name(self) -> str: """获取消息发送者的名称。(可能会返回空字符串)""" - return self.message_obj.sender.nickname + if isinstance(self.message_obj.sender.nickname, str): + return self.message_obj.sender.nickname + return "" def set_extra(self, key, value): """设置额外的信息。""" @@ -270,7 +272,7 @@ def should_call_llm(self, call_llm: bool): """ self.call_llm = call_llm - def get_result(self) -> MessageEventResult: + def get_result(self) -> MessageEventResult | None: """获取消息事件的结果。""" return self._result @@ -320,7 +322,7 @@ def request_llm( self, prompt: str, func_tool_manager=None, - session_id: str = None, + session_id: str = "", image_urls: list[str] | None = None, contexts: list | None = None, system_prompt: str = "", diff --git a/astrbot/core/platform/astrbot_message.py b/astrbot/core/platform/astrbot_message.py index 0ada18506..253963322 100644 --- a/astrbot/core/platform/astrbot_message.py +++ b/astrbot/core/platform/astrbot_message.py @@ -54,7 +54,7 @@ class AstrBotMessage: self_id: str # 机器人的识别id session_id: str # 会话id。取决于 unique_session 的设置。 message_id: str # 消息id - group: Group # 群组 + group: Group | None # 群组 sender: MessageMember # 发送者 message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式 message_str: str # 最直观的纯文本消息字符串 @@ -78,7 +78,7 @@ def group_id(self) -> str: return "" @group_id.setter - def group_id(self, value: str): + def group_id(self, value: str | None): """设置 group_id""" if value: if self.group: diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index 2b3a87d8b..c139b8bd7 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -1,7 +1,7 @@ import abc import uuid from asyncio import Queue -from collections.abc import Awaitable +from collections.abc import Coroutine from dataclasses import dataclass, field from datetime import datetime from enum import Enum @@ -100,7 +100,7 @@ def get_stats(self) -> dict: } @abc.abstractmethod - def run(self) -> Awaitable[Any]: + def run(self) -> Coroutine[Any, Any, None]: """得到一个平台的运行实例,需要返回一个协程对象。""" raise NotImplementedError @@ -116,7 +116,7 @@ async def send_by_session( self, session: MessageSesion, message_chain: MessageChain, - ): + ) -> None: """通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。 异步方法。 diff --git a/astrbot/core/platform/platform_metadata.py b/astrbot/core/platform/platform_metadata.py index c63bd82b1..06455aac4 100644 --- a/astrbot/core/platform/platform_metadata.py +++ b/astrbot/core/platform/platform_metadata.py @@ -7,7 +7,7 @@ class PlatformMetadata: """平台的名称,即平台的类型,如 aiocqhttp, discord, slack""" description: str """平台的描述""" - id: str | None = None + id: str """平台的唯一标识符,用于配置中识别特定平台""" default_config_tmpl: dict | None = None diff --git a/astrbot/core/platform/register.py b/astrbot/core/platform/register.py index c1721c5c5..5f550ecd1 100644 --- a/astrbot/core/platform/register.py +++ b/astrbot/core/platform/register.py @@ -40,6 +40,7 @@ def decorator(cls): pm = PlatformMetadata( name=adapter_name, description=desc, + id=adapter_name, default_config_tmpl=default_config_tmpl, adapter_display_name=adapter_display_name, logo_path=logo_path, diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index ce8fd56df..293b462d3 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -70,16 +70,18 @@ async def _dispatch_send( bot: CQHttp, event: Event | None, is_group: bool, - session_id: str, + session_id: str | None, messages: list[dict], ): # session_id 必须是纯数字字符串 - session_id = int(session_id) if session_id.isdigit() else None + session_id_int = ( + int(session_id) if session_id and session_id.isdigit() else None + ) - if is_group and isinstance(session_id, int): - await bot.send_group_msg(group_id=session_id, message=messages) - elif not is_group and isinstance(session_id, int): - await bot.send_private_msg(user_id=session_id, message=messages) + if is_group and isinstance(session_id_int, int): + await bot.send_group_msg(group_id=session_id_int, message=messages) + elif not is_group and isinstance(session_id_int, int): + await bot.send_private_msg(user_id=session_id_int, message=messages) elif isinstance(event, Event): # 最后兜底 await bot.send(event=event, message=messages) else: diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index bfefa2f68..b3c2229ab 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -4,7 +4,7 @@ import time import uuid from collections.abc import Awaitable -from typing import Any +from typing import Any, cast from aiocqhttp import CQHttp, Event from aiocqhttp.exceptions import ActionFailed @@ -48,7 +48,7 @@ def __init__( self.metadata = PlatformMetadata( name="aiocqhttp", description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。", - id=self.config.get("id"), + id=cast(str, self.config.get("id")), support_streaming_message=False, ) @@ -127,7 +127,9 @@ async def _convert_handle_request_event(self, event: Event) -> AstrBotMessage: """OneBot V11 请求类事件""" abm = AstrBotMessage() abm.self_id = str(event.self_id) - abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id) + abm.sender = MessageMember( + user_id=str(event.user_id), nickname=str(event.user_id) + ) abm.type = MessageType.OTHER_MESSAGE if event.get("group_id"): abm.type = MessageType.GROUP_MESSAGE @@ -194,6 +196,7 @@ async def _convert_handle_message_event( @param event: 事件对象 @param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。 """ + assert event.sender is not None abm = AstrBotMessage() abm.self_id = str(event.self_id) abm.sender = MessageMember( @@ -203,6 +206,7 @@ async def _convert_handle_message_event( if event["message_type"] == "group": abm.type = MessageType.GROUP_MESSAGE abm.group_id = str(event.group_id) + abm.group = Group(str(event.group_id)) abm.group.group_name = event.get("group_name", "N/A") elif event["message_type"] == "private": abm.type = MessageType.FRIEND_MESSAGE @@ -228,7 +232,7 @@ async def _convert_handle_message_event( await self.bot.send(event, err) except BaseException as e: logger.error(f"回复消息失败: {e}") - return None + raise ValueError(err) # 按消息段类型类型适配 for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]): diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py index 8ccbf8b9a..8905698a5 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py @@ -2,6 +2,7 @@ import os import threading import uuid +from typing import cast import aiohttp import dingtalk_stream @@ -54,12 +55,14 @@ def __init__( self.client_id = platform_config["client_id"] self.client_secret = platform_config["client_secret"] + outer_self = self + class AstrCallbackClient(dingtalk_stream.ChatbotHandler): - async def process(self_, message: dingtalk_stream.CallbackMessage): + async def process(self, message: dingtalk_stream.CallbackMessage): logger.debug(f"dingtalk: {message.data}") im = dingtalk_stream.ChatbotMessage.from_dict(message.data) - abm = await self.convert_msg(im) - await self.handle_msg(abm) + abm = await outer_self.convert_msg(im) + await outer_self.handle_msg(abm) return AckMessage.STATUS_OK, "OK" @@ -73,6 +76,7 @@ async def process(self_, message: dingtalk_stream.CallbackMessage): self.client, ) self.client_ = client # 用于 websockets 的 client + self._shutdown_event: threading.Event | None = None def _id_to_sid(self, dingtalk_id: str | None) -> str: if not dingtalk_id: @@ -93,7 +97,7 @@ def meta(self) -> PlatformMetadata: return PlatformMetadata( name="dingtalk", description="钉钉机器人官方 API 适配器", - id=self.config.get("id"), + id=cast(str, self.config.get("id")), support_streaming_message=False, ) @@ -104,7 +108,7 @@ async def convert_msg( abm = AstrBotMessage() abm.message = [] abm.message_str = "" - abm.timestamp = int(message.create_at / 1000) + abm.timestamp = int(cast(int, message.create_at) / 1000) abm.type = ( MessageType.GROUP_MESSAGE if message.conversation_type == "2" @@ -115,7 +119,7 @@ async def convert_msg( nickname=message.sender_nick, ) abm.self_id = self._id_to_sid(message.chatbot_user_id) - abm.message_id = message.message_id + abm.message_id = cast(str, message.message_id) abm.raw_message = message if abm.type == MessageType.GROUP_MESSAGE: @@ -132,14 +136,16 @@ async def convert_msg( else: abm.session_id = abm.sender.user_id - message_type: str = message.message_type + message_type: str = cast(str, message.message_type) match message_type: case "text": abm.message_str = message.text.content.strip() abm.message.append(Plain(abm.message_str)) case "richText": - rtc: dingtalk_stream.RichTextContent = message.rich_text_content - contents: list[dict] = rtc.rich_text_list + rtc: dingtalk_stream.RichTextContent = cast( + dingtalk_stream.RichTextContent, message.rich_text_content + ) + contents: list[dict] = cast(list[dict], rtc.rich_text_list) for content in contents: plains = "" if "text" in content: @@ -148,7 +154,7 @@ async def convert_msg( elif "type" in content and content["type"] == "picture": f_path = await self.download_ding_file( content["downloadCode"], - message.robot_code, + cast(str, message.robot_code), "jpg", ) abm.message.append(Image.fromFileSystem(f_path)) @@ -193,7 +199,7 @@ async def download_ding_file( logger.error( f"下载钉钉文件失败: {resp.status}, {await resp.text()}", ) - return None + return "" resp_data = await resp.json() download_url = resp_data["data"]["downloadUrl"] await download_file(download_url, f_path) @@ -213,7 +219,7 @@ async def get_access_token(self) -> str: logger.error( f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}", ) - return None + return "" return (await resp.json())["data"]["accessToken"] async def handle_msg(self, abm: AstrBotMessage): @@ -250,9 +256,11 @@ async def terminate(self): def monkey_patch_close(): raise KeyboardInterrupt("Graceful shutdown") - self.client_.open_connection = monkey_patch_close - await self.client_.websocket.close(code=1000, reason="Graceful shutdown") - self._shutdown_event.set() + if self.client_.websocket is not None: + self.client_.open_connection = monkey_patch_close + await self.client_.websocket.close(code=1000, reason="Graceful shutdown") + if self._shutdown_event is not None: + self._shutdown_event.set() def get_client(self): return self.client diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py index a1cd9c1aa..d520189d8 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py @@ -1,4 +1,5 @@ import asyncio +from typing import cast import dingtalk_stream @@ -32,7 +33,7 @@ async def send_with_client( client.reply_markdown, segment.text, segment.text, - self.message_obj.raw_message, + cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message), ) elif isinstance(segment, Comp.Image): markdown_str = "" @@ -53,7 +54,9 @@ async def send_with_client( client.reply_markdown, "😄", markdown_str, - self.message_obj.raw_message, + cast( + dingtalk_stream.ChatbotMessage, self.message_obj.raw_message + ), ) logger.debug(f"send image: {ret}") diff --git a/astrbot/core/platform/sources/discord/client.py b/astrbot/core/platform/sources/discord/client.py index 5d29e3429..ac0610f2a 100644 --- a/astrbot/core/platform/sources/discord/client.py +++ b/astrbot/core/platform/sources/discord/client.py @@ -1,4 +1,5 @@ import sys +from collections.abc import Awaitable, Callable import discord @@ -27,13 +28,16 @@ def __init__(self, token: str, proxy: str | None = None): super().__init__(intents=intents, proxy=proxy) # 回调函数 - self.on_message_received = None - self.on_ready_once_callback = None + self.on_message_received: Callable[[dict], Awaitable[None]] | None = None + self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None self._ready_once_fired = False - @override async def on_ready(self): """当机器人成功连接并准备就绪时触发""" + if self.user is None: + logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)") + return + logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录") logger.info("[Discord] 客户端已准备就绪。") @@ -49,6 +53,9 @@ async def on_ready(self): def _create_message_data(self, message: discord.Message) -> dict: """从 discord.Message 创建数据字典""" + if self.user is None: + raise RuntimeError("Bot is not ready: self.user is None") + is_mentioned = self.user in message.mentions return { "message": message, @@ -66,6 +73,12 @@ def _create_message_data(self, message: discord.Message) -> dict: def _create_interaction_data(self, interaction: discord.Interaction) -> dict: """从 discord.Interaction 创建数据字典""" + if self.user is None: + raise RuntimeError("Bot is not ready: self.user is None") + + if interaction.user is None: + raise ValueError("Interaction received without a valid user") + return { "interaction": interaction, "bot_id": str(self.user.id), @@ -80,7 +93,6 @@ def _create_interaction_data(self, interaction: discord.Interaction) -> dict: "type": "interaction", } - @override async def on_message(self, message: discord.Message): """当接收到消息时触发""" if message.author.bot: diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py index d3e69e763..f875652a0 100644 --- a/astrbot/core/platform/sources/discord/components.py +++ b/astrbot/core/platform/sources/discord/components.py @@ -97,8 +97,8 @@ class DiscordView(BaseMessageComponent): def __init__( self, - components: list[BaseMessageComponent] = None, - timeout: float = None, + components: list[BaseMessageComponent] | None = None, + timeout: float | None = None, ): self.components = components or [] self.timeout = timeout diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index 17002c06f..50aa0fe6f 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -1,10 +1,10 @@ import asyncio import re import sys -from typing import Any +from typing import Any, cast import discord -from discord.abc import Messageable +from discord.abc import GuildChannel, Messageable, PrivateChannel from discord.channel import DMChannel from astrbot import logger @@ -46,7 +46,7 @@ def __init__( ) -> None: super().__init__(platform_config, event_queue) self.settings = platform_settings - self.client_self_id = None + self.client_self_id: str | None = None self.registered_handlers = [] # 指令注册相关 self.enable_command_register = self.config.get("discord_command_register", True) @@ -62,6 +62,12 @@ async def send_by_session( message_chain: MessageChain, ): """通过会话发送消息""" + if self.client.user is None: + logger.error( + "[Discord] 客户端未就绪 (self.client.user is None),无法发送消息" + ) + return + # 创建一个 message_obj 以便在 event 中使用 message_obj = AstrBotMessage() if "_" in session.session_id: @@ -89,7 +95,7 @@ async def send_by_session( user_id=str(self.client_self_id), nickname=self.client.user.display_name, ) - message_obj.self_id = self.client_self_id + message_obj.self_id = cast(str, self.client_self_id) message_obj.session_id = session.session_id message_obj.message = message_chain.chain @@ -110,7 +116,7 @@ def meta(self) -> PlatformMetadata: return PlatformMetadata( "discord", "Discord 适配器", - id=self.config.get("id"), + id=cast(str, self.config.get("id")), default_config_tmpl=self.config, support_streaming_message=False, ) @@ -160,7 +166,7 @@ async def callback(): def _get_message_type( self, - channel: Messageable, + channel: Messageable | GuildChannel | PrivateChannel, guild_id: int | None = None, ) -> MessageType: """根据 channel 对象和 guild_id 判断消息类型""" @@ -170,13 +176,15 @@ def _get_message_type( return MessageType.FRIEND_MESSAGE return MessageType.GROUP_MESSAGE - def _get_channel_id(self, channel: Messageable) -> str: + def _get_channel_id( + self, channel: Messageable | GuildChannel | PrivateChannel + ) -> str: """根据 channel 对象获取ID""" return str(getattr(channel, "id", None)) def _convert_message_to_abm(self, data: dict) -> AstrBotMessage: """将普通消息转换为 AstrBotMessage""" - message: discord.Message = data["message"] + message = data["message"] content = message.content @@ -233,7 +241,7 @@ def _convert_message_to_abm(self, data: dict) -> AstrBotMessage: ) abm.message = message_chain abm.raw_message = message - abm.self_id = self.client_self_id + abm.self_id = cast(str, self.client_self_id) abm.session_id = str(message.channel.id) abm.message_id = str(message.id) return abm @@ -254,32 +262,52 @@ async def handle_msg(self, message: AstrBotMessage, followup_webhook=None): interaction_followup_webhook=followup_webhook, ) + if self.client.user is None: + logger.error( + "[Discord] 客户端未就绪 (self.client.user is None),无法处理消息" + ) + return + # 检查是否为斜杠指令 is_slash_command = message_event.interaction_followup_webhook is not None + # 1. 优先处理斜杠指令 + if is_slash_command: + message_event.is_wake = True + message_event.is_at_or_wake_command = True + self.commit_event(message_event) + return + + # 2. 处理普通消息(提及检测) + # 确保 raw_message 是 discord.Message 类型,以便静态检查通过 + raw_message = message.raw_message + if not isinstance(raw_message, discord.Message): + logger.warning( + f"[Discord] 收到非 Message 类型的消息: {type(raw_message)},已忽略。" + ) + return + # 检查是否被@(User Mention 或 Bot 拥有的 Role Mention) is_mention = False + # User Mention - if ( - self.client - and self.client.user - and hasattr(message.raw_message, "mentions") - ): - if self.client.user in message.raw_message.mentions: - is_mention = True + # 此时 Pylance 知道 raw_message 是 discord.Message,具有 mentions 属性 + if self.client.user in raw_message.mentions: + is_mention = True + # Role Mention(Bot 拥有的角色被提及) - if not is_mention and hasattr(message.raw_message, "role_mentions"): + if not is_mention and raw_message.role_mentions: bot_member = None - if hasattr(message.raw_message, "guild") and message.raw_message.guild: + if raw_message.guild: try: - bot_member = message.raw_message.guild.get_member( + bot_member = raw_message.guild.get_member( self.client.user.id, ) except Exception: bot_member = None if bot_member and hasattr(bot_member, "roles"): bot_roles = set(bot_member.roles) - mentioned_roles = set(message.raw_message.role_mentions) + mentioned_roles = set(raw_message.role_mentions) if ( bot_roles and mentioned_roles @@ -287,8 +315,8 @@ async def handle_msg(self, message: AstrBotMessage, followup_webhook=None): ): is_mention = True - # 如果是斜杠指令或被@的消息,设置为唤醒状态 - if is_slash_command or is_mention: + # 如果是被@的消息,设置为唤醒状态 + if is_mention: message_event.is_wake = True message_event.is_at_or_wake_command = True @@ -424,7 +452,7 @@ async def dynamic_callback( ) abm.message = [Plain(text=message_str_for_filter)] abm.raw_message = ctx.interaction - abm.self_id = self.client_self_id + abm.self_id = cast(str, self.client_self_id) abm.session_id = str(ctx.channel_id) abm.message_id = str(ctx.interaction.id) @@ -437,7 +465,7 @@ async def dynamic_callback( def _extract_command_info( event_filter: Any, handler_metadata: StarHandlerMetadata, - ) -> tuple[str, str, CommandFilter] | None: + ) -> tuple[str, str, CommandFilter | None] | None: """从事件过滤器中提取指令信息""" cmd_name = None # is_group = False diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index 82eb9f144..053018225 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -4,8 +4,10 @@ from collections.abc import AsyncGenerator from io import BytesIO from pathlib import Path +from typing import cast import discord +from discord.types.interactions import ComponentInteractionData from astrbot import logger from astrbot.api.event import AstrMessageEvent, MessageChain @@ -85,6 +87,9 @@ async def send(self, message: MessageChain): channel = await self._get_channel() if not channel: return + if not isinstance(channel, discord.abc.Messageable): + logger.error(f"[Discord] 频道 {channel.id} 不是可发送消息的类型") + return await channel.send(**kwargs) except Exception as e: @@ -107,7 +112,9 @@ async def send_streaming( await self.send(buffer) return await super().send_streaming(generator, use_fallback) - async def _get_channel(self) -> discord.abc.Messageable | None: + async def _get_channel( + self, + ) -> discord.Thread | discord.abc.GuildChannel | discord.abc.PrivateChannel | None: """获取当前事件对应的频道对象""" try: channel_id = int(self.session_id) @@ -121,7 +128,13 @@ async def _get_channel(self) -> discord.abc.Messageable | None: async def _parse_to_discord( self, message: MessageChain, - ) -> tuple[str, list[discord.File], discord.ui.View | None, list[discord.Embed]]: + ) -> tuple[ + str, + list[discord.File], + discord.ui.View | None, + list[discord.Embed], + str | int | None, + ]: """将 MessageChain 解析为 Discord 发送所需的内容""" content_parts = [] files = [] @@ -261,7 +274,9 @@ async def react(self, emoji: str): self.message_obj.raw_message, "add_reaction", ): - await self.message_obj.raw_message.add_reaction(emoji) + await cast(discord.Message, self.message_obj.raw_message).add_reaction( + emoji + ) except Exception as e: logger.error(f"[Discord] 添加反应失败: {e}") @@ -270,7 +285,7 @@ def is_slash_command(self) -> bool: return ( hasattr(self.message_obj, "raw_message") and hasattr(self.message_obj.raw_message, "type") - and self.message_obj.raw_message.type + and cast(discord.Interaction, self.message_obj.raw_message).type == discord.InteractionType.application_command ) @@ -279,14 +294,18 @@ def is_button_interaction(self) -> bool: return ( hasattr(self.message_obj, "raw_message") and hasattr(self.message_obj.raw_message, "type") - and self.message_obj.raw_message.type == discord.InteractionType.component + and cast(discord.Interaction, self.message_obj.raw_message).type + == discord.InteractionType.component ) def get_interaction_custom_id(self) -> str: """获取交互组件的custom_id""" if self.is_button_interaction(): try: - return self.message_obj.raw_message.data.get("custom_id", "") + return cast( + ComponentInteractionData, + cast(discord.Interaction, self.message_obj.raw_message).data, + ).get("custom_id", "") except Exception: pass return "" @@ -299,7 +318,9 @@ def is_mentioned(self) -> bool: ): return any( mention.id == int(self.message_obj.self_id) - for mention in self.message_obj.raw_message.mentions + for mention in cast( + discord.Message, self.message_obj.raw_message + ).mentions ) return False @@ -309,5 +330,5 @@ def get_mention_clean_content(self) -> str: self.message_obj.raw_message, "clean_content", ): - return self.message_obj.raw_message.clean_content + return cast(discord.Message, self.message_obj.raw_message).clean_content return self.message_str diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py index 59626f78d..473be096f 100644 --- a/astrbot/core/platform/sources/lark/lark_adapter.py +++ b/astrbot/core/platform/sources/lark/lark_adapter.py @@ -3,9 +3,14 @@ import json import re import uuid +from typing import cast import lark_oapi as lark -from lark_oapi.api.im.v1 import * +from lark_oapi.api.im.v1 import ( + CreateMessageRequest, + CreateMessageRequestBody, + GetMessageResourceRequest, +) import astrbot.api.message_components as Comp from astrbot import logger @@ -74,6 +79,10 @@ async def send_by_session( session: MessageSesion, message_chain: MessageChain, ): + if self.lark_api.im is None: + logger.error("[Lark] API Client im 模块未初始化,无法发送消息") + return + res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api) wrapped = { "zh_cn": { @@ -114,14 +123,21 @@ def meta(self) -> PlatformMetadata: return PlatformMetadata( name="lark", description="飞书机器人官方 API 适配器", - id=self.config.get("id"), + id=cast(str, self.config.get("id")), support_streaming_message=False, ) async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1): + if event.event is None: + logger.debug("[Lark] 收到空事件(event.event is None)") + return message = event.event.message + if message is None: + logger.debug("[Lark] 事件中没有消息体(message is None)") + return + abm = AstrBotMessage() - abm.timestamp = int(message.create_time) / 1000 + abm.timestamp = cast(int, message.create_time) // 1000 abm.message = [] abm.type = ( MessageType.GROUP_MESSAGE @@ -136,14 +152,28 @@ async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1): at_list = {} if message.mentions: for m in message.mentions: - at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name) + if m.id is None: + continue + # 飞书 open_id 可能是 None,这里做个防护 + open_id = m.id.open_id if m.id.open_id else "" + at_list[m.key] = Comp.At(qq=open_id, name=m.name) + if m.name == self.bot_name: - abm.self_id = m.id.open_id + if m.id.open_id is not None: + abm.self_id = m.id.open_id + + if message.content is None: + logger.warning("[Lark] 消息内容为空") + return - content_json_b = json.loads(message.content) + try: + content_json_b = json.loads(message.content) + except json.JSONDecodeError: + logger.error(f"[Lark] 解析消息内容失败: {message.content}") + return if message.message_type == "text": - message_str_raw = content_json_b["text"] # 带有 @ 的消息 + message_str_raw = content_json_b.get("text", "") # 带有 @ 的消息 at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则 # at_users = re.findall(at_pattern, message_str_raw) # 拆分文本,去掉AT符号部分 @@ -168,27 +198,47 @@ async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1): content_json_b = _ls elif message.message_type == "image": content_json_b = [ - {"tag": "img", "image_key": content_json_b["image_key"], "style": []}, + { + "tag": "img", + "image_key": content_json_b.get("image_key"), + "style": [], + }, ] if message.message_type in ("post", "image"): for comp in content_json_b: - if comp["tag"] == "at": - abm.message.append(at_list[comp["user_id"]]) - elif comp["tag"] == "text" and comp["text"].strip(): + if comp.get("tag") == "at": + user_id = comp.get("user_id") + if user_id in at_list: + abm.message.append(at_list[user_id]) + elif comp.get("tag") == "text" and comp.get("text", "").strip(): abm.message.append(Comp.Plain(comp["text"].strip())) - elif comp["tag"] == "img": - image_key = comp["image_key"] + elif comp.get("tag") == "img": + image_key = comp.get("image_key") + if not image_key: + continue + request = ( GetMessageResourceRequest.builder() - .message_id(message.message_id) + .message_id(cast(str, message.message_id)) .file_key(image_key) .type("image") .build() ) + + if self.lark_api.im is None: + logger.error("[Lark] API Client im 模块未初始化") + continue + response = await self.lark_api.im.v1.message_resource.aget(request) if not response.success(): logger.error(f"无法下载飞书图片: {image_key}") + continue + + if response.file is None: + logger.error(f"飞书图片响应中不包含文件流: {image_key}") + continue + image_bytes = response.file.read() image_base64 = base64.b64encode(image_bytes).decode() abm.message.append(Comp.Image.fromBase64(image_base64)) @@ -196,6 +246,19 @@ async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1): for comp in abm.message: if isinstance(comp, Comp.Plain): abm.message_str += comp.text + + if message.message_id is None: + logger.error("[Lark] 消息缺少 message_id") + return + + if ( + event.event.sender is None + or event.event.sender.sender_id is None + or event.event.sender.sender_id.open_id is None + ): + logger.error("[Lark] 消息发送者信息不完整") + return + abm.message_id = message.message_id abm.raw_message = message abm.sender = MessageMember( @@ -235,5 +298,5 @@ async def terminate(self): await self.client._disconnect() logger.info("飞书(Lark) 适配器已被优雅地关闭") - def get_client(self) -> lark.Client: + def get_client(self) -> lark.ws.Client: return self.client diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index 04204d35e..7b7d20b38 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -5,7 +5,15 @@ from io import BytesIO import lark_oapi as lark -from lark_oapi.api.im.v1 import * +from lark_oapi.api.im.v1 import ( + CreateImageRequest, + CreateImageRequestBody, + CreateMessageReactionRequest, + CreateMessageReactionRequestBody, + Emoji, + ReplyMessageRequest, + ReplyMessageRequestBody, +) from astrbot import logger from astrbot.api.event import AstrMessageEvent, MessageChain @@ -44,7 +52,7 @@ async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> l file_path = comp.file.replace("file:///", "") elif comp.file and comp.file.startswith("http"): image_file_path = await download_image_by_url(comp.file) - file_path = image_file_path + file_path = image_file_path if image_file_path else "" elif comp.file and comp.file.startswith("base64://"): base64_str = comp.file.removeprefix("base64://") image_data = base64.b64decode(base64_str) @@ -54,10 +62,17 @@ async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> l with open(file_path, "wb") as f: f.write(BytesIO(image_data).getvalue()) else: - file_path = comp.file + file_path = comp.file if comp.file else "" if image_file is None: - image_file = open(file_path, "rb") + if not file_path: + logger.error("[Lark] 图片路径为空,无法上传") + continue + try: + image_file = open(file_path, "rb") + except Exception as e: + logger.error(f"[Lark] 无法打开图片文件: {e}") + continue request = ( CreateImageRequest.builder() @@ -69,9 +84,20 @@ async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> l ) .build() ) + + if lark_client.im is None: + logger.error("[Lark] API Client im 模块未初始化,无法上传图片") + continue + response = await lark_client.im.v1.image.acreate(request) if not response.success(): logger.error(f"无法上传飞书图片({response.code}): {response.msg}") + continue + + if response.data is None: + logger.error("[Lark] 上传图片成功但未返回数据(data is None)") + continue + image_key = response.data.image_key logger.debug(image_key) ret.append(_stage) @@ -107,6 +133,10 @@ async def send(self, message: MessageChain): .build() ) + if self.bot.im is None: + logger.error("[Lark] API Client im 模块未初始化,无法回复消息") + return + response = await self.bot.im.v1.message.areply(request) if not response.success(): @@ -115,6 +145,10 @@ async def send(self, message: MessageChain): await super().send(message) async def react(self, emoji: str): + if self.bot.im is None: + logger.error("[Lark] API Client im 模块未初始化,无法发送表情") + return + request = ( CreateMessageReactionRequest.builder() .message_id(self.message_obj.message_id) @@ -125,6 +159,7 @@ async def react(self, emoji: str): ) .build() ) + response = await self.bot.im.v1.message_reaction.acreate(request) if not response.success(): logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}") diff --git a/astrbot/core/platform/sources/misskey/misskey_adapter.py b/astrbot/core/platform/sources/misskey/misskey_adapter.py index 528ef8122..7f3db3062 100644 --- a/astrbot/core/platform/sources/misskey/misskey_adapter.py +++ b/astrbot/core/platform/sources/misskey/misskey_adapter.py @@ -1,7 +1,6 @@ import asyncio import os import random -from collections.abc import Awaitable from typing import Any import astrbot.api.message_components as Comp @@ -203,7 +202,7 @@ def _process_poll_data( if not isinstance(message.raw_message, dict): message.raw_message = {} message.raw_message["poll"] = poll - message.poll = poll + message.__setattr__("poll", poll) except Exception: pass @@ -372,7 +371,7 @@ async def send_by_session( self, session: MessageSession, message_chain: MessageChain, - ) -> Awaitable[Any]: + ) -> None: if not self.api: logger.error("[Misskey] API 客户端未初始化") return await super().send_by_session(session, message_chain) diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index 4bc474c13..d693c4206 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -3,6 +3,7 @@ import os import random import uuid +from typing import cast import aiofiles import botpy @@ -60,7 +61,10 @@ async def send_streaming(self, generator, use_fallback: bool = False): time_since_last_edit = current_time - last_edit_time if time_since_last_edit >= throttle_interval: - ret = await self._post_send(stream=stream_payload) + ret = cast( + message.Message, + await self._post_send(stream=stream_payload), + ) stream_payload["index"] += 1 stream_payload["id"] = ret["id"] last_edit_time = asyncio.get_event_loop().time() @@ -83,7 +87,8 @@ async def _post_send(self, stream: dict | None = None): return None source = self.message_obj.raw_message - assert isinstance( + + if not isinstance( source, ( botpy.message.Message, @@ -91,7 +96,9 @@ async def _post_send(self, stream: dict | None = None): botpy.message.DirectMessage, botpy.message.C2CMessage, ), - ) + ): + logger.warning(f"[QQOfficial] 不支持的消息源类型: {type(source)}") + return None ( plain_text, @@ -108,7 +115,7 @@ async def _post_send(self, stream: dict | None = None): ): return None - payload = { + payload: dict = { "content": plain_text, "msg_id": self.message_obj.message_id, } @@ -118,8 +125,12 @@ async def _post_send(self, stream: dict | None = None): ret = None - match type(source): - case botpy.message.GroupMessage: + match source: + case botpy.message.GroupMessage(): + if not source.group_openid: + logger.error("[QQOfficial] GroupMessage 缺少 group_openid") + return None + if image_base64: media = await self.upload_group_and_c2c_image( image_base64, @@ -140,7 +151,8 @@ async def _post_send(self, stream: dict | None = None): group_openid=source.group_openid, **payload, ) - case botpy.message.C2CMessage: + + case botpy.message.C2CMessage(): if image_base64: media = await self.upload_group_and_c2c_image( image_base64, @@ -169,18 +181,23 @@ async def _post_send(self, stream: dict | None = None): **payload, ) logger.debug(f"Message sent to C2C: {ret}") - case botpy.message.Message: + + case botpy.message.Message(): if image_path: payload["file_image"] = image_path ret = await self.bot.api.post_message( channel_id=source.channel_id, **payload, ) - case botpy.message.DirectMessage: + + case botpy.message.DirectMessage(): if image_path: payload["file_image"] = image_path ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload) + case _: + pass + await super().send(self.send_buffer) self.send_buffer = None @@ -198,18 +215,33 @@ async def upload_group_and_c2c_image( "file_type": file_type, "srv_send_msg": False, } + + result = None if "openid" in kwargs: payload["openid"] = kwargs["openid"] route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"]) - return await self.bot.api._http.request(route, json=payload) - if "group_openid" in kwargs: + result = await self.bot.api._http.request(route, json=payload) + elif "group_openid" in kwargs: payload["group_openid"] = kwargs["group_openid"] route = Route( "POST", "/v2/groups/{group_openid}/files", group_openid=kwargs["group_openid"], ) - return await self.bot.api._http.request(route, json=payload) + result = await self.bot.api._http.request(route, json=payload) + else: + raise ValueError("Invalid upload parameters") + + if not isinstance(result, dict): + raise RuntimeError( + f"Failed to upload image, response is not dict: {result}" + ) + + return Media( + file_uuid=result["file_uuid"], + file_info=result["file_info"], + ttl=result.get("ttl", 0), + ) async def upload_group_and_c2c_record( self, @@ -252,11 +284,14 @@ async def upload_group_and_c2c_record( result = await self.bot.api._http.request(route, json=payload) if result: + if not isinstance(result, dict): + logger.error(f"上传文件响应格式错误: {result}") + return None + return Media( - file_uuid=result.get("file_uuid"), - file_info=result.get("file_info"), + file_uuid=result["file_uuid"], + file_info=result["file_info"], ttl=result.get("ttl", 0), - file_id=result.get("id", ""), ) except Exception as e: logger.error(f"上传请求错误: {e}") @@ -273,7 +308,7 @@ async def post_c2c_message( message_reference: message.Reference | None = None, media: message.Media | None = None, msg_id: str | None = None, - msg_seq: str = 1, + msg_seq: int | None = 1, event_id: str | None = None, markdown: message.MarkdownPayload | None = None, keyboard: message.Keyboard | None = None, @@ -282,7 +317,14 @@ async def post_c2c_message( payload = locals() payload.pop("self", None) route = Route("POST", "/v2/users/{openid}/messages", openid=openid) - return await self.bot.api._http.request(route, json=payload) + result = await self.bot.api._http.request(route, json=payload) + + if not isinstance(result, dict): + raise RuntimeError( + f"Failed to post c2c message, response is not dict: {result}" + ) + + return message.Message(**result) @staticmethod async def _parse_to_qqofficial(message: MessageChain): @@ -302,8 +344,10 @@ async def _parse_to_qqofficial(message: MessageChain): image_base64 = file_to_base64(image_file_path) elif i.file and i.file.startswith("base64://"): image_base64 = i.file - else: + elif i.file: image_base64 = file_to_base64(i.file) + else: + raise ValueError("Unsupported image file format") image_base64 = image_base64.removeprefix("base64://") elif isinstance(i, Record): if i.file: diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index 9b1637b22..2a1bcda47 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -4,6 +4,7 @@ import logging import os import time +from typing import cast import botpy import botpy.message @@ -44,7 +45,9 @@ async def on_group_at_message_create(self, message: botpy.message.GroupMessage): MessageType.GROUP_MESSAGE, ) abm.session_id = ( - abm.sender.user_id if self.platform.unique_session else message.group_openid + abm.sender.user_id + if self.platform.unique_session + else cast(str, message.group_openid) ) self._commit(abm) @@ -101,7 +104,7 @@ def __init__( self.appid = platform_config["appid"] self.secret = platform_config["secret"] - self.unique_session = platform_settings["unique_session"] + self.unique_session: bool = platform_settings["unique_session"] qq_group = platform_config["enable_group_c2c"] guild_dm = platform_config["enable_guild_direct_message"] @@ -137,12 +140,15 @@ def meta(self) -> PlatformMetadata: return PlatformMetadata( name="qq_official", description="QQ 机器人官方 API 适配器", - id=self.config.get("id"), + id=cast(str, self.config.get("id")), ) @staticmethod def _parse_from_qqofficial( - message: botpy.message.Message | botpy.message.GroupMessage, + message: botpy.message.Message + | botpy.message.GroupMessage + | botpy.message.DirectMessage + | botpy.message.C2CMessage, message_type: MessageType, ): abm = AstrBotMessage() @@ -150,7 +156,7 @@ def _parse_from_qqofficial( abm.timestamp = int(time.time()) abm.raw_message = message abm.message_id = message.id - abm.tag = "qq_official" + # abm.tag = "qq_official" msg: list[BaseMessageComponent] = [] if isinstance(message, botpy.message.GroupMessage) or isinstance( @@ -180,9 +186,9 @@ def _parse_from_qqofficial( message, botpy.message.DirectMessage, ): - try: + if isinstance(message, botpy.message.Message): abm.self_id = str(message.mentions[0].id) - except BaseException as _: + else: abm.self_id = "" plain_content = message.content.replace( diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py index fcb41ca2f..63b6726fe 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py @@ -1,6 +1,6 @@ import asyncio import logging -from typing import Any +from typing import Any, cast import botpy import botpy.message @@ -36,7 +36,9 @@ async def on_group_at_message_create(self, message: botpy.message.GroupMessage): MessageType.GROUP_MESSAGE, ) abm.session_id = ( - abm.sender.user_id if self.platform.unique_session else message.group_openid + abm.sender.user_id + if self.platform.unique_session + else cast(str, message.group_openid) ) self._commit(abm) @@ -120,7 +122,7 @@ def meta(self) -> PlatformMetadata: return PlatformMetadata( name="qq_official_webhook", description="QQ 机器人官方 API 适配器", - id=self.config.get("id"), + id=cast(str, self.config.get("id")), ) async def run(self): diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py index bce45e892..2eda11a6c 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py @@ -1,5 +1,6 @@ import asyncio import logging +from typing import cast import quart from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token @@ -99,7 +100,7 @@ async def handle_callback(self, request) -> dict: if opcode == 13: # validation - signed = await self.webhook_validation(data) + signed = await self.webhook_validation(cast(dict, data)) print(signed) return signed diff --git a/astrbot/core/platform/sources/slack/client.py b/astrbot/core/platform/sources/slack/client.py index 574fe40bc..fbdc71759 100644 --- a/astrbot/core/platform/sources/slack/client.py +++ b/astrbot/core/platform/sources/slack/client.py @@ -4,9 +4,11 @@ import json import logging from collections.abc import Callable +from typing import cast from quart import Quart, Response, request from slack_sdk.socket_mode.aiohttp import SocketModeClient +from slack_sdk.socket_mode.async_client import AsyncBaseSocketModeClient from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.response import SocketModeResponse from slack_sdk.web.async_client import AsyncWebClient @@ -66,7 +68,7 @@ async def handle_callback(self, req): """ try: # 获取请求体和头部 - body = await req.get_data() + body = cast(bytes, await req.get_data()) event_data = json.loads(body.decode("utf-8")) # Verify Slack request signature @@ -139,9 +141,14 @@ def __init__( self.event_handler = event_handler self.socket_client = None - async def _handle_events(self, _: SocketModeClient, req: SocketModeRequest): + async def _handle_events( + self, _: AsyncBaseSocketModeClient, req: SocketModeRequest + ): """处理 Socket Mode 事件""" try: + if self.socket_client is None: + raise RuntimeError("Socket client is not initialized") + # 确认收到事件 response = SocketModeResponse(envelope_id=req.envelope_id) await self.socket_client.send_socket_mode_response(response) diff --git a/astrbot/core/platform/sources/slack/slack_adapter.py b/astrbot/core/platform/sources/slack/slack_adapter.py index 81936f903..4621f8494 100644 --- a/astrbot/core/platform/sources/slack/slack_adapter.py +++ b/astrbot/core/platform/sources/slack/slack_adapter.py @@ -3,8 +3,7 @@ import re import time import uuid -from collections.abc import Awaitable -from typing import Any +from typing import Any, cast import aiohttp from slack_sdk.socket_mode.request import SocketModeRequest @@ -68,7 +67,7 @@ def __init__( self.metadata = PlatformMetadata( name="slack", description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。", - id=self.config.get("id"), + id=cast(str, self.config.get("id")), support_streaming_message=False, ) @@ -118,13 +117,13 @@ async def convert_message(self, event: dict) -> AstrBotMessage: logger.debug(f"[slack] RawMessage {event}") abm = AstrBotMessage() - abm.self_id = self.bot_self_id + abm.self_id = cast(str, self.bot_self_id) # 获取用户信息 user_id = event.get("user", "") try: user_info = await self.web_client.users_info(user=user_id) - user_data = user_info["user"] + user_data = cast(dict, user_info["user"]) user_name = user_data.get("real_name") or user_data.get("name", user_id) except Exception: user_name = user_id @@ -135,7 +134,7 @@ async def convert_message(self, event: dict) -> AstrBotMessage: channel_id = event.get("channel", "") try: channel_info = await self.web_client.conversations_info(channel=channel_id) - is_im = channel_info["channel"]["is_im"] + is_im = cast(dict, channel_info["channel"])["is_im"] if is_im: abm.type = MessageType.FRIEND_MESSAGE @@ -178,7 +177,7 @@ async def convert_message(self, event: dict) -> AstrBotMessage: for mention in mentions: try: mentioned_user = await self.web_client.users_info(user=mention) - user_data = mentioned_user["user"] + user_data = cast(dict, mentioned_user["user"]) user_name = user_data.get("real_name") or user_data.get( "name", mention, @@ -329,7 +328,7 @@ async def get_file_base64(self, url: str) -> str: ) raise Exception(f"下载文件失败: {resp.status}") - async def run(self) -> Awaitable[Any]: + async def run(self) -> None: self.bot_self_id = await self.get_bot_user_id() logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}") diff --git a/astrbot/core/platform/sources/slack/slack_event.py b/astrbot/core/platform/sources/slack/slack_event.py index d3e768800..822e6fdeb 100644 --- a/astrbot/core/platform/sources/slack/slack_event.py +++ b/astrbot/core/platform/sources/slack/slack_event.py @@ -1,6 +1,7 @@ import asyncio import re -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Iterable +from typing import cast from slack_sdk.web.async_client import AsyncWebClient @@ -38,7 +39,7 @@ async def _from_segment_to_slack_block( if isinstance(segment, Image): # upload file url = segment.url or segment.file - if url.startswith("http"): + if url and url.startswith("http"): return { "type": "image", "image_url": url, @@ -55,7 +56,7 @@ async def _from_segment_to_slack_block( "type": "section", "text": {"type": "mrkdwn", "text": "图片上传失败"}, } - image_url = response["files"][0]["url_private"] + image_url = cast(list, response["files"])[0]["url_private"] logger.debug(f"Slack file upload response: {response}") return { "type": "image", @@ -77,7 +78,7 @@ async def _from_segment_to_slack_block( "type": "section", "text": {"type": "mrkdwn", "text": "文件上传失败"}, } - file_url = response["files"][0]["permalink"] + file_url = cast(list, response["files"])[0]["permalink"] return { "type": "section", "text": { @@ -225,10 +226,10 @@ async def get_group(self, group_id=None, **kwargs): ) members = [] - for member_id in members_response["members"]: + for member_id in cast(Iterable, members_response["members"]): try: user_info = await self.web_client.users_info(user=member_id) - user_data = user_info["user"] + user_data = cast(dict, user_info["user"]) members.append( MessageMember( user_id=member_id, @@ -240,7 +241,7 @@ async def get_group(self, group_id=None, **kwargs): # 如果获取用户信息失败,使用默认信息 members.append(MessageMember(user_id=member_id, nickname=member_id)) - channel_data = channel_info["channel"] + channel_data = cast(dict, channel_info["channel"]) return Group( group_id=channel_id, group_name=channel_data.get("name", ""), diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index 34fd86ad9..37f60e65a 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -1,6 +1,7 @@ import asyncio import os import re +from typing import Any, cast import telegramify_markdown from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji @@ -17,8 +18,6 @@ Reply, ) from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata -from astrbot.core.utils.astrbot_path import get_astrbot_data_path -from astrbot.core.utils.io import download_file class TelegramPlatformEvent(AstrMessageEvent): @@ -97,7 +96,7 @@ async def send_with_client( "chat_id": user_name, } if has_reply: - payload["reply_to_message_id"] = reply_message_id + payload["reply_to_message_id"] = str(reply_message_id) if message_thread_id: payload["message_thread_id"] = message_thread_id @@ -110,33 +109,30 @@ async def send_with_client( try: md_text = telegramify_markdown.markdownify( chunk, - max_line_length=None, normalize_whitespace=False, ) await client.send_message( text=md_text, parse_mode="MarkdownV2", - **payload, + **cast(Any, payload), ) except Exception as e: logger.warning( f"MarkdownV2 send failed: {e}. Using plain text instead.", ) - await client.send_message(text=chunk, **payload) + await client.send_message(text=chunk, **cast(Any, payload)) elif isinstance(i, Image): image_path = await i.convert_to_file_path() - await client.send_photo(photo=image_path, **payload) + await client.send_photo(photo=image_path, **cast(Any, payload)) elif isinstance(i, File): - if i.file.startswith("https://"): - temp_dir = os.path.join(get_astrbot_data_path(), "temp") - path = os.path.join(temp_dir, i.name) - await download_file(i.file, path) - i.file = path - - await client.send_document(document=i.file, filename=i.name, **payload) + path = await i.get_file() + name = i.name or os.path.basename(path) + await client.send_document( + document=path, filename=name, **cast(Any, payload) + ) elif isinstance(i, Record): path = await i.convert_to_file_path() - await client.send_voice(voice=path, **payload) + await client.send_voice(voice=path, **cast(Any, payload)) async def send(self, message: MessageChain): if self.get_message_type() == MessageType.GROUP_MESSAGE: @@ -214,24 +210,23 @@ async def send_streaming(self, generator, use_fallback: bool = False): delta += i.text elif isinstance(i, Image): image_path = await i.convert_to_file_path() - await self.client.send_photo(photo=image_path, **payload) + await self.client.send_photo( + photo=image_path, **cast(Any, payload) + ) continue elif isinstance(i, File): - if i.file.startswith("https://"): - temp_dir = os.path.join(get_astrbot_data_path(), "temp") - path = os.path.join(temp_dir, i.name) - await download_file(i.file, path) - i.file = path + path = await i.get_file() + name = i.name or os.path.basename(path) await self.client.send_document( - document=i.file, - filename=i.name, - **payload, + document=path, + filename=name, + **cast(Any, payload), ) continue elif isinstance(i, Record): path = await i.convert_to_file_path() - await self.client.send_voice(voice=path, **payload) + await self.client.send_voice(voice=path, **cast(Any, payload)) continue else: logger.warning(f"不支持的消息类型: {type(i)}") @@ -260,7 +255,9 @@ async def send_streaming(self, generator, use_fallback: bool = False): else: # delta 长度一般不会大于 4096,因此这里直接发送 try: - msg = await self.client.send_message(text=delta, **payload) + msg = await self.client.send_message( + text=delta, **cast(Any, payload) + ) current_content = delta except Exception as e: logger.warning(f"发送消息失败(streaming): {e!s}") @@ -274,7 +271,6 @@ async def send_streaming(self, generator, use_fallback: bool = False): try: markdown_text = telegramify_markdown.markdownify( delta, - max_line_length=None, normalize_whitespace=False, ) await self.client.edit_message_text( diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 42f79b80d..084d7860d 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -2,7 +2,7 @@ import os import time import uuid -from collections.abc import Awaitable, Callable +from collections.abc import Callable, Coroutine from typing import Any from astrbot import logger @@ -207,7 +207,7 @@ async def convert_message(self, data: tuple) -> AstrBotMessage: abm.raw_message = data return abm - def run(self) -> Awaitable[Any]: + def run(self) -> Coroutine[Any, Any, None]: async def callback(data: tuple): abm = await self.convert_message(data) await self.handle_msg(abm) diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index 70c834e65..9f1a6d059 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -101,9 +101,9 @@ async def _send( return data - async def send(self, message: MessageChain): + async def send(self, message: MessageChain | None): await WebChatMessageEvent._send(message, session_id=self.session_id) - await super().send(message) + await super().send(MessageChain([])) async def send_streaming(self, generator, use_fallback: bool = False): final_data = "" diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py index 8186dd1ca..4c9a9d36b 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py @@ -4,6 +4,7 @@ import os import time import traceback +from typing import cast import aiohttp import anyio @@ -69,7 +70,7 @@ def __init__( ) self.base_url = f"http://{self.host}:{self.port}" self.auth_key = None # 用于保存生成的授权码 - self.wxid = None # 用于保存登录成功后的 wxid + self.wxid: str | None = None # 用于保存登录成功后的 wxid self.credentials_file = os.path.join( get_astrbot_data_path(), "wechatpadpro_credentials.json", @@ -398,7 +399,7 @@ async def connect_websocket(self): ) await asyncio.sleep(5) - async def handle_websocket_message(self, message: str): + async def handle_websocket_message(self, message: str | bytes): """处理从 WebSocket 接收到的消息。""" logger.debug(f"收到 WebSocket 消息: {message}") try: @@ -430,10 +431,13 @@ async def handle_websocket_message(self, message: str): async def convert_message(self, raw_message: dict) -> AstrBotMessage | None: """将 WeChatPadPro 原始消息转换为 AstrBotMessage。""" + if self.wxid is None: + logger.error("WeChatPadPro 适配器未登录或未获取到 wxid,无法处理消息。") + return None abm = AstrBotMessage() abm.raw_message = raw_message abm.message_id = str(raw_message.get("msg_id")) - abm.timestamp = raw_message.get("create_time") + abm.timestamp = cast(int, raw_message.get("create_time")) abm.self_id = self.wxid if int(time.time()) - abm.timestamp > 180: @@ -446,7 +450,7 @@ async def convert_message(self, raw_message: dict) -> AstrBotMessage | None: to_user_name = raw_message.get("to_user_name", {}).get("str", "") content = raw_message.get("content", {}).get("str", "") push_content = raw_message.get("push_content", "") - msg_type = raw_message.get("msg_type") + msg_type = cast(int, raw_message.get("msg_type")) abm.message_str = "" abm.message = [] @@ -574,7 +578,7 @@ async def _download_raw_image( from_user_name: str, to_user_name: str, msg_id: int, - ): + ) -> dict | None: """下载原始图片。""" url = f"{self.base_url}/message/GetMsgBigImg" params = {"key": self.auth_key} @@ -725,12 +729,15 @@ async def _process_message_content( # 图片消息 from_user_name = raw_message.get("from_user_name", {}).get("str", "") to_user_name = raw_message.get("to_user_name", {}).get("str", "") - msg_id = raw_message.get("msg_id") + msg_id = cast(int, raw_message.get("msg_id")) image_resp = await self._download_raw_image( from_user_name, to_user_name, msg_id, ) + if image_resp is None: + logger.error(f"下载图片失败: msg_id={msg_id}") + return image_bs64_data = ( image_resp.get("Data", {}).get("Data", {}).get("Buffer", None) ) @@ -771,6 +778,9 @@ async def _process_message_content( bufid = 0 to_user_name = raw_message.get("to_user_name", {}).get("str", "") new_msg_id = raw_message.get("new_msg_id") + if new_msg_id is None: + logger.error("语音消息缺少 new_msg_id") + return data_parser = GeweDataParser( content=content, is_private_chat=(abm.type != MessageType.GROUP_MESSAGE), @@ -778,6 +788,9 @@ async def _process_message_content( ) voicemsg = data_parser._format_to_xml().find("voicemsg") + if voicemsg is None: + logger.error("无法从 XML 解析 voicemsg 节点") + return bufid = voicemsg.get("bufid") or "0" length = int(voicemsg.get("length") or 0) voice_resp = await self.download_voice( @@ -786,6 +799,9 @@ async def _process_message_content( bufid=bufid, length=length, ) + if voice_resp is None: + logger.error(f"下载语音失败: new_msg_id={new_msg_id}") + return voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None) if voice_bs64_data: voice_bs64_data = base64.b64decode(voice_bs64_data) @@ -827,7 +843,8 @@ async def terminate(self): try: if self.ws_handle_task: self.ws_handle_task.cancel() - self._shutdown_event.set() + if self._shutdown_event is not None: + self._shutdown_event.set() except Exception: pass @@ -894,8 +911,8 @@ async def get_contact_list(self): async def get_contact_details_list( self, - room_wx_id_list: list[str] = None, - user_names: list[str] = None, + room_wx_id_list: list[str] | None = None, + user_names: list[str] | None = None, ) -> dict | None: """获取联系人详情列表。""" if room_wx_id_list is None: diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 9bbed276b..8f3d091a4 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -2,7 +2,8 @@ import os import sys import uuid -from typing import Any +from collections.abc import Awaitable, Callable +from typing import Any, cast import quart from requests import Response @@ -40,7 +41,7 @@ class WecomServer: def __init__(self, event_queue: asyncio.Queue, config: dict): self.server = quart.Quart(__name__) - self.port = int(config.get("port")) + self.port = int(cast(str, config.get("port"))) self.callback_server_host = config.get("callback_server_host", "0.0.0.0") self.server.add_url_rule( "/callback/command", @@ -60,7 +61,7 @@ def __init__(self, event_queue: asyncio.Queue, config: dict): config["corpid"].strip(), ) - self.callback = None + self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None self.shutdown_event = asyncio.Event() async def verify(self): @@ -114,7 +115,7 @@ async def handle_callback(self, request) -> str: logger.error("解密失败,签名异常,请检查配置。") raise else: - msg = parse_message(xml) + msg = cast(BaseMessage, parse_message(xml)) logger.info(f"解析成功: {msg}") if self.callback: @@ -176,10 +177,10 @@ def __init__( # inject self.wechat_kf_api = WeChatKF(client=self.client) self.wechat_kf_message_api = WeChatKFMessage(self.client) - self.client.kf = self.wechat_kf_api - self.client.kf_message = self.wechat_kf_message_api + self.client.__setattr__("kf", self.wechat_kf_api) + self.client.__setattr__("kf_message", self.wechat_kf_message_api) - self.client.API_BASE_URL = self.api_base_url + self.client.__setattr__("API_BASE_URL", self.api_base_url) async def callback(msg: BaseMessage): if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event": @@ -278,37 +279,33 @@ async def webhook_callback(self, request: Any) -> Any: async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None: abm = AstrBotMessage() - if msg.type == "text": - assert isinstance(msg, TextMessage) + if isinstance(msg, TextMessage): abm.message_str = msg.content abm.self_id = str(msg.agent) abm.message = [Plain(msg.content)] abm.type = MessageType.FRIEND_MESSAGE abm.sender = MessageMember( - msg.source, - msg.source, + cast(str, msg.source), + cast(str, msg.source), ) - abm.message_id = msg.id - abm.timestamp = msg.time + abm.message_id = str(msg.id) + abm.timestamp = int(cast(int | str, msg.time)) abm.session_id = abm.sender.user_id abm.raw_message = msg - elif msg.type == "image": - assert isinstance(msg, ImageMessage) + elif isinstance(msg, ImageMessage): abm.message_str = "[图片]" abm.self_id = str(msg.agent) abm.message = [Image(file=msg.image, url=msg.image)] abm.type = MessageType.FRIEND_MESSAGE abm.sender = MessageMember( - msg.source, - msg.source, + cast(str, msg.source), + cast(str, msg.source), ) - abm.message_id = msg.id - abm.timestamp = msg.time + abm.message_id = str(msg.id) + abm.timestamp = int(cast(int | str, msg.time)) abm.session_id = abm.sender.user_id abm.raw_message = msg - elif msg.type == "voice": - assert isinstance(msg, VoiceMessage) - + elif isinstance(msg, VoiceMessage): resp: Response = await asyncio.get_event_loop().run_in_executor( None, self.client.media.download, @@ -335,11 +332,11 @@ async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None: abm.message = [Record(file=path_wav, url=path_wav)] abm.type = MessageType.FRIEND_MESSAGE abm.sender = MessageMember( - msg.source, - msg.source, + cast(str, msg.source), + cast(str, msg.source), ) - abm.message_id = msg.id - abm.timestamp = msg.time + abm.message_id = str(msg.id) + abm.timestamp = int(cast(int | str, msg.time)) abm.session_id = abm.sender.user_id abm.raw_message = msg else: @@ -351,7 +348,7 @@ async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None: async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None: msgtype = msg.get("msgtype") - external_userid = msg.get("external_userid") + external_userid = cast(str, msg.get("external_userid")) abm = AstrBotMessage() abm.raw_message = msg abm.raw_message["_wechat_kf_flag"] = None # 方便处理 diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index c566d3f0e..0b5dae272 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -93,10 +93,10 @@ async def send(self, message: MessageChain): if is_wechat_kf: # 微信客服 kf_message_api = getattr(self.client, "kf_message", None) - if not kf_message_api: + if not isinstance(kf_message_api, WeChatKFMessage): logger.warning("未找到微信客服发送消息方法。") return - assert isinstance(kf_message_api, WeChatKFMessage) + user_id = self.get_sender_id() for comp in message.chain: if isinstance(comp, Plain): diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py index 0091783a4..fd11d7ceb 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py @@ -39,7 +39,7 @@ def __init__( @staticmethod async def _send( - message_chain: MessageChain, + message_chain: MessageChain | None, stream_id: str, queue_mgr: WecomAIQueueMgr, streaming: bool = False, @@ -90,7 +90,7 @@ async def _send( return data - async def send(self, message: MessageChain): + async def send(self, message: MessageChain | None): """发送消息""" raw = self.message_obj.raw_message assert isinstance(raw, dict), ( @@ -98,7 +98,7 @@ async def send(self, message: MessageChain): ) stream_id = raw.get("stream_id", self.session_id) await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr) - await super().send(message) + await super().send(MessageChain([])) async def send_streaming(self, generator, use_fallback=False): """流式发送消息,参考webchat的send_streaming设计""" diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py index c84e2865b..d0304a48e 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py @@ -1,7 +1,8 @@ import asyncio import sys import uuid -from typing import Any +from collections.abc import Awaitable, Callable +from typing import Any, cast import quart from requests import Response @@ -36,7 +37,7 @@ class WeixinOfficialAccountServer: def __init__(self, event_queue: asyncio.Queue, config: dict): self.server = quart.Quart(__name__) - self.port = int(config.get("port")) + self.port = int(cast(int | str, config.get("port"))) self.callback_server_host = config.get("callback_server_host", "0.0.0.0") self.token = config.get("token") self.encoding_aes_key = config.get("encoding_aes_key") @@ -55,7 +56,7 @@ def __init__(self, event_queue: asyncio.Queue, config: dict): self.event_queue = event_queue - self.callback = None + self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None self.shutdown_event = asyncio.Event() async def verify(self): @@ -114,6 +115,9 @@ async def handle_callback(self, request) -> str: raise else: msg = parse_message(xml) + if not msg: + logger.error("解析失败。msg为None。") + raise logger.info(f"解析成功: {msg}") if self.callback: @@ -176,7 +180,7 @@ def __init__( self.config["secret"].strip(), ) - self.client.API_BASE_URL = self.api_base_url + self.client.__setattr__("API_BASE_URL", self.api_base_url) # 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重 # msgid -> Future @@ -188,11 +192,11 @@ async def callback(msg: BaseMessage): await self.convert_message(msg, None) else: if msg.id in self.wexin_event_workers: - future = self.wexin_event_workers[msg.id] + future = self.wexin_event_workers[str(cast(str | int, msg.id))] logger.debug(f"duplicate message id checked: {msg.id}") else: future = asyncio.get_event_loop().create_future() - self.wexin_event_workers[msg.id] = future + self.wexin_event_workers[str(cast(str | int, msg.id))] = future await self.convert_message(msg, future) # I love shield so much! result = await asyncio.wait_for( @@ -200,7 +204,7 @@ async def callback(msg: BaseMessage): 60, ) # wait for 60s logger.debug(f"Got future result: {result}") - self.wexin_event_workers.pop(msg.id, None) + self.wexin_event_workers.pop(str(cast(str | int, msg.id)), None) return result # xml. see weixin_offacc_event.py except asyncio.TimeoutError: pass @@ -248,33 +252,33 @@ async def webhook_callback(self, request: Any) -> Any: async def convert_message( self, msg, - future: asyncio.Future = None, + future: asyncio.Future | None = None, ) -> AstrBotMessage | None: abm = AstrBotMessage() if isinstance(msg, TextMessage): - abm.message_str = msg.content + abm.message_str = cast(str, msg.content) abm.self_id = str(msg.target) - abm.message = [Plain(msg.content)] + abm.message = [Plain(cast(str, msg.content))] abm.type = MessageType.FRIEND_MESSAGE abm.sender = MessageMember( - msg.source, - msg.source, + cast(str, msg.source), + cast(str, msg.source), ) - abm.message_id = msg.id - abm.timestamp = msg.time + abm.message_id = str(cast(str | int, msg.id)) + abm.timestamp = cast(int, msg.time) abm.session_id = abm.sender.user_id elif msg.type == "image": assert isinstance(msg, ImageMessage) abm.message_str = "[图片]" abm.self_id = str(msg.target) - abm.message = [Image(file=msg.image, url=msg.image)] + abm.message = [Image(file=cast(str, msg.image), url=cast(str, msg.image))] abm.type = MessageType.FRIEND_MESSAGE abm.sender = MessageMember( - msg.source, - msg.source, + cast(str, msg.source), + cast(str, msg.source), ) - abm.message_id = msg.id - abm.timestamp = msg.time + abm.message_id = str(cast(str | int, msg.id)) + abm.timestamp = cast(int, msg.time) abm.session_id = abm.sender.user_id elif msg.type == "voice": assert isinstance(msg, VoiceMessage) @@ -306,15 +310,16 @@ async def convert_message( abm.message = [Record(file=path_wav, url=path_wav)] abm.type = MessageType.FRIEND_MESSAGE abm.sender = MessageMember( - msg.source, - msg.source, + cast(str, msg.source), + cast(str, msg.source), ) - abm.message_id = msg.id - abm.timestamp = msg.time + abm.message_id = str(cast(str | int, msg.id)) + abm.timestamp = cast(int, msg.time) abm.session_id = abm.sender.user_id else: logger.warning(f"暂未实现的事件: {msg.type}") - future.set_result(None) + if future: + future.set_result(None) return # 很不优雅 :( abm.raw_message = { diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py index 1974c91a9..c1f137a41 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py @@ -1,5 +1,6 @@ import asyncio import uuid +from typing import cast from wechatpy import WeChatClient from wechatpy.replies import ImageReply, TextReply, VoiceReply @@ -85,7 +86,9 @@ async def split_plain(self, plain: str) -> list[str]: async def send(self, message: MessageChain): message_obj = self.message_obj - active_send_mode = message_obj.raw_message.get("active_send_mode", False) + active_send_mode = cast(dict, message_obj.raw_message).get( + "active_send_mode", False + ) for comp in message.chain: if isinstance(comp, Plain): # Split long text messages if needed @@ -96,10 +99,10 @@ async def send(self, message: MessageChain): else: reply = TextReply( content=chunk, - message=self.message_obj.raw_message["message"], + message=cast(dict, self.message_obj.raw_message)["message"], ) xml = reply.render() - future = self.message_obj.raw_message["future"] + future = cast(dict, self.message_obj.raw_message)["future"] assert isinstance(future, asyncio.Future) future.set_result(xml) await asyncio.sleep(0.5) # Avoid sending too fast @@ -125,10 +128,10 @@ async def send(self, message: MessageChain): else: reply = ImageReply( media_id=response["media_id"], - message=self.message_obj.raw_message["message"], + message=cast(dict, self.message_obj.raw_message)["message"], ) xml = reply.render() - future = self.message_obj.raw_message["future"] + future = cast(dict, self.message_obj.raw_message)["future"] assert isinstance(future, asyncio.Future) future.set_result(xml) @@ -160,10 +163,10 @@ async def send(self, message: MessageChain): else: reply = VoiceReply( media_id=response["media_id"], - message=self.message_obj.raw_message["message"], + message=cast(dict, self.message_obj.raw_message)["message"], ) xml = reply.render() - future = self.message_obj.raw_message["future"] + future = cast(dict, self.message_obj.raw_message)["future"] assert isinstance(future, asyncio.Future) future.set_result(xml) diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 8e04423ed..7aad86bdd 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -4,7 +4,7 @@ import copy import json import os -from collections.abc import Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Any import aiohttp @@ -118,7 +118,7 @@ def spec_to_func( name: str, func_args: list[dict], desc: str, - handler: Callable[..., Awaitable[Any]], + handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]], ) -> FuncTool: params = { "type": "object", # hard-coded here @@ -140,7 +140,7 @@ def add_func( name: str, func_args: list, desc: str, - handler: Callable[..., Awaitable[Any]], + handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]], ) -> None: """添加函数调用工具 diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 3e477255a..be8edc282 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -1,5 +1,6 @@ import asyncio import traceback +from typing import Protocol, runtime_checkable from astrbot.core import astrbot_config, logger, sp from astrbot.core.astrbot_config_mgr import AstrBotConfigManager @@ -10,6 +11,7 @@ from .provider import ( EmbeddingProvider, Provider, + Providers, RerankProvider, STTProvider, TTSProvider, @@ -17,6 +19,11 @@ from .register import llm_tools, provider_cls_map +@runtime_checkable +class HasInitialize(Protocol): + async def initialize(self) -> None: ... + + class ProviderManager: def __init__( self, @@ -48,7 +55,7 @@ def __init__( """加载的 Rerank Provider 的实例""" self.inst_map: dict[ str, - Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider, + Providers, ] = {} """Provider 实例映射. key: provider_id, value: Provider 实例""" self.llm_tools = llm_tools @@ -123,15 +130,13 @@ async def set_provider( self.curr_provider_inst = prov sp.put("curr_provider", provider_id, scope="global", scope_id="global") - async def get_provider_by_id(self, provider_id: str) -> Provider | None: + async def get_provider_by_id(self, provider_id: str) -> Providers | None: """根据提供商 ID 获取提供商实例""" return self.inst_map.get(provider_id) def get_using_provider( - self, - provider_type: ProviderType, - umo=None, - ) -> Provider | STTProvider | TTSProvider | None: + self, provider_type: ProviderType, umo=None + ) -> Providers | None: """获取正在使用的提供商实例。 Args: @@ -191,7 +196,6 @@ async def initialize(self): logger.error(traceback.format_exc()) logger.error(e) - # 设置默认提供商 selected_provider_id = sp.get( "curr_provider", self.provider_settings.get("default_provider_id"), @@ -210,15 +214,37 @@ async def initialize(self): scope="global", scope_id="global", ) - self.curr_provider_inst = self.inst_map.get(selected_provider_id) + + temp_provider = ( + self.inst_map.get(selected_provider_id) + if isinstance(selected_provider_id, str) + else None + ) + self.curr_provider_inst = ( + temp_provider if isinstance(temp_provider, Provider) else None + ) if not self.curr_provider_inst and self.provider_insts: self.curr_provider_inst = self.provider_insts[0] - self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id) + temp_stt = ( + self.inst_map.get(selected_stt_provider_id) + if isinstance(selected_stt_provider_id, str) + else None + ) + self.curr_stt_provider_inst = ( + temp_stt if isinstance(temp_stt, STTProvider) else None + ) if not self.curr_stt_provider_inst and self.stt_provider_insts: self.curr_stt_provider_inst = self.stt_provider_insts[0] - self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id) + temp_tts = ( + self.inst_map.get(selected_tts_provider_id) + if isinstance(selected_tts_provider_id, str) + else None + ) + self.curr_tts_provider_inst = ( + temp_tts if isinstance(temp_tts, TTSProvider) else None + ) if not self.curr_tts_provider_inst and self.tts_provider_insts: self.curr_tts_provider_inst = self.tts_provider_insts[0] @@ -358,73 +384,103 @@ async def load_provider(self, provider_config: dict): provider_metadata.id = provider_config["id"] - if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT: - # STT 任务 - inst = cls_type(provider_config, self.provider_settings) - - if getattr(inst, "initialize", None): - await inst.initialize() - - self.stt_provider_insts.append(inst) - if ( - self.provider_stt_settings.get("provider_id") - == provider_config["id"] - ): - self.curr_stt_provider_inst = inst - logger.info( - f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。", - ) - if not self.curr_stt_provider_inst: - self.curr_stt_provider_inst = inst - - elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH: - # TTS 任务 - inst = cls_type(provider_config, self.provider_settings) - - if getattr(inst, "initialize", None): - await inst.initialize() - - self.tts_provider_insts.append(inst) - if self.provider_settings.get("provider_id") == provider_config["id"]: - self.curr_tts_provider_inst = inst - logger.info( - f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。", + match provider_metadata.provider_type: + case ProviderType.SPEECH_TO_TEXT: + # STT 任务 + if not issubclass(cls_type, STTProvider): + raise TypeError( + f"Provider class {cls_type} is not a subclass of STTProvider" + ) + inst = cls_type(provider_config, self.provider_settings) + + if isinstance(inst, HasInitialize): + await inst.initialize() + + self.stt_provider_insts.append(inst) + if ( + self.provider_stt_settings.get("provider_id") + == provider_config["id"] + ): + self.curr_stt_provider_inst = inst + logger.info( + f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。", + ) + if not self.curr_stt_provider_inst: + self.curr_stt_provider_inst = inst + + case ProviderType.TEXT_TO_SPEECH: + # TTS 任务 + if not issubclass(cls_type, TTSProvider): + raise TypeError( + f"Provider class {cls_type} is not a subclass of TTSProvider" + ) + inst = cls_type(provider_config, self.provider_settings) + + if isinstance(inst, HasInitialize): + await inst.initialize() + + self.tts_provider_insts.append(inst) + if ( + self.provider_settings.get("provider_id") + == provider_config["id"] + ): + self.curr_tts_provider_inst = inst + logger.info( + f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。", + ) + if not self.curr_tts_provider_inst: + self.curr_tts_provider_inst = inst + + case ProviderType.CHAT_COMPLETION: + # 文本生成任务 + if not issubclass(cls_type, Provider): + raise TypeError( + f"Provider class {cls_type} is not a subclass of Provider" + ) + inst = cls_type( + provider_config, + self.provider_settings, ) - if not self.curr_tts_provider_inst: - self.curr_tts_provider_inst = inst - - elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION: - # 文本生成任务 - inst = cls_type( - provider_config, - self.provider_settings, - ) - if getattr(inst, "initialize", None): - await inst.initialize() - - self.provider_insts.append(inst) - if ( - self.provider_settings.get("default_provider_id") - == provider_config["id"] - ): - self.curr_provider_inst = inst - logger.info( - f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。", + if isinstance(inst, HasInitialize): + await inst.initialize() + + self.provider_insts.append(inst) + if ( + self.provider_settings.get("default_provider_id") + == provider_config["id"] + ): + self.curr_provider_inst = inst + logger.info( + f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。", + ) + if not self.curr_provider_inst: + self.curr_provider_inst = inst + + case ProviderType.EMBEDDING: + if not issubclass(cls_type, EmbeddingProvider): + raise TypeError( + f"Provider class {cls_type} is not a subclass of EmbeddingProvider" + ) + inst = cls_type(provider_config, self.provider_settings) + if isinstance(inst, HasInitialize): + await inst.initialize() + self.embedding_provider_insts.append(inst) + case ProviderType.RERANK: + if not issubclass(cls_type, RerankProvider): + raise TypeError( + f"Provider class {cls_type} is not a subclass of RerankProvider" + ) + inst = cls_type(provider_config, self.provider_settings) + if isinstance(inst, HasInitialize): + await inst.initialize() + self.rerank_provider_insts.append(inst) + case _: + # 未知供应商抛出异常,确保inst初始化 + # Should be unreachable + raise Exception( + f"未知的提供商类型:{provider_metadata.provider_type}" ) - if not self.curr_provider_inst: - self.curr_provider_inst = inst - - elif provider_metadata.provider_type == ProviderType.EMBEDDING: - inst = cls_type(provider_config, self.provider_settings) - if getattr(inst, "initialize", None): - await inst.initialize() - self.embedding_provider_insts.append(inst) - elif provider_metadata.provider_type == ProviderType.RERANK: - inst = cls_type(provider_config, self.provider_settings) - if getattr(inst, "initialize", None): - await inst.initialize() - self.rerank_provider_insts.append(inst) self.inst_map[provider_config["id"]] = inst except Exception as e: diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 2b5057e85..7f21a2ee1 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -2,6 +2,7 @@ import asyncio import os from collections.abc import AsyncGenerator +from typing import TypeAlias, Union from astrbot.core.agent.message import Message from astrbot.core.agent.tool import ToolSet @@ -14,6 +15,14 @@ from astrbot.core.provider.register import provider_cls_map from astrbot.core.utils.astrbot_path import get_astrbot_path +Providers: TypeAlias = Union[ + "Provider", + "STTProvider", + "TTSProvider", + "EmbeddingProvider", + "RerankProvider", +] + class AbstractProvider(abc.ABC): """Provider Abstract Class""" @@ -142,7 +151,9 @@ async def text_chat_stream( - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 """ - ... + if False: # pragma: no cover - make this an async generator for typing + yield None # type: ignore + raise NotImplementedError() async def pop_record(self, context: list): """弹出 context 第一条非系统提示词对话记录""" diff --git a/astrbot/core/provider/sources/azure_tts_source.py b/astrbot/core/provider/sources/azure_tts_source.py index e85d91793..2ccf146ca 100644 --- a/astrbot/core/provider/sources/azure_tts_source.py +++ b/astrbot/core/provider/sources/azure_tts_source.py @@ -29,15 +29,24 @@ def __init__(self, config: dict): self.last_sync_time = 0 self.timeout = Timeout(10.0) self.retry_count = 3 - self.client = None + self._client: AsyncClient | None = None + + @property + def client(self) -> AsyncClient: + if self._client is None: + raise RuntimeError( + "Client not initialized. Please use 'async with' context." + ) + return self._client async def __aenter__(self): - self.client = AsyncClient(timeout=self.timeout) + self._client = AsyncClient(timeout=self.timeout) return self async def __aexit__(self, exc_type, exc_val, exc_tb): - if self.client: - await self.client.aclose() + if self._client: + await self._client.aclose() + self._client = None async def _sync_time(self): try: @@ -90,6 +99,7 @@ async def get_audio(self, text: str, voice_params: dict) -> str: if attempt == self.retry_count - 1: raise RuntimeError(f"OTTS请求失败: {e!s}") from e await asyncio.sleep(0.5 * (attempt + 1)) + raise RuntimeError("OTTS未返回音频文件") class AzureNativeProvider(TTSProvider): @@ -105,7 +115,7 @@ def __init__(self, provider_config: dict, provider_settings: dict): self.endpoint = ( f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1" ) - self.client = None + self._client: AsyncClient | None = None self.token = None self.token_expire = 0 self.voice_params = { @@ -116,8 +126,16 @@ def __init__(self, provider_config: dict, provider_settings: dict): "volume": provider_config.get("azure_tts_volume", "100"), } + @property + def client(self) -> AsyncClient: + if self._client is None: + raise RuntimeError( + "Client not initialized. Please use 'async with' context." + ) + return self._client + async def __aenter__(self): - self.client = AsyncClient( + self._client = AsyncClient( headers={ "User-Agent": f"AstrBot/{VERSION}", "Content-Type": "application/ssml+xml", @@ -127,8 +145,9 @@ async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): - if self.client: - await self.client.aclose() + if self._client: + await self._client.aclose() + self._client = None async def _refresh_token(self): token_url = ( @@ -181,8 +200,11 @@ def __init__(self, provider_config: dict, provider_settings: dict): key_value = provider_config.get("azure_tts_subscription_key", "") self.provider = self._parse_provider(key_value, provider_config) - def _parse_provider(self, key_value: str, config: dict) -> TTSProvider: + def _parse_provider( + self, key_value: str, config: dict + ) -> OTTSProvider | AzureNativeProvider: if key_value.lower().startswith("other["): + json_str = "" try: match = re.match(r"other\[(.*)\]", key_value, re.DOTALL) if not match: diff --git a/astrbot/core/provider/sources/bailian_rerank_source.py b/astrbot/core/provider/sources/bailian_rerank_source.py index e6f6f1a4d..9e079d4a9 100644 --- a/astrbot/core/provider/sources/bailian_rerank_source.py +++ b/astrbot/core/provider/sources/bailian_rerank_source.py @@ -177,6 +177,10 @@ async def rerank( Returns: 重排序结果列表 """ + if not self.client: + logger.error("百炼 Rerank 客户端会话已关闭,返回空结果") + return [] + if not documents: logger.warning("文档列表为空,返回空结果") return [] diff --git a/astrbot/core/provider/sources/dashscope_tts.py b/astrbot/core/provider/sources/dashscope_tts.py index 44e9965cc..50bc421fd 100644 --- a/astrbot/core/provider/sources/dashscope_tts.py +++ b/astrbot/core/provider/sources/dashscope_tts.py @@ -36,7 +36,7 @@ def __init__( super().__init__(provider_config, provider_settings) self.chosen_api_key: str = provider_config.get("api_key", "") self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella") - self.set_model(provider_config.get("model")) + self.set_model(provider_config["model"]) self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000 dashscope.api_key = self.chosen_api_key @@ -71,9 +71,10 @@ def _call_qwen_tts(self, model: str, text: str): kwargs = { "model": model, - "text": text, + "messages": None, "api_key": self.chosen_api_key, "voice": self.voice or "Cherry", + "text": text, } if not self.voice: logging.warning( diff --git a/astrbot/core/provider/sources/edge_tts_source.py b/astrbot/core/provider/sources/edge_tts_source.py index 8bbf62325..71a5a82d6 100644 --- a/astrbot/core/provider/sources/edge_tts_source.py +++ b/astrbot/core/provider/sources/edge_tts_source.py @@ -67,7 +67,7 @@ async def get_audio(self, text: str) -> str: from pyffmpeg import FFmpeg ff = FFmpeg() - ff.convert(input=mp3_path, output=wav_path) + ff.convert(input_file=mp3_path, output_file=wav_path) except Exception as e: logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换") # use ffmpeg command line diff --git a/astrbot/core/provider/sources/fishaudio_tts_api_source.py b/astrbot/core/provider/sources/fishaudio_tts_api_source.py index ca571c3ee..8362ce1b4 100644 --- a/astrbot/core/provider/sources/fishaudio_tts_api_source.py +++ b/astrbot/core/provider/sources/fishaudio_tts_api_source.py @@ -59,9 +59,9 @@ def __init__( self.headers = { "Authorization": f"Bearer {self.chosen_api_key}", } - self.set_model(provider_config.get("model")) + self.set_model(provider_config["model"]) - async def _get_reference_id_by_character(self, character: str) -> str: + async def _get_reference_id_by_character(self, character: str) -> str | None: """获取角色的reference_id Args: @@ -109,7 +109,7 @@ def _validate_reference_id(self, reference_id: str) -> bool: pattern = r"^[a-fA-F0-9]{32}$" return bool(re.match(pattern, reference_id.strip())) - async def _generate_request(self, text: str) -> dict: + async def _generate_request(self, text: str) -> ServeTTSRequest: # 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询 if self.reference_id and self.reference_id.strip(): # 验证reference_id格式 @@ -146,5 +146,6 @@ async def get_audio(self, text: str) -> str: async for chunk in response.aiter_bytes(): f.write(chunk) return path - text = await response.aread() + body = await response.aread() + text = body.decode("utf-8", errors="replace") raise Exception(f"Fish Audio API请求失败: {text}") diff --git a/astrbot/core/provider/sources/gemini_embedding_source.py b/astrbot/core/provider/sources/gemini_embedding_source.py index 8d11cce5f..146b50a4e 100644 --- a/astrbot/core/provider/sources/gemini_embedding_source.py +++ b/astrbot/core/provider/sources/gemini_embedding_source.py @@ -1,3 +1,5 @@ +from typing import cast + from google import genai from google.genai import types from google.genai.errors import APIError @@ -18,8 +20,8 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: self.provider_config = provider_config self.provider_settings = provider_settings - api_key: str = provider_config.get("embedding_api_key") - api_base: str = provider_config.get("embedding_api_base") + api_key: str = provider_config["embedding_api_key"] + api_base: str = provider_config["embedding_api_base"] timeout: int = int(provider_config.get("timeout", 20)) http_options = types.HttpOptions(timeout=timeout * 1000) @@ -41,18 +43,26 @@ async def get_embedding(self, text: str) -> list[float]: model=self.model, contents=text, ) + assert result.embeddings is not None + assert result.embeddings[0].values is not None return result.embeddings[0].values except APIError as e: raise Exception(f"Gemini Embedding API请求失败: {e.message}") - async def get_embeddings(self, texts: list[str]) -> list[list[float]]: + async def get_embeddings(self, text: list[str]) -> list[list[float]]: """批量获取文本的嵌入""" try: result = await self.client.models.embed_content( model=self.model, - contents=texts, + contents=cast(types.ContentListUnion, text), ) - return [embedding.values for embedding in result.embeddings] + assert result.embeddings is not None + + embeddings: list[list[float]] = [] + for embedding in result.embeddings: + assert embedding.values is not None + embeddings.append(embedding.values) + return embeddings except APIError as e: raise Exception(f"Gemini Embedding API批量请求失败: {e.message}") diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 3bc6c67cc..2d171cfd8 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -4,6 +4,7 @@ import logging import random from collections.abc import AsyncGenerator +from typing import cast from google import genai from google.genai import types @@ -136,7 +137,7 @@ async def _prepare_query_config( logger.warning("流式输出不支持图片模态,已自动降级为文本模态") modalities = ["Text"] - tool_list = [] + tool_list: list[types.Tool] | None = [] model_name = self.get_model() native_coderunner = self.provider_config.get("gm_native_coderunner", False) native_search = self.provider_config.get("gm_native_search", False) @@ -213,7 +214,7 @@ async def _prepare_query_config( logprobs=payloads.get("logprobs"), seed=payloads.get("seed"), response_modalities=modalities, - tools=tool_list, + tools=cast(types.ToolListUnion | None, tool_list), safety_settings=self.safety_settings if self.safety_settings else None, thinking_config=( types.ThinkingConfig( @@ -257,6 +258,7 @@ def append_or_extend( content_cls: type[types.Content], ) -> None: if contents and isinstance(contents[-1], content_cls): + assert contents[-1].parts is not None contents[-1].parts.extend(part) else: contents.append(content_cls(parts=part)) @@ -448,7 +450,7 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: ) result = await self.client.models.generate_content( model=self.get_model(), - contents=conversation, + contents=cast(types.ContentListUnion, conversation), config=config, ) logger.debug(f"genai result: {result}") @@ -524,7 +526,7 @@ async def _query_stream( ) result = await self.client.models.generate_content_stream( model=self.get_model(), - contents=conversation, + contents=cast(types.ContentListUnion, conversation), config=config, ) break diff --git a/astrbot/core/provider/sources/minimax_tts_api_source.py b/astrbot/core/provider/sources/minimax_tts_api_source.py index 5ffc7cc63..9e2d665c7 100644 --- a/astrbot/core/provider/sources/minimax_tts_api_source.py +++ b/astrbot/core/provider/sources/minimax_tts_api_source.py @@ -87,7 +87,7 @@ def _build_tts_stream_body(self, text: str): return json.dumps(dict_body) - async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]: + async def _call_tts_stream(self, text: str) -> AsyncIterator[str]: """进行流式请求""" try: async with ( @@ -117,7 +117,9 @@ async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]: data = json.loads(message[6:]) if "extra_info" in data: continue - audio = data.get("data", {}).get("audio") + audio: str | None = data.get("data", {}).get( + "audio" + ) if audio is not None: yield audio except json.JSONDecodeError: diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index 368e610ec..c9e03d7af 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -30,9 +30,9 @@ async def get_embedding(self, text: str) -> list[float]: embedding = await self.client.embeddings.create(input=text, model=self.model) return embedding.data[0].embedding - async def get_embeddings(self, texts: list[str]) -> list[list[float]]: + async def get_embeddings(self, text: list[str]) -> list[list[float]]: """批量获取文本的嵌入""" - embeddings = await self.client.embeddings.create(input=texts, model=self.model) + embeddings = await self.client.embeddings.create(input=text, model=self.model) return [item.embedding for item in embeddings.data] def get_dim(self) -> int: diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index cce3f01c9..788b649a9 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -284,6 +284,10 @@ async def _parse_openai_completion( if isinstance(tool_call, str): # workaround for #1359 tool_call = json.loads(tool_call) + if tools is None: + # 工具集未提供 + # Should be unreachable + raise Exception("工具集未提供") for tool in tools.func_list: if ( tool_call.type == "function" diff --git a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py index 67947c685..a41bd72fd 100644 --- a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py +++ b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py @@ -7,6 +7,7 @@ import os import re from datetime import datetime +from typing import cast from funasr_onnx import SenseVoiceSmall from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess @@ -32,7 +33,7 @@ def __init__( provider_settings: dict, ) -> None: super().__init__(provider_config, provider_settings) - self.set_model(provider_config.get("stt_model")) + self.set_model(provider_config["stt_model"]) self.model = None self.is_emotion = provider_config.get("is_emotion", False) @@ -86,7 +87,9 @@ async def get_text(self, audio_url: str) -> str: loop = asyncio.get_event_loop() res = await loop.run_in_executor( None, # 使用默认的线程池 - lambda: self.model(audio_url, language="auto", use_itn=True), + lambda: cast(SenseVoiceSmall, self.model)( + audio_url, language="auto", use_itn=True + ), ) # res = self.model(audio_url, language="auto", use_itn=True) diff --git a/astrbot/core/provider/sources/vllm_rerank_source.py b/astrbot/core/provider/sources/vllm_rerank_source.py index 3e6f3d33c..edd8a5491 100644 --- a/astrbot/core/provider/sources/vllm_rerank_source.py +++ b/astrbot/core/provider/sources/vllm_rerank_source.py @@ -44,6 +44,7 @@ async def rerank( } if top_n is not None: payload["top_n"] = top_n + assert self.client is not None async with self.client.post( f"{self.base_url}/v1/rerank", json=payload, diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py index 55a232498..fa69206ef 100644 --- a/astrbot/core/provider/sources/whisper_api_source.py +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -36,7 +36,7 @@ def __init__( timeout=provider_config.get("timeout", NOT_GIVEN), ) - self.set_model(provider_config.get("model")) + self.set_model(provider_config["model"]) async def _get_audio_format(self, file_path): # 定义要检测的头部字节 diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py index fbdc7d626..a14f93f14 100644 --- a/astrbot/core/provider/sources/whisper_selfhosted_source.py +++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py @@ -1,6 +1,7 @@ import asyncio import os import uuid +from typing import cast import whisper @@ -26,7 +27,7 @@ def __init__( provider_settings: dict, ) -> None: super().__init__(provider_config, provider_settings) - self.set_model(provider_config.get("model")) + self.set_model(provider_config["model"]) self.model = None async def initialize(self): @@ -75,5 +76,8 @@ async def get_text(self, audio_url: str) -> str: await tencent_silk_to_wav(audio_url, output_path) audio_url = output_path + if not self.model: + raise RuntimeError("Whisper 模型未初始化") + result = await loop.run_in_executor(None, self.model.transcribe, audio_url) - return result["text"] + return cast(str, result["text"]) diff --git a/astrbot/core/provider/sources/xinference_rerank_source.py b/astrbot/core/provider/sources/xinference_rerank_source.py index 29f3ab095..960408550 100644 --- a/astrbot/core/provider/sources/xinference_rerank_source.py +++ b/astrbot/core/provider/sources/xinference_rerank_source.py @@ -1,6 +1,11 @@ +from typing import cast + from xinference_client.client.restful.async_restful_client import ( AsyncClient as Client, ) +from xinference_client.client.restful.async_restful_client import ( + AsyncRESTfulRerankModelHandle, +) from astrbot import logger @@ -29,7 +34,7 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: False, ) self.client = None - self.model = None + self.model: AsyncRESTfulRerankModelHandle | None = None self.model_uid = None async def initialize(self): @@ -65,7 +70,10 @@ async def initialize(self): return if self.model_uid: - self.model = await self.client.get_model(self.model_uid) + self.model = cast( + AsyncRESTfulRerankModelHandle, + await self.client.get_model(self.model_uid), + ) except Exception as e: logger.error(f"Failed to initialize Xinference model: {e}") diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 21c1ad8fd..9a52ec8bc 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -285,7 +285,7 @@ def get_all_embedding_providers(self) -> list[EmbeddingProvider]: """获取所有用于 Embedding 任务的 Provider。""" return self.provider_manager.embedding_provider_insts - def get_using_provider(self, umo: str | None = None) -> Provider | None: + def get_using_provider(self, umo: str | None = None) -> Provider: """获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。 Args: @@ -296,7 +296,7 @@ def get_using_provider(self, umo: str | None = None) -> Provider | None: provider_type=ProviderType.CHAT_COMPLETION, umo=umo, ) - if prov and not isinstance(prov, Provider): + if not isinstance(prov, Provider): raise ValueError("返回的 Provider 不是 Provider 类型") return prov diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index ee3c09680..daf36a8f6 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from collections.abc import Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Any import docstring_parser @@ -12,6 +12,7 @@ from astrbot.core.agent.hooks import BaseAgentRunHooks from astrbot.core.agent.tool import FunctionTool from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.message.message_event_result import MessageEventResult from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES from astrbot.core.provider.register import llm_tools @@ -28,13 +29,19 @@ from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry -def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str: +def get_handler_full_name( + awaitable: Callable[..., Awaitable[Any] | AsyncGenerator[Any]], +) -> str: """获取 Handler 的全名""" return f"{awaitable.__module__}_{awaitable.__name__}" def get_handler_or_create( - handler: Callable[..., Awaitable[Any]], + handler: Callable[ + ..., + Awaitable[MessageEventResult | str | None] + | AsyncGenerator[MessageEventResult | str | None], + ], event_type: EventType, dont_add=False, **kwargs, @@ -169,6 +176,8 @@ def decorator(awaitable): for ( sub_handle ) in parent_register_commandable.parent_group.sub_command_filters: + if isinstance(sub_handle, CommandGroupFilter): + continue # 所有符合fullname一致的子指令handle添加自定义过滤器。 # 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器? sub_handle_md = sub_handle.get_handler_md() @@ -180,6 +189,8 @@ def decorator(awaitable): else: # 裸指令 + # 确保运行时是可调用的 handler,针对类型检查器添加忽略 + assert isinstance(awaitable, Callable) handler_md = get_handler_or_create( awaitable, EventType.AdapterMessageEvent, @@ -237,7 +248,7 @@ class RegisteringCommandable: group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group command: Callable[..., Callable[..., None]] = register_command - custom_filter: Callable[..., Callable[..., None]] = register_custom_filter + custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter def __init__(self, parent_group: CommandGroupFilter): self.parent_group = parent_group @@ -412,7 +423,13 @@ async def get_weather(event: AstrMessageEvent, location: str): if kwargs.get("registering_agent"): registering_agent = kwargs["registering_agent"] - def decorator(awaitable: Callable[..., Awaitable[Any]]): + def decorator( + awaitable: Callable[ + ..., + AsyncGenerator[MessageEventResult | str | None] + | Awaitable[MessageEventResult | str | None], + ], + ): llm_tool_name = name_ if name_ else awaitable.__name__ func_doc = awaitable.__doc__ or "" docstring = docstring_parser.parse(func_doc) diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index 141f9180a..da59cd291 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -1,9 +1,9 @@ from __future__ import annotations import enum -from collections.abc import Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import dataclass, field -from typing import Any, Generic, TypeVar +from typing import Any, Generic, Literal, TypeVar, overload from .filter import HandlerFilter from .star import star_map @@ -29,6 +29,84 @@ def _print_handlers(self): for handler in self._handlers: print(handler.handler_full_name) + @overload + def get_handlers_by_event_type( + self, + event_type: Literal[EventType.OnAstrBotLoadedEvent], + only_activated=True, + plugins_name: list[str] | None = None, + ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... + + @overload + def get_handlers_by_event_type( + self, + event_type: Literal[EventType.OnPlatformLoadedEvent], + only_activated=True, + plugins_name: list[str] | None = None, + ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... + + @overload + def get_handlers_by_event_type( + self, + event_type: Literal[EventType.AdapterMessageEvent], + only_activated=True, + plugins_name: list[str] | None = None, + ) -> list[ + StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]] + ]: ... + + @overload + def get_handlers_by_event_type( + self, + event_type: Literal[EventType.OnLLMRequestEvent], + only_activated=True, + plugins_name: list[str] | None = None, + ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... + + @overload + def get_handlers_by_event_type( + self, + event_type: Literal[EventType.OnLLMResponseEvent], + only_activated=True, + plugins_name: list[str] | None = None, + ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... + + @overload + def get_handlers_by_event_type( + self, + event_type: Literal[EventType.OnDecoratingResultEvent], + only_activated=True, + plugins_name: list[str] | None = None, + ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... + + @overload + def get_handlers_by_event_type( + self, + event_type: Literal[EventType.OnCallingFuncToolEvent], + only_activated=True, + plugins_name: list[str] | None = None, + ) -> list[ + StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]] + ]: ... + + @overload + def get_handlers_by_event_type( + self, + event_type: Literal[EventType.OnAfterMessageSentEvent], + only_activated=True, + plugins_name: list[str] | None = None, + ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... + + @overload + def get_handlers_by_event_type( + self, + event_type: EventType, + only_activated=True, + plugins_name: list[str] | None = None, + ) -> list[ + StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]] + ]: ... + def get_handlers_by_event_type( self, event_type: EventType, @@ -111,8 +189,11 @@ class EventType(enum.Enum): OnAfterMessageSentEvent = enum.auto() # 发送消息后 +H = TypeVar("H", bound=Callable[..., Any]) + + @dataclass -class StarHandlerMetadata: +class StarHandlerMetadata(Generic[H]): """描述一个 Star 所注册的某一个 Handler。""" event_type: EventType @@ -127,7 +208,7 @@ class StarHandlerMetadata: handler_module_path: str """Handler 所在的模块路径。""" - handler: Callable[..., Awaitable[Any]] + handler: H """Handler 的函数对象,应当是一个异步函数""" event_filters: list[HandlerFilter] diff --git a/astrbot/core/updator.py b/astrbot/core/updator.py index d13bab687..0a7116a0d 100644 --- a/astrbot/core/updator.py +++ b/astrbot/core/updator.py @@ -71,10 +71,10 @@ def _reboot(self, delay: int = 3): async def check_update( self, - url: str, - current_version: str, + url: str | None, + current_version: str | None, consider_prerelease: bool = True, - ) -> ReleaseInfo: + ) -> ReleaseInfo | None: """检查更新""" return await super().check_update( self.ASTRBOT_RELEASE_API, diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index 073c04938..fcf5bb3c7 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -49,7 +49,7 @@ def port_checker(port: int, host: str = "localhost"): return False -def save_temp_img(img: Image.Image | str) -> str: +def save_temp_img(img: Image.Image | bytes) -> str: temp_dir = os.path.join(get_astrbot_data_path(), "temp") # 获得文件创建时间,清除超过 12 小时的 try: diff --git a/astrbot/core/utils/session_waiter.py b/astrbot/core/utils/session_waiter.py index 33b7cb17a..e1f2fbef7 100644 --- a/astrbot/core/utils/session_waiter.py +++ b/astrbot/core/utils/session_waiter.py @@ -20,16 +20,16 @@ class SessionController: def __init__(self): self.future = asyncio.Future() - self.current_event: asyncio.Event = None + self.current_event: asyncio.Event | None = None """当前正在等待的所用的异步事件""" - self.ts: float = None + self.ts: float | None = None """上次保持(keep)开始时的时间""" - self.timeout: float | int = None + self.timeout: float | int | None = None """上次保持(keep)开始时的超时时间""" self.history_chains: list[list[Comp.BaseMessageComponent]] = [] - def stop(self, error: Exception = None): + def stop(self, error: Exception | None = None): """立即结束这个会话""" if not self.future.done(): if error: @@ -53,6 +53,8 @@ def keep(self, timeout: float = 0, reset_timeout=False): self.stop() return else: + assert self.timeout is not None + assert self.ts is not None left_timeout = self.timeout - (new_ts - self.ts) timeout = left_timeout + timeout if timeout <= 0: @@ -69,7 +71,7 @@ def keep(self, timeout: float = 0, reset_timeout=False): asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep - async def _holding(self, event: asyncio.Event, timeout: int): + async def _holding(self, event: asyncio.Event, timeout: float): """等待事件结束或超时""" try: await asyncio.wait_for(event.wait(), timeout) @@ -108,7 +110,9 @@ def __init__( ): self.session_id = session_id self.session_filter = session_filter - self.handler: Callable[[str], Awaitable[Any]] | None = None # 处理函数 + self.handler: ( + Callable[[SessionController, AstrMessageEvent], Awaitable[Any]] | None + ) = None # 处理函数 self.session_controller = SessionController() self.record_history_chains = record_history_chains @@ -119,7 +123,7 @@ def __init__( async def register_wait( self, - handler: Callable[[str], Awaitable[Any]], + handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]], timeout: int = 30, ) -> Any: """等待外部输入并处理""" @@ -137,7 +141,7 @@ async def register_wait( finally: self._cleanup() - def _cleanup(self, error: Exception = None): + def _cleanup(self, error: Exception | None = None): """清理会话""" USER_SESSIONS.pop(self.session_id, None) try: @@ -161,6 +165,7 @@ async def trigger(cls, session_id: str, event: AstrMessageEvent): ) try: # TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行 + assert session.handler is not None await session.handler(session.session_controller, event) except Exception as e: session.session_controller.stop(e) @@ -173,11 +178,13 @@ def session_waiter(timeout: int = 30, record_history_chains: bool = False): :param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。 """ - def decorator(func: Callable[[str], Awaitable[Any]]): + def decorator( + func: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]], + ): @functools.wraps(func) async def wrapper( event: AstrMessageEvent, - session_filter: SessionFilter = None, + session_filter: SessionFilter | None = None, *args, **kwargs, ): diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index 6b1f52a69..ccd394ee4 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -53,6 +53,38 @@ async def range_get_async( ret = await self.db_helper.get_preferences(scope, scope_id, key) return ret + @overload + async def session_get( + self, + umo: str, + key: str, + default: _VT = None, + ) -> _VT: ... + + @overload + async def session_get( + self, + umo: None, + key: str, + default: Any = None, + ) -> list[Preference]: ... + + @overload + async def session_get( + self, + umo: str, + key: None, + default: Any = None, + ) -> list[Preference]: ... + + @overload + async def session_get( + self, + umo: None, + key: None, + default: Any = None, + ) -> list[Preference]: ... + async def session_get( self, umo: str | None, diff --git a/astrbot/core/utils/t2i/__init__.py b/astrbot/core/utils/t2i/__init__.py index 5038a46f7..e4112c354 100644 --- a/astrbot/core/utils/t2i/__init__.py +++ b/astrbot/core/utils/t2i/__init__.py @@ -3,11 +3,11 @@ class RenderStrategy(ABC): @abstractmethod - def render(self, text: str, return_url: bool) -> str: + async def render(self, text: str, return_url: bool) -> str: pass @abstractmethod - def render_custom_template( + async def render_custom_template( self, tmpl_str: str, tmpl_data: dict, diff --git a/astrbot/core/utils/t2i/local_strategy.py b/astrbot/core/utils/t2i/local_strategy.py index 19eab2efe..2fa235129 100644 --- a/astrbot/core/utils/t2i/local_strategy.py +++ b/astrbot/core/utils/t2i/local_strategy.py @@ -20,7 +20,7 @@ class FontManager: _font_cache = {} @classmethod - def get_font(cls, size: int) -> ImageFont.FreeTypeFont: + def get_font(cls, size: int) -> ImageFont.FreeTypeFont|ImageFont.ImageFont: """获取指定大小的字体,优先从缓存获取""" if size in cls._font_cache: return cls._font_cache[size] @@ -66,23 +66,17 @@ class TextMeasurer: """测量文本尺寸的工具类""" @staticmethod - def get_text_size(text: str, font: ImageFont.FreeTypeFont) -> Tuple[int, int]: + def get_text_size(text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont) -> tuple[int, int]: """获取文本的尺寸""" - try: - # PIL 9.0.0 以上版本 - return ( - font.getbbox(text)[2:] - if hasattr(font, "getbbox") - else font.getsize(text) - ) - except Exception: - # 兼容旧版本 - return font.getsize(text) + + # 依赖库Pillow>=11.2.1,不再需要考虑<9.0.0 + left, top, right, bottom = font.getbbox("Hello world") + return int(right - left), int(bottom - top) @staticmethod def split_text_to_fit_width( - text: str, font: ImageFont.FreeTypeFont, max_width: int - ) -> List[str]: + text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont, max_width: int + ) -> list[str]: """将文本拆分为多行,确保每行不超过指定宽度""" lines = [] if not text: @@ -126,7 +120,7 @@ def calculate_height(self, image_width: int, font_size: int) -> int: def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -152,7 +146,7 @@ def calculate_height(self, image_width: int, font_size: int) -> int: def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -186,7 +180,7 @@ def calculate_height(self, image_width: int, font_size: int) -> int: def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -251,7 +245,7 @@ def calculate_height(self, image_width: int, font_size: int) -> int: def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -299,7 +293,7 @@ def render( # 倾斜变换,使用仿射变换实现斜体效果 # 变换矩阵: [1, 0.2, 0, 0, 1, 0] italic_img = text_img.transform( - text_img.size, Image.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.BICUBIC + text_img.size, Image.Transform.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.Resampling.BICUBIC ) # 粘贴到原图像 @@ -331,7 +325,7 @@ def calculate_height(self, image_width: int, font_size: int) -> int: def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -371,7 +365,7 @@ def calculate_height(self, image_width: int, font_size: int) -> int: def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -422,7 +416,7 @@ def calculate_height(self, image_width: int, font_size: int) -> int: def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -458,7 +452,7 @@ def calculate_height(self, image_width: int, font_size: int) -> int: def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -502,7 +496,7 @@ def calculate_height(self, image_width: int, font_size: int) -> int: def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -532,7 +526,7 @@ def render( class CodeBlockElement(MarkdownElement): """代码块元素""" - def __init__(self, content: List[str]): + def __init__(self, content: list[str]): super().__init__("\n".join(content)) def calculate_height(self, image_width: int, font_size: int) -> int: @@ -552,7 +546,7 @@ def calculate_height(self, image_width: int, font_size: int) -> int: def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -595,7 +589,7 @@ def calculate_height(self, image_width: int, font_size: int) -> int: def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -667,7 +661,7 @@ def calculate_height(self, image_width: int, font_size: int) -> int: def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -686,7 +680,7 @@ def render( if pasted_image.width > max_width: ratio = max_width / pasted_image.width new_size = (int(max_width), int(pasted_image.height * ratio)) - pasted_image = pasted_image.resize(new_size, Image.LANCZOS) + pasted_image = pasted_image.resize(new_size, Image.Resampling.LANCZOS) # 计算居中位置 paste_x = x + (image_width - pasted_image.width) // 2 - 10 @@ -705,7 +699,7 @@ class MarkdownParser: """Markdown解析器,将文本解析为元素""" @staticmethod - async def parse(text: str) -> List[MarkdownElement]: + async def parse(text: str) -> list[MarkdownElement]: elements = [] lines = text.split("\n") @@ -847,7 +841,7 @@ def __init__( self, font_size: int = 26, width: int = 800, - bg_color: Tuple[int, int, int] = (255, 255, 255), + bg_color: tuple[int, int, int] = (255, 255, 255), ): self.font_size = font_size self.width = width diff --git a/astrbot/core/utils/tencent_record_helper.py b/astrbot/core/utils/tencent_record_helper.py index 74c164586..b58643bd3 100644 --- a/astrbot/core/utils/tencent_record_helper.py +++ b/astrbot/core/utils/tencent_record_helper.py @@ -68,7 +68,7 @@ async def convert_to_pcm_wav(input_path: str, output_path: str) -> str: from pyffmpeg import FFmpeg ff = FFmpeg() - ff.convert(input=input_path, output=output_path) + ff.convert(input_file=input_path, output_file=output_path) except Exception as e: logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换") diff --git a/astrbot/core/utils/version_comparator.py b/astrbot/core/utils/version_comparator.py index e3bf74951..4ad2da10e 100644 --- a/astrbot/core/utils/version_comparator.py +++ b/astrbot/core/utils/version_comparator.py @@ -60,9 +60,12 @@ def split_version(version): return -1 if isinstance(p1, str) and isinstance(p2, int): return 1 - if (isinstance(p1, int) and isinstance(p2, int)) or ( - isinstance(p1, str) and isinstance(p2, str) - ): + if isinstance(p1, int) and isinstance(p2, int): + if p1 > p2: + return 1 + if p1 < p2: + return -1 + if isinstance(p1, str) and isinstance(p2, str): if p1 > p2: return 1 if p1 < p2: diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 56f98bfbb..cfb750803 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -4,7 +4,9 @@ import os import uuid from contextlib import asynccontextmanager +from typing import cast +from quart import Response as QuartResponse from quart import g, make_response, request, send_file from astrbot.core import logger @@ -424,16 +426,19 @@ async def stream(): sender_name=username, ) - response = await make_response( - stream(), - { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - "Transfer-Encoding": "chunked", - "Connection": "keep-alive", - }, + response = cast( + QuartResponse, + await make_response( + stream(), + { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Transfer-Encoding": "chunked", + "Connection": "keep-alive", + }, + ), ) - response.timeout = None # fix SSE auto disconnect issue # pyright: ignore[reportAttributeAccessIssue] + response.timeout = None # fix SSE auto disconnect issue return response async def delete_webchat_session(self): diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index c22f0f3ee..e8f17cc99 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -3,6 +3,7 @@ import os import traceback import uuid +from typing import Any from quart import request @@ -26,7 +27,7 @@ from .route import Response, Route, RouteContext -def try_cast(value: str, type_: str): +def try_cast(value: Any, type_: str): if type_ == "int": try: return int(value) @@ -505,9 +506,9 @@ async def get_embedding_dim(self): if not isinstance(inst, EmbeddingProvider): return Response().error("提供商不是 EmbeddingProvider 类型").__dict__ - # 初始化 - if getattr(inst, "initialize", None): - await inst.initialize() + init_fn = getattr(inst, "initialize", None) + if inspect.iscoroutinefunction(init_fn): + await init_fn() # 获取嵌入向量维度 vec = await inst.get_embedding("echo") @@ -777,7 +778,7 @@ async def _get_astrbot_config(self): return {"metadata": CONFIG_METADATA_2, "config": config} async def _get_plugin_config(self, plugin_name: str): - ret = {"metadata": None, "config": None} + ret: dict = {"metadata": None, "config": None} for plugin_md in star_registry: if plugin_md.name == plugin_name: diff --git a/astrbot/dashboard/routes/log.py b/astrbot/dashboard/routes/log.py index eb02fdf40..86cc8c6ca 100644 --- a/astrbot/dashboard/routes/log.py +++ b/astrbot/dashboard/routes/log.py @@ -1,6 +1,8 @@ import asyncio import json +from typing import cast +from quart import Response as QuartResponse from quart import make_response from astrbot.core import LogBroker, logger @@ -39,14 +41,17 @@ async def stream(): if queue: self.log_broker.unregister(queue) - response = await make_response( - stream(), - { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Transfer-Encoding": "chunked", - }, + response = cast( + QuartResponse, + await make_response( + stream(), + { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Transfer-Encoding": "chunked", + }, + ), ) response.timeout = None return response diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index f2a35dfe1..09edafef6 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -545,6 +545,10 @@ async def get_plugin_readme(self): logger.warning(f"插件 {plugin_name} 不存在") return Response().error(f"插件 {plugin_name} 不存在").__dict__ + if not plugin_obj.root_dir_name: + logger.warning(f"插件 {plugin_name} 目录不存在") + return Response().error(f"插件 {plugin_name} 目录不存在").__dict__ + plugin_dir = os.path.join( self.plugin_manager.plugin_store_path, plugin_obj.root_dir_name, diff --git a/astrbot/dashboard/routes/route.py b/astrbot/dashboard/routes/route.py index 1105b69a7..01ab292d4 100644 --- a/astrbot/dashboard/routes/route.py +++ b/astrbot/dashboard/routes/route.py @@ -12,6 +12,8 @@ class RouteContext: class Route: + routes: list | dict + def __init__(self, context: RouteContext): self.app = context.app self.config = context.config diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 22eb2474c..09ec76b52 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -2,9 +2,12 @@ import logging import os import socket +from typing import cast import jwt import psutil +from flask.json.provider import DefaultJSONProvider +from psutil._common import addr as psutil_addr from quart import Quart, g, jsonify, request from quart.logging import default_handler @@ -21,7 +24,7 @@ from .routes.session_management import SessionManagementRoute from .routes.t2i import T2iRoute -APP: Quart = None +APP: Quart class AstrBotDashboard: @@ -48,7 +51,7 @@ def __init__( self.app.config["MAX_CONTENT_LENGTH"] = ( 128 * 1024 * 1024 ) # 将 Flask 允许的最大上传文件体大小设置为 128 MB - self.app.json.sort_keys = False + cast(DefaultJSONProvider, self.app.json).sort_keys = False self.app.before_request(self.auth_middleware) # token 用于验证请求 logging.getLogger(self.app.name).removeHandler(default_handler) @@ -147,7 +150,7 @@ def get_process_using_port(self, port: int) -> str: """获取占用端口的进程详细信息""" try: for conn in psutil.net_connections(kind="inet"): - if conn.laddr.port == port: + if cast(psutil_addr, conn.laddr).port == port: try: process = psutil.Process(conn.pid) # 获取详细信息 diff --git a/packages/astrbot/commands/conversation.py b/packages/astrbot/commands/conversation.py index cdffd3597..9849a62d4 100644 --- a/packages/astrbot/commands/conversation.py +++ b/packages/astrbot/commands/conversation.py @@ -30,6 +30,8 @@ async def _get_current_persona_id(self, session_id): session_id, curr, ) + if not conv: + return None return conv.persona_id def ltm_enabled(self, event: AstrMessageEvent): diff --git a/packages/astrbot/process_llm_request.py b/packages/astrbot/process_llm_request.py index 4a680a040..28c41df9f 100644 --- a/packages/astrbot/process_llm_request.py +++ b/packages/astrbot/process_llm_request.py @@ -139,6 +139,11 @@ async def process_llm_request(self, event: AstrMessageEvent, req: ProviderReques # group name identifier if cfg.get("group_name_display") and event.message_obj.group_id: + if not event.message_obj.group: + logger.error( + f"Group name display enabled but group object is None. Group ID: {event.message_obj.group_id}" + ) + return group_name = event.message_obj.group.group_name if group_name: req.system_prompt += f"\nGroup name: {group_name}\n" diff --git a/packages/python_interpreter/main.py b/packages/python_interpreter/main.py index 35a2f2698..98496157a 100644 --- a/packages/python_interpreter/main.py +++ b/packages/python_interpreter/main.py @@ -14,6 +14,7 @@ from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter from astrbot.api.message_components import File, Image from astrbot.api.provider import ProviderRequest +from astrbot.core.message.components import BaseMessageComponent from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.io import download_file, download_image_by_url @@ -224,6 +225,8 @@ async def on_message(self, event: AstrMessageEvent): del self.user_waiting[uid] elif isinstance(comp, Image): image_url = comp.url if comp.url else comp.file + if image_url is None: + raise ValueError("Image URL is None") if image_url.startswith("http"): image_path = await download_image_by_url(image_url) elif image_url.startswith("file:///"): @@ -240,6 +243,8 @@ async def on_message(self, event: AstrMessageEvent): async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest): if event.get_session_id() in self.user_file_msg_buffer: files = self.user_file_msg_buffer[event.get_session_id()] + if not request.prompt: + request.prompt = "" request.prompt += f"\nUser provided files: {files}" @filter.command_group("pi") @@ -477,7 +482,9 @@ async def python_interpreter(self, event: AstrMessageEvent): # file_s3_url = await self.file_upload(file_path) # logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}") file_name = os.path.basename(file_path) - chain = [File(name=file_name, file=file_path)] + chain: list[BaseMessageComponent] = [ + File(name=file_name, file=file_path) + ] yield event.set_result(MessageEventResult(chain=chain)) elif "Traceback (most recent call last)" in log or "[Error]: " in log: diff --git a/packages/reminder/main.py b/packages/reminder/main.py index eaeec8d73..8f61e02fe 100644 --- a/packages/reminder/main.py +++ b/packages/reminder/main.py @@ -5,6 +5,7 @@ import zoneinfo from apscheduler.schedulers.asyncio import AsyncIOScheduler +from apscheduler.triggers.cron import CronTrigger from astrbot.api import llm_tool, logger, star from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter @@ -62,13 +63,13 @@ def _init_scheduler(self): misfire_grace_time=60, ) elif "cron" in reminder: + trigger = CronTrigger(**self._parse_cron_expr(reminder["cron"])) self.scheduler.add_job( self._reminder_callback, - trigger="cron", + trigger=trigger, id=id_, args=[group, reminder], misfire_grace_time=60, - **self._parse_cron_expr(reminder["cron"]), ) def check_is_outdated(self, reminder: dict): @@ -101,10 +102,10 @@ def _parse_cron_expr(self, cron_expr: str): async def reminder_tool( self, event: AstrMessageEvent, - text: str = None, - datetime_str: str = None, - cron_expression: str = None, - human_readable_cron: str = None, + text: str | None = None, + datetime_str: str | None = None, + cron_expression: str | None = None, + human_readable_cron: str | None = None, ): """Call this function when user is asking for setting a reminder. @@ -139,17 +140,19 @@ async def reminder_tool( "id": str(uuid.uuid4()), } self.reminder_data[event.unified_msg_origin].append(d) + trigger = CronTrigger(**self._parse_cron_expr(cron_expression)) self.scheduler.add_job( self._reminder_callback, - "cron", + trigger, id=d["id"], misfire_grace_time=60, - **self._parse_cron_expr(cron_expression), args=[event.unified_msg_origin, d], ) if human_readable_cron: reminder_time = f"{human_readable_cron}(Cron: {cron_expression})" else: + if datetime_str is None: + raise ValueError("datetime_str cannot be None.") d = {"text": text, "datetime": datetime_str, "id": str(uuid.uuid4())} self.reminder_data[event.unified_msg_origin].append(d) datetime_scheduled = datetime.datetime.strptime( diff --git a/packages/web_searcher/engines/__init__.py b/packages/web_searcher/engines/__init__.py index 706cfa87b..699438602 100644 --- a/packages/web_searcher/engines/__init__.py +++ b/packages/web_searcher/engines/__init__.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from aiohttp import ClientSession -from bs4 import BeautifulSoup +from bs4 import BeautifulSoup, Tag HEADERS = { "User-Agent": "Mozilla/5.0 (Windows NT 6.1; rv:84.0) Gecko/20100101 Firefox/84.0", @@ -45,13 +45,13 @@ def __init__(self) -> None: self.page = 1 self.headers = HEADERS - def _set_selector(self, selector: str) -> None: + def _set_selector(self, selector: str) -> str: raise NotImplementedError - def _get_next_page(self): + def _get_next_page(self, query: str): raise NotImplementedError - async def _get_html(self, url: str, data: dict = None) -> str: + async def _get_html(self, url: str, data: dict | None = None) -> str: headers = self.headers headers["Referer"] = url headers["User-Agent"] = random.choice(USER_AGENTS) @@ -83,6 +83,9 @@ def tidy_text(self, text: str) -> str: """清理文本,去除空格、换行符等""" return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ") + def _get_url(self, tag: Tag) -> str: + return self.tidy_text(tag.get_text()) + async def search(self, query: str, num_results: int) -> list[SearchResult]: query = urllib.parse.quote(query) @@ -92,12 +95,16 @@ async def search(self, query: str, num_results: int) -> list[SearchResult]: links = soup.select(self._set_selector("links")) results = [] for link in links: - title = self.tidy_text( - link.select_one(self._set_selector("title")).text, - ) - url = link.select_one(self._set_selector("url")) + # Safely get the title text (select_one may return None) + title_elem = link.select_one(self._set_selector("title")) + title = "" + if title_elem is not None: + title = self.tidy_text(title_elem.get_text()) + + url_tag = link.select_one(self._set_selector("url")) snippet = "" - if title and url: + if title and url_tag: + url = self._get_url(url_tag) results.append(SearchResult(title=title, url=url, snippet=snippet)) return results[:num_results] if len(results) > num_results else results except Exception as e: diff --git a/packages/web_searcher/engines/bing.py b/packages/web_searcher/engines/bing.py index 4c2ec319d..7565e5df3 100644 --- a/packages/web_searcher/engines/bing.py +++ b/packages/web_searcher/engines/bing.py @@ -1,4 +1,4 @@ -from . import USER_AGENT_BING, SearchEngine, SearchResult +from . import USER_AGENT_BING, SearchEngine class Bing(SearchEngine): @@ -28,11 +28,3 @@ async def _get_next_page(self, query) -> str: self.base_url = base_url continue raise Exception("Bing search failed") - - async def search(self, query: str, num_results: int) -> list[SearchResult]: - results = await super().search(query, num_results) - for result in results: - if not isinstance(result.url, str): - result.url = result.url.text - - return results diff --git a/packages/web_searcher/engines/sogo.py b/packages/web_searcher/engines/sogo.py index 382e7c937..f490f1106 100644 --- a/packages/web_searcher/engines/sogo.py +++ b/packages/web_searcher/engines/sogo.py @@ -1,7 +1,8 @@ import random import re +from typing import cast -from bs4 import BeautifulSoup +from bs4 import BeautifulSoup, Tag from . import USER_AGENTS, SearchEngine, SearchResult @@ -26,10 +27,12 @@ async def _get_next_page(self, query) -> str: url = f"{self.base_url}/web?query={query}" return await self._get_html(url, None) + def _get_url(self, tag: Tag) -> str: + return cast(str, tag.get("href")) + async def search(self, query: str, num_results: int) -> list[SearchResult]: results = await super().search(query, num_results) for result in results: - result.url = result.url.get("href") if result.url.startswith("/link?"): result.url = self.base_url + result.url result.url = await self._parse_url(result.url) @@ -40,7 +43,10 @@ async def _parse_url(self, url) -> str: soup = BeautifulSoup(html, "html.parser") script = soup.find("script") if script: - url = re.search(r'window.location.replace\("(.+?)"\)', script.string).group( - 1, + script_text = ( + script.string if script.string is not None else script.get_text() ) + match = re.search(r'window.location.replace\("(.+?)"\)', script_text) + if match: + url = match.group(1) return url diff --git a/typings/faiss/__init__.pyi b/typings/faiss/__init__.pyi new file mode 100644 index 000000000..6f2bace36 --- /dev/null +++ b/typings/faiss/__init__.pyi @@ -0,0 +1,90 @@ +"""Minimal type stubs for faiss used in this project. + +This file only exposes a small subset of the faiss API that the +project uses, including the runtime-monkeypatched signatures such as +`Index.add_with_ids` so Pyright/Pylance stops reporting false positives. +""" + +from typing import Any, overload + +import numpy as np + +class Index: + d: int + ntotal: int + code_size: int + nprobe: int + + def add(self, x: np.ndarray) -> None: ... + def add_with_ids(self, x: np.ndarray, ids: np.ndarray) -> None: ... + def search( + self, + x: np.ndarray, + k: int, + *, + params: Any = ..., + D: np.ndarray | None = ..., + I: np.ndarray | None = ..., + ) -> tuple[np.ndarray, np.ndarray]: ... + def remove_ids(self, x: np.ndarray) -> int: ... + @overload + def reconstruct(self, key: int) -> np.ndarray: ... + @overload + def reconstruct(self, key: int, x: np.ndarray) -> None: ... + def reconstruct( + self, key: int, x: np.ndarray | None = ... + ) -> np.ndarray | None: ... + @overload + def reconstruct_n(self, n0: int, ni: int) -> np.ndarray: ... + @overload + def reconstruct_n(self, n0: int, ni: int, x: np.ndarray) -> None: ... + def reconstruct_n( + self, n0: int = ..., ni: int = ..., x: np.ndarray | None = ... + ) -> np.ndarray | None: ... + def range_search( + self, x: np.ndarray, thresh: float, *, params: Any = ... + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: ... + def add_sa_codes(self, codes: np.ndarray, ids: np.ndarray | None = ...) -> None: ... + def sa_encode(self, x: np.ndarray) -> np.ndarray: ... + def sa_decode(self, codes: np.ndarray) -> np.ndarray: ... + +class IndexFlatL2(Index): + def __init__(self, d: int) -> None: ... + +class IndexIDMap(Index): + index: Index + + def __init__(self, index: Index) -> None: ... + +def read_index(path: str) -> Index: ... +def write_index(index: Index, path: str | None = ...) -> None: ... +def normalize_L2(x: np.ndarray) -> None: ... + +# Additional concrete-ish classes exposed by some faiss builds (SWIG helpers +# expose `downcast_*` helpers to convert generic objects to these concrete +# types). We keep these minimal — only the names are important for typing. +class IndexBinary(Index): + def __init__(self, d: int) -> None: ... + +class InvertedLists: + def __len__(self) -> int: ... + +class AdditiveQuantizer: + pass + +class Quantizer: + pass + +class VectorTransform: + pass + +# SWIG-provided downcast helpers (present in some faiss Python builds). +def downcast_IndexBinary(obj: Any) -> IndexBinary: ... +def downcast_InvertedLists(obj: Any) -> InvertedLists: ... +def downcast_AdditiveQuantizer(obj: Any) -> AdditiveQuantizer: ... +def downcast_Quantizer(obj: Any) -> Quantizer: ... +def downcast_VectorTransform(obj: Any) -> VectorTransform: ... +def downcast_index(obj: Any) -> Index: ... + +# version exposed by runtime +__version__: str