Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
258 changes: 240 additions & 18 deletions astrbot/core/star/star_tools.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from typing import Union, Awaitable, List, Optional, ClassVar, Dict, Any, Callable
部分类型已过时,比如List,Dict,在3.9已经过时。应该改成内置的list,dict.Optional可以改成丨 None, Callable移动到了新库,以及类似的问题。
为了风格统一,便于维护,尽量完善下类型标注相关的工作。
except OSError as e:
if isinstance(e, PermissionError):
raise RuntimeError(f"无法创建目录 {data_dir}:权限不足") from e
raise RuntimeError(f"无法创建目录 {data_dir}:{e!s}") from e

没必要用isinstance判断错误类型。
这样写:
except PermissionError as e:
raise RuntimeError(f"无法创建目录 {data_dir}:权限不足") from e
except OSError as e:
raise RuntimeError(f"无法创建目录 {data_dir}:{e!s}") from e
更加美观

Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
import inspect
import os
import uuid
import asyncio
from pathlib import Path
from typing import Union, Awaitable, List, Optional, ClassVar
from typing import Union, Awaitable, List, Optional, ClassVar, Dict, Any, Callable
from astrbot.core import logger
from astrbot.core.message.components import BaseMessageComponent
from astrbot.core.message.message_event_result import MessageChain
from astrbot.api.platform import MessageMember, AstrBotMessage, MessageType
from astrbot.core.platform import MessageMember, AstrBotMessage, MessageType
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.star.context import Context
from astrbot.core.star.star import star_map
Expand All @@ -40,6 +42,10 @@ class StarTools:
"""

_context: ClassVar[Optional[Context]] = None
_shared_data: ClassVar[Dict[str, Any]] = {}
_data_listeners: ClassVar[Dict[str, List[Callable[[str, Any], Awaitable[None]]]]] = {}
_data_lock: ClassVar[asyncio.Lock] = asyncio.Lock()


@classmethod
def initialize(cls, context: Context) -> None:
Expand Down Expand Up @@ -257,33 +263,249 @@ def get_data_dir(cls, plugin_name: Optional[str] = None) -> Path:
- 无法获取模块的元数据信息
- 创建目录失败(权限不足或其他IO错误)
"""
if not plugin_name:
frame = inspect.currentframe()
module = None
resolved_plugin_name = cls._get_caller_plugin_name(plugin_name)

if not resolved_plugin_name:
raise ValueError("无法获取插件名称")

data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", resolved_plugin_name))

try:
data_dir.mkdir(parents=True, exist_ok=True)
except OSError as e:
if isinstance(e, PermissionError):
raise RuntimeError(f"无法创建目录 {data_dir}:权限不足") from e
raise RuntimeError(f"无法创建目录 {data_dir}:{e!s}") from e

return data_dir.resolve()

@classmethod
def _get_caller_plugin_name(cls, plugin_name: Optional[str]) -> str:
"""
通过调用栈获取插件名称

Returns:
str: 插件名称

Raises:
RuntimeError: 当无法获取调用者模块信息或元数据信息时抛出
"""
if plugin_name is not None:
return plugin_name

frame = inspect.currentframe()
try:
if frame:
frame = frame.f_back
module = inspect.getmodule(frame)
if frame:
frame = frame.f_back

if not frame:
raise RuntimeError("无法获取调用者帧信息")

module = inspect.getmodule(frame)
if not module:
raise RuntimeError("无法获取调用者模块信息")

metadata = star_map.get(module.__name__, None)

if not metadata:
raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息")

plugin_name = metadata.name
return metadata.name
finally:
del frame

if not plugin_name:
raise ValueError("无法获取插件名称")
@classmethod
async def set_shared_data(cls, key: str, value: Any, plugin_name: Optional[str] = None) -> None:
"""
设置插件间共享数据

data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name))
Args:
key (str): 数据键名
value (Any): 要存储的数据,支持任意数据类型
plugin_name (Optional[str]): 插件名称,如果为None则自动检测

Example:
# 设置工作状态
StarTools.set_shared_data("worker_status", True)

# 设置复杂数据
StarTools.set_shared_data("task_progress", {
"current": 5,
"total": 10,
"status": "processing"
})
"""
resolved_plugin_name = cls._get_caller_plugin_name(plugin_name)
full_key = f"{resolved_plugin_name}:{key}"

async with cls._data_lock:
cls._shared_data[full_key] = value

await cls._notify_listeners(full_key, value)

@classmethod
async def get_shared_data(cls, key: str, plugin_name: Optional[str] = None, default: Any = None) -> Any:
"""
获取插件间共享数据

Args:
key (str): 数据键名
plugin_name (Optional[str]): 插件名称,如果为None则自动检测
default (Any): 当数据不存在时返回的默认值

Returns:
Any: 存储的数据,如果不存在则返回default

Example:
# 获取其他插件的工作状态
status = StarTools.get_shared_data("worker_status", "other_plugin")

# 获取当前插件的数据
my_data = StarTools.get_shared_data("my_key")
"""
resolved_plugin_name = cls._get_caller_plugin_name(plugin_name)
full_key = f"{resolved_plugin_name}:{key}"

async with cls._data_lock:
return cls._shared_data.get(full_key, default)

@classmethod
async def remove_shared_data(cls, key: str, plugin_name: Optional[str] = None) -> bool:
"""
删除插件间共享数据

Args:
key (str): 数据键名
plugin_name (Optional[str]): 插件名称,如果为None则自动检测

Returns:
bool: 是否成功删除(True表示数据存在并被删除,False表示数据不存在)
"""
resolved_plugin_name = cls._get_caller_plugin_name(plugin_name)
full_key = f"{resolved_plugin_name}:{key}"

async with cls._data_lock:
if full_key in cls._shared_data:
del cls._shared_data[full_key]
return True
return False

@classmethod
async def list_shared_data(cls, plugin_name: Optional[str] = None) -> Dict[str, Any]:
"""
列出指定插件的所有共享数据

Args:
plugin_name (Optional[str]): 插件名称,如果为None则返回当前插件数据,如果为空字符串则返回所有数据

Returns:
Dict[str, Any]: 数据字典,键为原始键名(不包含插件前缀)

Example:
# 获取当前插件的所有数据
my_data = StarTools.list_shared_data()

# 获取所有插件的数据
all_data = StarTools.list_shared_data("")
"""
async with cls._data_lock:
if plugin_name == "":
return dict(cls._shared_data)

resolved_plugin_name = cls._get_caller_plugin_name(plugin_name)
prefix = f"{resolved_plugin_name}:"
result = {}
for full_key, value in cls._shared_data.items():
if full_key.startswith(prefix):
original_key = full_key[len(prefix):]
result[original_key] = value
return result

@classmethod
async def add_data_listener(
cls,
key: str,
callback: Callable[[str, Any], Awaitable[None]],
plugin_name: Optional[str] = None
) -> None:
"""
添加数据变化监听器

Args:
key (str): 要监听的数据键名
callback (Callable): 回调函数,接受参数(key, new_value)
plugin_name (Optional[str]): 插件名称,如果为None则自动检测

Example:
async def on_worker_status_change(key: str, value: Any):
if value:
logger.info("哈哈我的工作完成啦!")

StarTools.add_data_listener("worker_status", on_worker_status_change, "other_plugin")
"""
resolved_plugin_name = cls._get_caller_plugin_name(plugin_name)
full_key = f"{resolved_plugin_name}:{key}"

async with cls._data_lock:
if full_key not in cls._data_listeners:
cls._data_listeners[full_key] = []
cls._data_listeners[full_key].append(callback)

@classmethod
async def remove_data_listener(
cls,
key: str,
callback: Callable[[str, Any], Awaitable[None]],
plugin_name: Optional[str] = None
) -> bool:
"""
移除数据变化监听器

Args:
key (str): 数据键名
callback (Callable): 要移除的回调函数
plugin_name (Optional[str]): 插件名称,如果为None则自动检测

Returns:
bool: 是否成功移除
"""
resolved_plugin_name = cls._get_caller_plugin_name(plugin_name)
full_key = f"{resolved_plugin_name}:{key}"

async with cls._data_lock:
if full_key in cls._data_listeners and callback in cls._data_listeners[full_key]:
cls._data_listeners[full_key].remove(callback)
if not cls._data_listeners[full_key]:
del cls._data_listeners[full_key]
return True
return False

@classmethod
async def _notify_listeners(cls, full_key: str, value: Any) -> None:
"""
通知所有监听指定数据的回调函数

Args:
full_key (str): 完整的数据键名(包含插件前缀)
value (Any): 新的数据值
"""
listeners = []
async with cls._data_lock:
if full_key in cls._data_listeners:
listeners = cls._data_listeners[full_key].copy()

if listeners:
tasks = []
for callback in listeners:
try:
task = callback(full_key, value)
if asyncio.iscoroutine(task):
tasks.append(task)
except Exception as e:
logger.error(f"数据监听器错误:{full_key}: {e}")

if tasks:
await asyncio.gather(*tasks, return_exceptions=True)

try:
data_dir.mkdir(parents=True, exist_ok=True)
except OSError as e:
if isinstance(e, PermissionError):
raise RuntimeError(f"无法创建目录 {data_dir}:权限不足") from e
raise RuntimeError(f"无法创建目录 {data_dir}:{e!s}") from e

return data_dir.resolve()
Loading