diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 6402aeaed..7818efbe5 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -1,5 +1,6 @@ import abc import asyncio +import copy import hashlib import re import uuid @@ -29,6 +30,9 @@ class AstrMessageEvent(abc.ABC): + # extras 中可安全清理的瞬态字段清单;子类可按需扩展 + TRANSIENT_EXTRA_KEYS: set[str] = set() + def __init__( self, message_str: str, @@ -71,6 +75,8 @@ def __init__( # back_compability self.platform = platform_meta + # 可选的绕过标记,避免被 SessionWaiter 再次截获 + self._bypass_session_waiter = False def get_platform_name(self): """获取这个事件所属的平台的类型(如 aiocqhttp, slack, discord 等)。 @@ -278,6 +284,22 @@ def clear_result(self): """清除消息事件的结果。""" self._result = None + def clone_for_llm(self) -> "AstrMessageEvent": + """浅拷贝并重置状态,以便重新走默认 LLM 流程。""" + new_event: AstrMessageEvent = copy.copy(self) + new_event.clear_result() + # 保留非瞬态 extras,避免跨管线上下文丢失 + new_event._extras = self._extras.copy() + for key in self.TRANSIENT_EXTRA_KEYS: + new_event._extras.pop(key, None) + new_event._has_send_oper = False + new_event.call_llm = False + new_event.is_wake = False + new_event.is_at_or_wake_command = False + new_event.plugins_name = None + new_event._bypass_session_waiter = False + return new_event + """消息链相关""" def make_result(self) -> MessageEventResult: diff --git a/astrbot/core/utils/session_waiter.py b/astrbot/core/utils/session_waiter.py index 33b7cb17a..c42a81ffb 100644 --- a/astrbot/core/utils/session_waiter.py +++ b/astrbot/core/utils/session_waiter.py @@ -9,6 +9,7 @@ from typing import Any import astrbot.core.message.components as Comp +from astrbot import logger from astrbot.core.platform import AstrMessageEvent USER_SESSIONS: dict[str, "SessionWaiter"] = {} # 存储 SessionWaiter 实例 @@ -29,6 +30,33 @@ def __init__(self): self.history_chains: list[list[Comp.BaseMessageComponent]] = [] + def fallback_to_llm( + self, + event_queue: asyncio.Queue, + event: AstrMessageEvent, + *, + stop_session: bool = True, + ) -> AstrMessageEvent: + """将当前事件重新入队,由默认 LLM 流程处理,适用于非预期输入的兜底。 + + Args: + event_queue: 事件队列 + event: 当前事件 + stop_session: 是否结束当前 SessionWaiter。False 时仅兜底当前输入,继续等待后续输入。 + """ + if not stop_session: + logger.warning( + "fallback_to_llm(stop_session=False) 会保留当前会话,默认会话拦截可能导致兜底无效," + "建议谨慎使用或在后续输入中自行终止会话。", + ) + new_event = event.clone_for_llm() + new_event._bypass_session_waiter = not stop_session + event_queue.put_nowait(new_event) + event.stop_event() + if stop_session: + self.stop() + return new_event + def stop(self, error: Exception = None): """立即结束这个会话""" if not self.future.done(): @@ -147,11 +175,15 @@ def _cleanup(self, error: Exception = None): self.session_controller.stop(error) @classmethod - async def trigger(cls, session_id: str, event: AstrMessageEvent): - """外部输入触发会话处理""" + async def trigger(cls, session_id: str, event: AstrMessageEvent) -> bool: + """外部输入触发会话处理 + + Returns: + bool: 是否成功触发处理。False 表示会话不存在或已结束。 + """ session = USER_SESSIONS.get(session_id) if not session or session.session_controller.future.done(): - return + return False async with session._lock: if not session.session_controller.future.done(): @@ -164,6 +196,8 @@ async def trigger(cls, session_id: str, event: AstrMessageEvent): await session.handler(session.session_controller, event) except Exception as e: session.session_controller.stop(e) + return True + return False def session_waiter(timeout: int = 30, record_history_chains: bool = False): diff --git a/packages/session_controller/main.py b/packages/session_controller/main.py index 4d4a42528..038118c2e 100644 --- a/packages/session_controller/main.py +++ b/packages/session_controller/main.py @@ -1,4 +1,3 @@ -import copy from sys import maxsize import astrbot.api.message_components as Comp @@ -7,7 +6,6 @@ from astrbot.api.star import Context, Star from astrbot.core.utils.session_waiter import ( FILTERS, - USER_SESSIONS, SessionController, SessionWaiter, session_waiter, @@ -23,10 +21,12 @@ def __init__(self, context: Context): @filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize) async def handle_session_control_agent(self, event: AstrMessageEvent): """会话控制代理""" + if getattr(event, "_bypass_session_waiter", False): + return for session_filter in FILTERS: session_id = session_filter.filter(event) - if session_id in USER_SESSIONS: - await SessionWaiter.trigger(session_id, event) + handled = await SessionWaiter.trigger(session_id, event) + if handled: event.stop_event() @filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize - 1) @@ -96,11 +96,10 @@ async def empty_mention_waiter( 0, Comp.At(qq=event.get_self_id(), name=event.get_self_id()), ) - new_event = copy.copy(event) - # 重新推入事件队列 - self.context.get_event_queue().put_nowait(new_event) - event.stop_event() - controller.stop() + controller.fallback_to_llm( + self.context.get_event_queue(), + event, + ) try: await empty_mention_waiter(event)