diff --git a/astrbot/core/star/star_tools.py b/astrbot/core/star/star_tools.py index 42ed168ff..877241606 100644 --- a/astrbot/core/star/star_tools.py +++ b/astrbot/core/star/star_tools.py @@ -21,11 +21,13 @@ import inspect import os import uuid +import asyncio from pathlib import Path -from typing import Union, Awaitable, List, Optional, ClassVar +from typing import Union, Awaitable, List, Optional, ClassVar, Dict, Any, Callable +from astrbot.core import logger from astrbot.core.message.components import BaseMessageComponent from astrbot.core.message.message_event_result import MessageChain -from astrbot.api.platform import MessageMember, AstrBotMessage, MessageType +from astrbot.core.platform import MessageMember, AstrBotMessage, MessageType from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.star.context import Context from astrbot.core.star.star import star_map @@ -40,6 +42,10 @@ class StarTools: """ _context: ClassVar[Optional[Context]] = None + _shared_data: ClassVar[Dict[str, Any]] = {} + _data_listeners: ClassVar[Dict[str, List[Callable[[str, Any], Awaitable[None]]]]] = {} + _data_lock: ClassVar[asyncio.Lock] = asyncio.Lock() + @classmethod def initialize(cls, context: Context) -> None: @@ -257,33 +263,249 @@ def get_data_dir(cls, plugin_name: Optional[str] = None) -> Path: - 无法获取模块的元数据信息 - 创建目录失败(权限不足或其他IO错误) """ - if not plugin_name: - frame = inspect.currentframe() - module = None + resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) + + if not resolved_plugin_name: + raise ValueError("无法获取插件名称") + + data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", resolved_plugin_name)) + + try: + data_dir.mkdir(parents=True, exist_ok=True) + except OSError as e: + if isinstance(e, PermissionError): + raise RuntimeError(f"无法创建目录 {data_dir}:权限不足") from e + raise RuntimeError(f"无法创建目录 {data_dir}:{e!s}") from e + + return data_dir.resolve() + + @classmethod + def _get_caller_plugin_name(cls, plugin_name: Optional[str]) -> str: + """ + 通过调用栈获取插件名称 + + Returns: + str: 插件名称 + + Raises: + RuntimeError: 当无法获取调用者模块信息或元数据信息时抛出 + """ + if plugin_name is not None: + return plugin_name + + frame = inspect.currentframe() + try: if frame: frame = frame.f_back - module = inspect.getmodule(frame) + if frame: + frame = frame.f_back + + if not frame: + raise RuntimeError("无法获取调用者帧信息") + module = inspect.getmodule(frame) if not module: raise RuntimeError("无法获取调用者模块信息") metadata = star_map.get(module.__name__, None) - if not metadata: raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息") - plugin_name = metadata.name + return metadata.name + finally: + del frame - if not plugin_name: - raise ValueError("无法获取插件名称") + @classmethod + async def set_shared_data(cls, key: str, value: Any, plugin_name: Optional[str] = None) -> None: + """ + 设置插件间共享数据 - data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name)) + Args: + key (str): 数据键名 + value (Any): 要存储的数据,支持任意数据类型 + plugin_name (Optional[str]): 插件名称,如果为None则自动检测 + + Example: + # 设置工作状态 + StarTools.set_shared_data("worker_status", True) + + # 设置复杂数据 + StarTools.set_shared_data("task_progress", { + "current": 5, + "total": 10, + "status": "processing" + }) + """ + resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) + full_key = f"{resolved_plugin_name}:{key}" + + async with cls._data_lock: + cls._shared_data[full_key] = value + + await cls._notify_listeners(full_key, value) + + @classmethod + async def get_shared_data(cls, key: str, plugin_name: Optional[str] = None, default: Any = None) -> Any: + """ + 获取插件间共享数据 + + Args: + key (str): 数据键名 + plugin_name (Optional[str]): 插件名称,如果为None则自动检测 + default (Any): 当数据不存在时返回的默认值 + + Returns: + Any: 存储的数据,如果不存在则返回default + + Example: + # 获取其他插件的工作状态 + status = StarTools.get_shared_data("worker_status", "other_plugin") + + # 获取当前插件的数据 + my_data = StarTools.get_shared_data("my_key") + """ + resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) + full_key = f"{resolved_plugin_name}:{key}" + + async with cls._data_lock: + return cls._shared_data.get(full_key, default) + + @classmethod + async def remove_shared_data(cls, key: str, plugin_name: Optional[str] = None) -> bool: + """ + 删除插件间共享数据 + + Args: + key (str): 数据键名 + plugin_name (Optional[str]): 插件名称,如果为None则自动检测 + + Returns: + bool: 是否成功删除(True表示数据存在并被删除,False表示数据不存在) + """ + resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) + full_key = f"{resolved_plugin_name}:{key}" + + async with cls._data_lock: + if full_key in cls._shared_data: + del cls._shared_data[full_key] + return True + return False + + @classmethod + async def list_shared_data(cls, plugin_name: Optional[str] = None) -> Dict[str, Any]: + """ + 列出指定插件的所有共享数据 + + Args: + plugin_name (Optional[str]): 插件名称,如果为None则返回当前插件数据,如果为空字符串则返回所有数据 + + Returns: + Dict[str, Any]: 数据字典,键为原始键名(不包含插件前缀) + + Example: + # 获取当前插件的所有数据 + my_data = StarTools.list_shared_data() + + # 获取所有插件的数据 + all_data = StarTools.list_shared_data("") + """ + async with cls._data_lock: + if plugin_name == "": + return dict(cls._shared_data) + + resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) + prefix = f"{resolved_plugin_name}:" + result = {} + for full_key, value in cls._shared_data.items(): + if full_key.startswith(prefix): + original_key = full_key[len(prefix):] + result[original_key] = value + return result + + @classmethod + async def add_data_listener( + cls, + key: str, + callback: Callable[[str, Any], Awaitable[None]], + plugin_name: Optional[str] = None + ) -> None: + """ + 添加数据变化监听器 + + Args: + key (str): 要监听的数据键名 + callback (Callable): 回调函数,接受参数(key, new_value) + plugin_name (Optional[str]): 插件名称,如果为None则自动检测 + + Example: + async def on_worker_status_change(key: str, value: Any): + if value: + logger.info("哈哈我的工作完成啦!") + + StarTools.add_data_listener("worker_status", on_worker_status_change, "other_plugin") + """ + resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) + full_key = f"{resolved_plugin_name}:{key}" + + async with cls._data_lock: + if full_key not in cls._data_listeners: + cls._data_listeners[full_key] = [] + cls._data_listeners[full_key].append(callback) + + @classmethod + async def remove_data_listener( + cls, + key: str, + callback: Callable[[str, Any], Awaitable[None]], + plugin_name: Optional[str] = None + ) -> bool: + """ + 移除数据变化监听器 + + Args: + key (str): 数据键名 + callback (Callable): 要移除的回调函数 + plugin_name (Optional[str]): 插件名称,如果为None则自动检测 + + Returns: + bool: 是否成功移除 + """ + resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) + full_key = f"{resolved_plugin_name}:{key}" + + async with cls._data_lock: + if full_key in cls._data_listeners and callback in cls._data_listeners[full_key]: + cls._data_listeners[full_key].remove(callback) + if not cls._data_listeners[full_key]: + del cls._data_listeners[full_key] + return True + return False + + @classmethod + async def _notify_listeners(cls, full_key: str, value: Any) -> None: + """ + 通知所有监听指定数据的回调函数 + + Args: + full_key (str): 完整的数据键名(包含插件前缀) + value (Any): 新的数据值 + """ + listeners = [] + async with cls._data_lock: + if full_key in cls._data_listeners: + listeners = cls._data_listeners[full_key].copy() + + if listeners: + tasks = [] + for callback in listeners: + try: + task = callback(full_key, value) + if asyncio.iscoroutine(task): + tasks.append(task) + except Exception as e: + logger.error(f"数据监听器错误:{full_key}: {e}") + + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) - try: - data_dir.mkdir(parents=True, exist_ok=True) - except OSError as e: - if isinstance(e, PermissionError): - raise RuntimeError(f"无法创建目录 {data_dir}:权限不足") from e - raise RuntimeError(f"无法创建目录 {data_dir}:{e!s}") from e - return data_dir.resolve()