diff --git a/alicebot/adapter/__init__.py b/alicebot/adapter/__init__.py index 10d8cc2..2805e30 100644 --- a/alicebot/adapter/__init__.py +++ b/alicebot/adapter/__init__.py @@ -17,8 +17,7 @@ import structlog -from alicebot.event import Event -from alicebot.typing import ConfigT, EventT +from alicebot.typing import AnyEvent, ConfigT, EventT from alicebot.utils import is_config_class if TYPE_CHECKING: @@ -33,7 +32,7 @@ __import__("pkg_resources").declare_namespace(__name__) -_EventT = TypeVar("_EventT", bound="Event[Any]", default="Event[Any]") +_EventT = TypeVar("_EventT", bound=AnyEvent, default=AnyEvent) class Adapter(Generic[EventT, ConfigT], ABC): @@ -133,7 +132,7 @@ async def get( max_try_times: int | None = None, timeout: float | None = None, to_thread: bool = False, - ) -> Event[Any]: + ) -> AnyEvent: """获取满足指定条件的的事件,协程会等待直到适配器接收到满足条件的事件、超过最大事件数或超时。 类似 `Bot` 类的 `get()` 方法,但是隐含了判断产生事件的适配器是本适配器。 diff --git a/alicebot/bot.py b/alicebot/bot.py index 5b634f7..8385e89 100644 --- a/alicebot/bot.py +++ b/alicebot/bot.py @@ -28,7 +28,16 @@ from alicebot.exceptions import LoadModuleError, SkipException, StopException from alicebot.matcher import EventMatcher from alicebot.plugin import Plugin, PluginLoadType -from alicebot.typing import AdapterHook, AdapterT, BotHook, EventHook, EventT +from alicebot.typing import ( + AdapterHook, + AdapterT, + AnyAdapter, + AnyEvent, + AnyPlugin, + BotHook, + EventHook, + EventT, +) from alicebot.utils import ( ModulePathFinder, async_map, @@ -61,8 +70,8 @@ class Bot: """ config: MainConfig - adapters: list[Adapter[Any, Any]] - plugins_priority_dict: dict[int, list[type[Plugin[Any, Any, Any]]]] + adapters: list[AnyAdapter] + plugins_priority_dict: dict[int, list[type[AnyPlugin]]] plugin_state: dict[str, Any] global_state: dict[Any, Any] @@ -82,13 +91,13 @@ class Bot: _handle_signals: bool # 处理信号 _extend_plugins: list[ - type[Plugin[Any, Any, Any]] | str | Path + type[AnyPlugin] | str | Path ] # 使用 load_plugins() 方法程序化加载的插件列表 _extend_plugin_dirs: list[ Path ] # 使用 load_plugins_from_dirs() 方法程序化加载的插件路径列表 _extend_adapters: list[ - type[Adapter[Any, Any]] | str + type[AnyAdapter] | str ] # 使用 load_adapter() 方法程序化加载的适配器列表 _bot_run_hooks: list[BotHook] _bot_exit_hooks: list[BotHook] @@ -146,7 +155,7 @@ def __init__( sys.meta_path.insert(0, self._module_path_finder) @property - def plugins(self) -> list[type[Plugin[Any, Any, Any]]]: + def plugins(self) -> list[type[AnyPlugin]]: """当前已经加载的插件的列表。""" return list(chain(*self.plugins_priority_dict.values())) @@ -241,9 +250,9 @@ async def _run(self) -> None: def _remove_plugin_by_path( self, file: Path - ) -> list[type[Plugin[Any, Any, Any]]]: # pragma: no cover + ) -> list[type[AnyPlugin]]: # pragma: no cover """根据路径删除已加载的插件。""" - removed_plugins: list[type[Plugin[Any, Any, Any]]] = [] + removed_plugins: list[type[AnyPlugin]] = [] for plugins in self.plugins_priority_dict.values(): _removed_plugins = list( filter( @@ -341,7 +350,7 @@ def _update_config(self) -> None: """更新 config,合并入来自 Plugin 和 Adapter 的 Config。""" def update_config( - source: list[type[Plugin[Any, Any, Any]]] | list[Adapter[Any, Any]], + source: list[type[AnyPlugin]] | list[AnyAdapter], name: str, base: type[ConfigModel], ) -> tuple[type[ConfigModel], ConfigModel]: @@ -460,7 +469,7 @@ async def _handle_should_exit(self, cancel_scope: anyio.CancelScope) -> None: async def handle_event( self, - current_event: Event[Any], + current_event: AnyEvent, *, handle_get: bool = True, show_log: bool = True, @@ -493,7 +502,7 @@ async def _handle_event_receive(self) -> None: async for current_event, handle_get in self._event_receive_stream: tg.start_soon(self._handle_event, current_event, handle_get) - async def _handle_event(self, current_event: Event[Any], handle_get: bool) -> None: + async def _handle_event(self, current_event: AnyEvent, handle_get: bool) -> None: async with anyio.create_task_group() as tg: if handle_get: event_handled = False @@ -536,9 +545,7 @@ async def _handle_event(self, current_event: Event[Any], handle_get: bool) -> No logger.info("Event Finished") - async def _run_plugin( - self, plugin_class: type[Plugin[Any, Any, Any]], event: Event[Any] - ) -> bool: + async def _run_plugin(self, plugin_class: type[AnyPlugin], event: AnyEvent) -> bool: try: async with AsyncExitStack() as stack: plugin_instance = await solve_dependencies( @@ -573,14 +580,14 @@ async def _run_plugin( @overload async def get( self, - func: Callable[[Event[Any]], bool | Awaitable[bool]] | None = None, + func: Callable[[AnyEvent], bool | Awaitable[bool]] | None = None, *, event_type: None = None, adapter_type: None = None, max_try_times: int | None = None, timeout: float | None = None, to_thread: bool = False, - ) -> Event[Any]: ... + ) -> AnyEvent: ... @overload async def get( @@ -600,7 +607,7 @@ async def get( func: Callable[[EventT], bool | Awaitable[bool]] | None = None, *, event_type: type[EventT], - adapter_type: type[Adapter[Any, Any]] | None = None, + adapter_type: type[AnyAdapter] | None = None, max_try_times: int | None = None, timeout: float | None = None, to_thread: bool = False, @@ -610,12 +617,12 @@ async def get( self, func: Callable[[Any], bool | Awaitable[bool]] | None = None, *, - event_type: type[Event[Any]] | None = None, - adapter_type: type[Adapter[Any, Any]] | None = None, + event_type: type[AnyEvent] | None = None, + adapter_type: type[AnyAdapter] | None = None, max_try_times: int | None = None, timeout: float | None = None, to_thread: bool = False, - ) -> Event[Any]: + ) -> AnyEvent: """获取满足指定条件的的事件,协程会等待直到适配器接收到满足条件的事件、超过最大事件数或超时。 Args: @@ -648,7 +655,7 @@ async def get( def _load_plugin_class( self, - plugin_class: type[Plugin[Any, Any, Any]], + plugin_class: type[AnyPlugin], plugin_load_type: PluginLoadType, plugin_file_path: str | None, ) -> None: @@ -698,7 +705,7 @@ def _load_plugins_from_module_name( def _load_plugins( self, - *plugins: type[Plugin[Any, Any, Any]] | str | Path, + *plugins: type[AnyPlugin] | str | Path, plugin_load_type: PluginLoadType | None = None, reload: bool = False, ) -> None: @@ -772,7 +779,7 @@ def _load_plugins( except Exception: logger.exception("Load plugin failed:", plugin=plugin_) - def load_plugins(self, *plugins: type[Plugin[Any, Any, Any]] | str | Path) -> None: + def load_plugins(self, *plugins: type[AnyPlugin] | str | Path) -> None: """加载插件。 Args: @@ -813,7 +820,7 @@ def load_plugins_from_dirs(self, *dirs: Path) -> None: self._extend_plugin_dirs.extend(dirs) self._load_plugins_from_dirs(*dirs) - def _load_adapters(self, *adapters: type[Adapter[Any, Any]] | str) -> None: + def _load_adapters(self, *adapters: type[AnyAdapter] | str) -> None: """加载适配器。 Args: @@ -823,7 +830,7 @@ def _load_adapters(self, *adapters: type[Adapter[Any, Any]] | str) -> None: 例如:`path.of.adapter`。 """ for adapter_ in adapters: - adapter_object: Adapter[Any, Any] + adapter_object: AnyAdapter try: if isinstance(adapter_, type) and issubclass(adapter_, Adapter): adapter_object = adapter_(self) @@ -857,7 +864,7 @@ def _load_adapters(self, *adapters: type[Adapter[Any, Any]] | str) -> None: else: self.adapters.append(adapter_object) - def load_adapters(self, *adapters: type[Adapter[Any, Any]] | str) -> None: + def load_adapters(self, *adapters: type[AnyAdapter] | str) -> None: """加载适配器。 Args: @@ -870,14 +877,12 @@ def load_adapters(self, *adapters: type[Adapter[Any, Any]] | str) -> None: self._load_adapters(*adapters) @overload - def get_adapter(self, adapter: str) -> Adapter[Any, Any]: ... + def get_adapter(self, adapter: str) -> AnyAdapter: ... @overload def get_adapter(self, adapter: type[AdapterT]) -> AdapterT: ... - def get_adapter( - self, adapter: str | type[AdapterT] - ) -> Adapter[Any, Any] | AdapterT: + def get_adapter(self, adapter: str | type[AdapterT]) -> AnyAdapter | AdapterT: """按照名称或适配器类获取已经加载的适配器。 Args: @@ -897,7 +902,7 @@ def get_adapter( return _adapter raise LookupError(f'Can not find adapter named "{adapter}"') - def get_plugin(self, name: str) -> type[Plugin[Any, Any, Any]]: + def get_plugin(self, name: str) -> type[AnyPlugin]: """按照名称获取已经加载的插件类。 Args: diff --git a/alicebot/event.py b/alicebot/event.py index c996b32..e519e53 100644 --- a/alicebot/event.py +++ b/alicebot/event.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, ConfigDict -from alicebot.typing import AdapterT +from alicebot.typing import AdapterT, AnyEvent __all__ = ["Event", "EventHandleOption", "MessageEvent"] @@ -47,7 +47,7 @@ class EventHandleOption(NamedTuple): handle_get: 当前事件是否可以被 get 方法捕获。 """ - event: Event[Any] + event: AnyEvent handle_get: bool diff --git a/alicebot/matcher.py b/alicebot/matcher.py index be751d4..7a6680b 100644 --- a/alicebot/matcher.py +++ b/alicebot/matcher.py @@ -11,9 +11,8 @@ import anyio import anyio.to_thread -from alicebot.adapter import Adapter -from alicebot.event import Event from alicebot.exceptions import GetEventTimeout +from alicebot.typing import AnyAdapter, AnyEvent if TYPE_CHECKING: from alicebot.bot import Bot @@ -26,8 +25,8 @@ class EventMatcher: func: Callable[[Any], bool | Awaitable[bool]] | None bot: "Bot" - event_type: type[Event[Any]] | None - adapter_type: type[Adapter[Any, Any]] | None + event_type: type[AnyEvent] | None + adapter_type: type[AnyAdapter] | None max_try_times: int | None timeout: int | float | None to_thread: bool @@ -35,7 +34,7 @@ class EventMatcher: event: anyio.Event try_times: int start_time: float - result: Event[Any] | None + result: "AnyEvent | None" exception: BaseException | None def __init__( @@ -43,8 +42,8 @@ def __init__( func: Callable[[Any], bool | Awaitable[bool]] | None, *, bot: "Bot", - event_type: type[Event[Any]] | None, - adapter_type: type[Adapter[Any, Any]] | None, + event_type: type[AnyEvent] | None, + adapter_type: type[AnyAdapter] | None, max_try_times: int | None, timeout: float | None, to_thread: bool, @@ -74,7 +73,7 @@ def __init__( self.result = None self.exception = None - async def wait(self) -> Event[Any]: + async def wait(self) -> AnyEvent: """等待当前事件匹配器直到满足条件或者超时。 Raises: @@ -97,7 +96,7 @@ async def wait(self) -> Event[Any]: raise RuntimeError("Event has no result.") # pragma: no cover - async def run(self, event: Event[Any]) -> bool | None: + async def run(self, event: AnyEvent) -> bool | None: """运行 `get()` 函数,检查当前 `get()` 是否成功。 Args: @@ -118,7 +117,7 @@ async def run(self, event: Event[Any]) -> bool | None: return None return False - async def match(self, event: Event[Any]) -> bool: + async def match(self, event: AnyEvent) -> bool: """检查当前事件是否被匹配。 Args: @@ -147,7 +146,7 @@ async def match(self, event: Event[Any]) -> bool: if self.func is None: return True if not inspect.iscoroutinefunction(self.func): - func = cast("Callable[[Event[Any]], bool]", self.func) + func = cast("Callable[[AnyEvent], bool]", self.func) if self.to_thread: return await anyio.to_thread.run_sync(func, event) return func(event) diff --git a/alicebot/typing.py b/alicebot/typing.py index eaee2d5..1549e53 100644 --- a/alicebot/typing.py +++ b/alicebot/typing.py @@ -4,7 +4,7 @@ """ from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypeAlias from typing_extensions import TypeVar from alicebot.message import BuildMessageType, MessageSegmentT, MessageT @@ -21,6 +21,9 @@ __all__ = [ "AdapterHook", "AdapterT", + "AnyAdapter", + "AnyEvent", + "AnyPlugin", "BotHook", "BuildMessageType", "ConfigT", @@ -32,12 +35,16 @@ "StateT", ] -EventT = TypeVar("EventT", bound="Event[Any]", default="Event[Any]") +AnyEvent: TypeAlias = "Event[Any]" +AnyPlugin: TypeAlias = "Plugin[Any, Any, Any]" +AnyAdapter: TypeAlias = "Adapter[Any, Any]" + +EventT = TypeVar("EventT", bound=AnyEvent, default=AnyEvent) StateT = TypeVar("StateT", default=None) ConfigT = TypeVar("ConfigT", bound="ConfigModel | None", default=None) -PluginT = TypeVar("PluginT", bound="Plugin[Any, Any, Any]") -AdapterT = TypeVar("AdapterT", bound="Adapter[Any, Any]") +PluginT = TypeVar("PluginT", bound=AnyPlugin) +AdapterT = TypeVar("AdapterT", bound=AnyAdapter) -BotHook = Callable[["Bot"], Awaitable[None]] -AdapterHook = Callable[["Adapter[Any, Any]"], Awaitable[None]] -EventHook = Callable[["Event[Any]"], Awaitable[None]] +BotHook: TypeAlias = Callable[["Bot"], Awaitable[None]] +AdapterHook: TypeAlias = Callable[[AnyAdapter], Awaitable[None]] +EventHook: TypeAlias = Callable[[AnyEvent], Awaitable[None]] diff --git a/packages/alicebot-adapter-apscheduler/alicebot/adapter/apscheduler/__init__.py b/packages/alicebot-adapter-apscheduler/alicebot/adapter/apscheduler/__init__.py index b095dfd..52e8a4d 100644 --- a/packages/alicebot-adapter-apscheduler/alicebot/adapter/apscheduler/__init__.py +++ b/packages/alicebot-adapter-apscheduler/alicebot/adapter/apscheduler/__init__.py @@ -18,7 +18,7 @@ from alicebot.adapter import Adapter from alicebot.plugin import Plugin -from alicebot.typing import PluginT +from alicebot.typing import AnyPlugin, PluginT from .config import Config from .event import APSchedulerEvent @@ -38,7 +38,7 @@ class APSchedulerAdapter(Adapter[APSchedulerEvent, Config]): Config = Config scheduler: AsyncIOScheduler - plugin_class_to_job: dict[type[Plugin[Any, Any, Any]], Job] + plugin_class_to_job: dict[type[AnyPlugin], Job] @override async def startup(self) -> None: @@ -83,7 +83,7 @@ async def run(self) -> None: async def shutdown(self) -> None: self.scheduler.shutdown() - async def create_event(self, plugin_class: type[Plugin[Any, Any, Any]]) -> None: + async def create_event(self, plugin_class: type[AnyPlugin]) -> None: """创建 `APSchedulerEvent` 事件。 Args: diff --git a/packages/alicebot-adapter-apscheduler/alicebot/adapter/apscheduler/event.py b/packages/alicebot-adapter-apscheduler/alicebot/adapter/apscheduler/event.py index 7806b59..4a9f65c 100644 --- a/packages/alicebot-adapter-apscheduler/alicebot/adapter/apscheduler/event.py +++ b/packages/alicebot-adapter-apscheduler/alicebot/adapter/apscheduler/event.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: import builtins - from alicebot.plugin import Plugin + from alicebot.typing import AnyPlugin from . import APSchedulerAdapter @@ -24,7 +24,7 @@ class APSchedulerEvent(Event["APSchedulerAdapter"]): type: str | None = "apscheduler" if TYPE_CHECKING: - plugin_class: "builtins.type[Plugin[Any, Any, Any]]" + plugin_class: "builtins.type[AnyPlugin]" else: plugin_class: Any diff --git a/tests/fake_adapter.py b/tests/fake_adapter.py index a779613..51efe21 100644 --- a/tests/fake_adapter.py +++ b/tests/fake_adapter.py @@ -1,13 +1,13 @@ import inspect from collections.abc import Awaitable, Callable -from typing import Any, Generic +from typing import Generic from typing_extensions import override from anyio.lowlevel import checkpoint from alicebot import Adapter, Event, MessageEvent from alicebot.plugin import Plugin -from alicebot.typing import ConfigT, StateT +from alicebot.typing import AnyEvent, ConfigT, StateT EventFactory = Callable[ ["FakeAdapter"], @@ -17,11 +17,11 @@ async def allow_schedule_other_tasks() -> None: """让出当前任务,允许其他任务执行。""" - for _ in range(100): + for _ in range(1000): await checkpoint() -class FakeAdapter(Adapter[Event[Any], None]): +class FakeAdapter(Adapter[AnyEvent, None]): """用于测试的适配器。""" name: str = "fake_adapter"