-
-
Notifications
You must be signed in to change notification settings - Fork 928
feat: 增加供插件使用的数据存取方法及监听器 #2718
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
anka-afk
wants to merge
3
commits into
AstrBotDevs:master
Choose a base branch
from
anka-afk:plugin-data
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
feat: 增加供插件使用的数据存取方法及监听器 #2718
Changes from 2 commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,11 +21,13 @@ | |
import inspect | ||
import os | ||
import uuid | ||
import asyncio | ||
from pathlib import Path | ||
from typing import Union, Awaitable, List, Optional, ClassVar | ||
from typing import Union, Awaitable, List, Optional, ClassVar, Dict, Any, Callable | ||
from astrbot.core import logger | ||
from astrbot.core.message.components import BaseMessageComponent | ||
from astrbot.core.message.message_event_result import MessageChain | ||
from astrbot.api.platform import MessageMember, AstrBotMessage, MessageType | ||
from astrbot.core.platform import MessageMember, AstrBotMessage, MessageType | ||
from astrbot.core.platform.astr_message_event import MessageSesion | ||
from astrbot.core.star.context import Context | ||
from astrbot.core.star.star import star_map | ||
|
@@ -40,6 +42,10 @@ class StarTools: | |
""" | ||
|
||
_context: ClassVar[Optional[Context]] = None | ||
_shared_data: ClassVar[Dict[str, Any]] = {} | ||
_data_listeners: ClassVar[Dict[str, List[Callable[[str, Any], Awaitable[None]]]]] = {} | ||
_data_lock: ClassVar[asyncio.Lock] = asyncio.Lock() | ||
|
||
|
||
@classmethod | ||
def initialize(cls, context: Context) -> None: | ||
|
@@ -257,33 +263,251 @@ def get_data_dir(cls, plugin_name: Optional[str] = None) -> Path: | |
- 无法获取模块的元数据信息 | ||
- 创建目录失败(权限不足或其他IO错误) | ||
""" | ||
if not plugin_name: | ||
frame = inspect.currentframe() | ||
module = None | ||
resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) | ||
|
||
if not resolved_plugin_name: | ||
raise ValueError("无法获取插件名称") | ||
|
||
data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", resolved_plugin_name)) | ||
|
||
try: | ||
data_dir.mkdir(parents=True, exist_ok=True) | ||
except OSError as e: | ||
if isinstance(e, PermissionError): | ||
raise RuntimeError(f"无法创建目录 {data_dir}:权限不足") from e | ||
raise RuntimeError(f"无法创建目录 {data_dir}:{e!s}") from e | ||
|
||
return data_dir.resolve() | ||
|
||
return data_dir.resolve() | ||
|
||
|
||
@classmethod | ||
def _get_caller_plugin_name(cls, plugin_name: Optional[str]) -> str: | ||
""" | ||
通过调用栈获取插件名称 | ||
|
||
Returns: | ||
str: 插件名称 | ||
|
||
Raises: | ||
RuntimeError: 当无法获取调用者模块信息或元数据信息时抛出 | ||
""" | ||
if plugin_name is not None: | ||
return plugin_name | ||
|
||
frame = inspect.currentframe() | ||
try: | ||
if frame: | ||
frame = frame.f_back | ||
module = inspect.getmodule(frame) | ||
if frame: | ||
frame = frame.f_back | ||
|
||
if not frame: | ||
raise RuntimeError("无法获取调用者帧信息") | ||
|
||
module = inspect.getmodule(frame) | ||
if not module: | ||
raise RuntimeError("无法获取调用者模块信息") | ||
|
||
metadata = star_map.get(module.__name__, None) | ||
|
||
if not metadata: | ||
raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息") | ||
|
||
plugin_name = metadata.name | ||
return metadata.name | ||
finally: | ||
del frame | ||
|
||
if not plugin_name: | ||
raise ValueError("无法获取插件名称") | ||
@classmethod | ||
async def set_shared_data(cls, key: str, value: Any, plugin_name: Optional[str] = None) -> None: | ||
""" | ||
设置插件间共享数据 | ||
|
||
data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name)) | ||
Args: | ||
key (str): 数据键名 | ||
value (Any): 要存储的数据,支持任意数据类型 | ||
plugin_name (Optional[str]): 插件名称,如果为None则自动检测 | ||
|
||
Example: | ||
# 设置工作状态 | ||
StarTools.set_shared_data("worker_status", True) | ||
|
||
# 设置复杂数据 | ||
StarTools.set_shared_data("task_progress", { | ||
"current": 5, | ||
"total": 10, | ||
"status": "processing" | ||
}) | ||
""" | ||
resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) | ||
full_key = f"{resolved_plugin_name}:{key}" | ||
|
||
async with cls._data_lock: | ||
cls._shared_data[full_key] = value | ||
|
||
await cls._notify_listeners(full_key, value) | ||
|
||
@classmethod | ||
async def get_shared_data(cls, key: str, plugin_name: Optional[str] = None, default: Any = None) -> Any: | ||
""" | ||
获取插件间共享数据 | ||
|
||
Args: | ||
key (str): 数据键名 | ||
plugin_name (Optional[str]): 插件名称,如果为None则自动检测 | ||
default (Any): 当数据不存在时返回的默认值 | ||
|
||
Returns: | ||
Any: 存储的数据,如果不存在则返回default | ||
|
||
Example: | ||
# 获取其他插件的工作状态 | ||
status = StarTools.get_shared_data("worker_status", "other_plugin") | ||
|
||
# 获取当前插件的数据 | ||
my_data = StarTools.get_shared_data("my_key") | ||
""" | ||
resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) | ||
full_key = f"{resolved_plugin_name}:{key}" | ||
|
||
async with cls._data_lock: | ||
return cls._shared_data.get(full_key, default) | ||
|
||
@classmethod | ||
async def remove_shared_data(cls, key: str, plugin_name: Optional[str] = None) -> bool: | ||
""" | ||
删除插件间共享数据 | ||
|
||
Args: | ||
key (str): 数据键名 | ||
plugin_name (Optional[str]): 插件名称,如果为None则自动检测 | ||
|
||
Returns: | ||
bool: 是否成功删除(True表示数据存在并被删除,False表示数据不存在) | ||
""" | ||
resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) | ||
full_key = f"{resolved_plugin_name}:{key}" | ||
|
||
async with cls._data_lock: | ||
if full_key in cls._shared_data: | ||
del cls._shared_data[full_key] | ||
return True | ||
return False | ||
|
||
@classmethod | ||
async def list_shared_data(cls, plugin_name: Optional[str] = None) -> Dict[str, Any]: | ||
""" | ||
列出指定插件的所有共享数据 | ||
|
||
Args: | ||
plugin_name (Optional[str]): 插件名称,如果为None则返回当前插件数据,如果为空字符串则返回所有数据 | ||
|
||
Returns: | ||
Dict[str, Any]: 数据字典,键为原始键名(不包含插件前缀) | ||
|
||
Example: | ||
# 获取当前插件的所有数据 | ||
my_data = StarTools.list_shared_data() | ||
|
||
# 获取所有插件的数据 | ||
all_data = StarTools.list_shared_data("") | ||
""" | ||
async with cls._data_lock: | ||
if plugin_name == "": | ||
return dict(cls._shared_data) | ||
|
||
resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) | ||
prefix = f"{resolved_plugin_name}:" | ||
result = {} | ||
for full_key, value in cls._shared_data.items(): | ||
if full_key.startswith(prefix): | ||
original_key = full_key[len(prefix):] | ||
result[original_key] = value | ||
return result | ||
|
||
@classmethod | ||
async def add_data_listener( | ||
cls, | ||
key: str, | ||
callback: Callable[[str, Any], Awaitable[None]], | ||
plugin_name: Optional[str] = None | ||
) -> None: | ||
""" | ||
添加数据变化监听器 | ||
|
||
Args: | ||
key (str): 要监听的数据键名 | ||
callback (Callable): 回调函数,接受参数(key, new_value) | ||
plugin_name (Optional[str]): 插件名称,如果为None则自动检测 | ||
|
||
Example: | ||
async def on_worker_status_change(key: str, value: Any): | ||
if value: | ||
logger.info("哈哈我的工作完成啦!") | ||
|
||
StarTools.add_data_listener("worker_status", on_worker_status_change, "other_plugin") | ||
""" | ||
resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) | ||
full_key = f"{resolved_plugin_name}:{key}" | ||
|
||
async with cls._data_lock: | ||
if full_key not in cls._data_listeners: | ||
cls._data_listeners[full_key] = [] | ||
cls._data_listeners[full_key].append(callback) | ||
|
||
@classmethod | ||
async def remove_data_listener( | ||
cls, | ||
key: str, | ||
callback: Callable[[str, Any], Awaitable[None]], | ||
plugin_name: Optional[str] = None | ||
) -> bool: | ||
""" | ||
移除数据变化监听器 | ||
|
||
Args: | ||
key (str): 数据键名 | ||
callback (Callable): 要移除的回调函数 | ||
plugin_name (Optional[str]): 插件名称,如果为None则自动检测 | ||
|
||
Returns: | ||
bool: 是否成功移除 | ||
""" | ||
resolved_plugin_name = cls._get_caller_plugin_name(plugin_name) | ||
full_key = f"{resolved_plugin_name}:{key}" | ||
|
||
async with cls._data_lock: | ||
if full_key in cls._data_listeners and callback in cls._data_listeners[full_key]: | ||
cls._data_listeners[full_key].remove(callback) | ||
if not cls._data_listeners[full_key]: | ||
del cls._data_listeners[full_key] | ||
return True | ||
return False | ||
|
||
@classmethod | ||
async def _notify_listeners(cls, full_key: str, value: Any) -> None: | ||
sourcery-ai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
通知所有监听指定数据的回调函数 | ||
|
||
Args: | ||
full_key (str): 完整的数据键名(包含插件前缀) | ||
value (Any): 新的数据值 | ||
""" | ||
listeners = [] | ||
async with cls._data_lock: | ||
if full_key in cls._data_listeners: | ||
listeners = cls._data_listeners[full_key].copy() | ||
|
||
if listeners: | ||
tasks = [] | ||
for callback in listeners: | ||
try: | ||
task = callback(full_key, value) | ||
if asyncio.iscoroutine(task): | ||
tasks.append(task) | ||
except Exception as e: | ||
logger.error(f"数据监听器错误:{full_key}: {e}") | ||
|
||
if tasks: | ||
await asyncio.gather(*tasks, return_exceptions=True) | ||
|
||
try: | ||
data_dir.mkdir(parents=True, exist_ok=True) | ||
except OSError as e: | ||
if isinstance(e, PermissionError): | ||
raise RuntimeError(f"无法创建目录 {data_dir}:权限不足") from e | ||
raise RuntimeError(f"无法创建目录 {data_dir}:{e!s}") from e | ||
|
||
return data_dir.resolve() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from typing import Union, Awaitable, List, Optional, ClassVar, Dict, Any, Callable
部分类型已过时,比如List,Dict,在3.9已经过时。应该改成内置的list,dict.Optional可以改成丨 None, Callable移动到了新库,以及类似的问题。
为了风格统一,便于维护,尽量完善下类型标注相关的工作。
except OSError as e:
if isinstance(e, PermissionError):
raise RuntimeError(f"无法创建目录 {data_dir}:权限不足") from e
raise RuntimeError(f"无法创建目录 {data_dir}:{e!s}") from e
没必要用isinstance判断错误类型。
这样写:
except PermissionError as e:
raise RuntimeError(f"无法创建目录 {data_dir}:权限不足") from e
except OSError as e:
raise RuntimeError(f"无法创建目录 {data_dir}:{e!s}") from e
更加美观