Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions alicebot/adapter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

import structlog

from alicebot.event import Event
from alicebot.typing import ConfigT, EventT
from alicebot.typing import AnyEvent, ConfigT, EventT
from alicebot.utils import is_config_class

if TYPE_CHECKING:
Expand All @@ -33,7 +32,7 @@
__import__("pkg_resources").declare_namespace(__name__)


_EventT = TypeVar("_EventT", bound="Event[Any]", default="Event[Any]")
_EventT = TypeVar("_EventT", bound=AnyEvent, default=AnyEvent)


class Adapter(Generic[EventT, ConfigT], ABC):
Expand Down Expand Up @@ -133,7 +132,7 @@ async def get(
max_try_times: int | None = None,
timeout: float | None = None,
to_thread: bool = False,
) -> Event[Any]:
) -> AnyEvent:
"""获取满足指定条件的的事件,协程会等待直到适配器接收到满足条件的事件、超过最大事件数或超时。

类似 `Bot` 类的 `get()` 方法,但是隐含了判断产生事件的适配器是本适配器。
Expand Down
67 changes: 36 additions & 31 deletions alicebot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,16 @@
from alicebot.exceptions import LoadModuleError, SkipException, StopException
from alicebot.matcher import EventMatcher
from alicebot.plugin import Plugin, PluginLoadType
from alicebot.typing import AdapterHook, AdapterT, BotHook, EventHook, EventT
from alicebot.typing import (
AdapterHook,
AdapterT,
AnyAdapter,
AnyEvent,
AnyPlugin,
BotHook,
EventHook,
EventT,
)
from alicebot.utils import (
ModulePathFinder,
async_map,
Expand Down Expand Up @@ -61,8 +70,8 @@ class Bot:
"""

config: MainConfig
adapters: list[Adapter[Any, Any]]
plugins_priority_dict: dict[int, list[type[Plugin[Any, Any, Any]]]]
adapters: list[AnyAdapter]
plugins_priority_dict: dict[int, list[type[AnyPlugin]]]
plugin_state: dict[str, Any]
global_state: dict[Any, Any]

Expand All @@ -82,13 +91,13 @@ class Bot:
_handle_signals: bool # 处理信号

_extend_plugins: list[
type[Plugin[Any, Any, Any]] | str | Path
type[AnyPlugin] | str | Path
] # 使用 load_plugins() 方法程序化加载的插件列表
_extend_plugin_dirs: list[
Path
] # 使用 load_plugins_from_dirs() 方法程序化加载的插件路径列表
_extend_adapters: list[
type[Adapter[Any, Any]] | str
type[AnyAdapter] | str
] # 使用 load_adapter() 方法程序化加载的适配器列表
_bot_run_hooks: list[BotHook]
_bot_exit_hooks: list[BotHook]
Expand Down Expand Up @@ -146,7 +155,7 @@ def __init__(
sys.meta_path.insert(0, self._module_path_finder)

@property
def plugins(self) -> list[type[Plugin[Any, Any, Any]]]:
def plugins(self) -> list[type[AnyPlugin]]:
"""当前已经加载的插件的列表。"""
return list(chain(*self.plugins_priority_dict.values()))

Expand Down Expand Up @@ -241,9 +250,9 @@ async def _run(self) -> None:

def _remove_plugin_by_path(
self, file: Path
) -> list[type[Plugin[Any, Any, Any]]]: # pragma: no cover
) -> list[type[AnyPlugin]]: # pragma: no cover
"""根据路径删除已加载的插件。"""
removed_plugins: list[type[Plugin[Any, Any, Any]]] = []
removed_plugins: list[type[AnyPlugin]] = []
for plugins in self.plugins_priority_dict.values():
_removed_plugins = list(
filter(
Expand Down Expand Up @@ -341,7 +350,7 @@ def _update_config(self) -> None:
"""更新 config,合并入来自 Plugin 和 Adapter 的 Config。"""

def update_config(
source: list[type[Plugin[Any, Any, Any]]] | list[Adapter[Any, Any]],
source: list[type[AnyPlugin]] | list[AnyAdapter],
name: str,
base: type[ConfigModel],
) -> tuple[type[ConfigModel], ConfigModel]:
Expand Down Expand Up @@ -460,7 +469,7 @@ async def _handle_should_exit(self, cancel_scope: anyio.CancelScope) -> None:

async def handle_event(
self,
current_event: Event[Any],
current_event: AnyEvent,
*,
handle_get: bool = True,
show_log: bool = True,
Expand Down Expand Up @@ -493,7 +502,7 @@ async def _handle_event_receive(self) -> None:
async for current_event, handle_get in self._event_receive_stream:
tg.start_soon(self._handle_event, current_event, handle_get)

async def _handle_event(self, current_event: Event[Any], handle_get: bool) -> None:
async def _handle_event(self, current_event: AnyEvent, handle_get: bool) -> None:
async with anyio.create_task_group() as tg:
if handle_get:
event_handled = False
Expand Down Expand Up @@ -536,9 +545,7 @@ async def _handle_event(self, current_event: Event[Any], handle_get: bool) -> No

logger.info("Event Finished")

async def _run_plugin(
self, plugin_class: type[Plugin[Any, Any, Any]], event: Event[Any]
) -> bool:
async def _run_plugin(self, plugin_class: type[AnyPlugin], event: AnyEvent) -> bool:
try:
async with AsyncExitStack() as stack:
plugin_instance = await solve_dependencies(
Expand Down Expand Up @@ -573,14 +580,14 @@ async def _run_plugin(
@overload
async def get(
self,
func: Callable[[Event[Any]], bool | Awaitable[bool]] | None = None,
func: Callable[[AnyEvent], bool | Awaitable[bool]] | None = None,
*,
event_type: None = None,
adapter_type: None = None,
max_try_times: int | None = None,
timeout: float | None = None,
to_thread: bool = False,
) -> Event[Any]: ...
) -> AnyEvent: ...

@overload
async def get(
Expand All @@ -600,7 +607,7 @@ async def get(
func: Callable[[EventT], bool | Awaitable[bool]] | None = None,
*,
event_type: type[EventT],
adapter_type: type[Adapter[Any, Any]] | None = None,
adapter_type: type[AnyAdapter] | None = None,
max_try_times: int | None = None,
timeout: float | None = None,
to_thread: bool = False,
Expand All @@ -610,12 +617,12 @@ async def get(
self,
func: Callable[[Any], bool | Awaitable[bool]] | None = None,
*,
event_type: type[Event[Any]] | None = None,
adapter_type: type[Adapter[Any, Any]] | None = None,
event_type: type[AnyEvent] | None = None,
adapter_type: type[AnyAdapter] | None = None,
max_try_times: int | None = None,
timeout: float | None = None,
to_thread: bool = False,
) -> Event[Any]:
) -> AnyEvent:
"""获取满足指定条件的的事件,协程会等待直到适配器接收到满足条件的事件、超过最大事件数或超时。

Args:
Expand Down Expand Up @@ -648,7 +655,7 @@ async def get(

def _load_plugin_class(
self,
plugin_class: type[Plugin[Any, Any, Any]],
plugin_class: type[AnyPlugin],
plugin_load_type: PluginLoadType,
plugin_file_path: str | None,
) -> None:
Expand Down Expand Up @@ -698,7 +705,7 @@ def _load_plugins_from_module_name(

def _load_plugins(
self,
*plugins: type[Plugin[Any, Any, Any]] | str | Path,
*plugins: type[AnyPlugin] | str | Path,
plugin_load_type: PluginLoadType | None = None,
reload: bool = False,
) -> None:
Expand Down Expand Up @@ -772,7 +779,7 @@ def _load_plugins(
except Exception:
logger.exception("Load plugin failed:", plugin=plugin_)

def load_plugins(self, *plugins: type[Plugin[Any, Any, Any]] | str | Path) -> None:
def load_plugins(self, *plugins: type[AnyPlugin] | str | Path) -> None:
"""加载插件。

Args:
Expand Down Expand Up @@ -813,7 +820,7 @@ def load_plugins_from_dirs(self, *dirs: Path) -> None:
self._extend_plugin_dirs.extend(dirs)
self._load_plugins_from_dirs(*dirs)

def _load_adapters(self, *adapters: type[Adapter[Any, Any]] | str) -> None:
def _load_adapters(self, *adapters: type[AnyAdapter] | str) -> None:
"""加载适配器。

Args:
Expand All @@ -823,7 +830,7 @@ def _load_adapters(self, *adapters: type[Adapter[Any, Any]] | str) -> None:
例如:`path.of.adapter`。
"""
for adapter_ in adapters:
adapter_object: Adapter[Any, Any]
adapter_object: AnyAdapter
try:
if isinstance(adapter_, type) and issubclass(adapter_, Adapter):
adapter_object = adapter_(self)
Expand Down Expand Up @@ -857,7 +864,7 @@ def _load_adapters(self, *adapters: type[Adapter[Any, Any]] | str) -> None:
else:
self.adapters.append(adapter_object)

def load_adapters(self, *adapters: type[Adapter[Any, Any]] | str) -> None:
def load_adapters(self, *adapters: type[AnyAdapter] | str) -> None:
"""加载适配器。

Args:
Expand All @@ -870,14 +877,12 @@ def load_adapters(self, *adapters: type[Adapter[Any, Any]] | str) -> None:
self._load_adapters(*adapters)

@overload
def get_adapter(self, adapter: str) -> Adapter[Any, Any]: ...
def get_adapter(self, adapter: str) -> AnyAdapter: ...

@overload
def get_adapter(self, adapter: type[AdapterT]) -> AdapterT: ...

def get_adapter(
self, adapter: str | type[AdapterT]
) -> Adapter[Any, Any] | AdapterT:
def get_adapter(self, adapter: str | type[AdapterT]) -> AnyAdapter | AdapterT:
"""按照名称或适配器类获取已经加载的适配器。

Args:
Expand All @@ -897,7 +902,7 @@ def get_adapter(
return _adapter
raise LookupError(f'Can not find adapter named "{adapter}"')

def get_plugin(self, name: str) -> type[Plugin[Any, Any, Any]]:
def get_plugin(self, name: str) -> type[AnyPlugin]:
"""按照名称获取已经加载的插件类。

Args:
Expand Down
4 changes: 2 additions & 2 deletions alicebot/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pydantic import BaseModel, ConfigDict

from alicebot.typing import AdapterT
from alicebot.typing import AdapterT, AnyEvent

__all__ = ["Event", "EventHandleOption", "MessageEvent"]

Expand Down Expand Up @@ -47,7 +47,7 @@ class EventHandleOption(NamedTuple):
handle_get: 当前事件是否可以被 get 方法捕获。
"""

event: Event[Any]
event: AnyEvent
handle_get: bool


Expand Down
21 changes: 10 additions & 11 deletions alicebot/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
import anyio
import anyio.to_thread

from alicebot.adapter import Adapter
from alicebot.event import Event
from alicebot.exceptions import GetEventTimeout
from alicebot.typing import AnyAdapter, AnyEvent

if TYPE_CHECKING:
from alicebot.bot import Bot
Expand All @@ -26,25 +25,25 @@ class EventMatcher:

func: Callable[[Any], bool | Awaitable[bool]] | None
bot: "Bot"
event_type: type[Event[Any]] | None
adapter_type: type[Adapter[Any, Any]] | None
event_type: type[AnyEvent] | None
adapter_type: type[AnyAdapter] | None
max_try_times: int | None
timeout: int | float | None
to_thread: bool

event: anyio.Event
try_times: int
start_time: float
result: Event[Any] | None
result: "AnyEvent | None"
exception: BaseException | None

def __init__(
self,
func: Callable[[Any], bool | Awaitable[bool]] | None,
*,
bot: "Bot",
event_type: type[Event[Any]] | None,
adapter_type: type[Adapter[Any, Any]] | None,
event_type: type[AnyEvent] | None,
adapter_type: type[AnyAdapter] | None,
max_try_times: int | None,
timeout: float | None,
to_thread: bool,
Expand Down Expand Up @@ -74,7 +73,7 @@ def __init__(
self.result = None
self.exception = None

async def wait(self) -> Event[Any]:
async def wait(self) -> AnyEvent:
"""等待当前事件匹配器直到满足条件或者超时。

Raises:
Expand All @@ -97,7 +96,7 @@ async def wait(self) -> Event[Any]:

raise RuntimeError("Event has no result.") # pragma: no cover

async def run(self, event: Event[Any]) -> bool | None:
async def run(self, event: AnyEvent) -> bool | None:
"""运行 `get()` 函数,检查当前 `get()` 是否成功。

Args:
Expand All @@ -118,7 +117,7 @@ async def run(self, event: Event[Any]) -> bool | None:
return None
return False

async def match(self, event: Event[Any]) -> bool:
async def match(self, event: AnyEvent) -> bool:
"""检查当前事件是否被匹配。

Args:
Expand Down Expand Up @@ -147,7 +146,7 @@ async def match(self, event: Event[Any]) -> bool:
if self.func is None:
return True
if not inspect.iscoroutinefunction(self.func):
func = cast("Callable[[Event[Any]], bool]", self.func)
func = cast("Callable[[AnyEvent], bool]", self.func)
if self.to_thread:
return await anyio.to_thread.run_sync(func, event)
return func(event)
Expand Down
21 changes: 14 additions & 7 deletions alicebot/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypeAlias
from typing_extensions import TypeVar

from alicebot.message import BuildMessageType, MessageSegmentT, MessageT
Expand All @@ -21,6 +21,9 @@
__all__ = [
"AdapterHook",
"AdapterT",
"AnyAdapter",
"AnyEvent",
"AnyPlugin",
"BotHook",
"BuildMessageType",
"ConfigT",
Expand All @@ -32,12 +35,16 @@
"StateT",
]

EventT = TypeVar("EventT", bound="Event[Any]", default="Event[Any]")
AnyEvent: TypeAlias = "Event[Any]"
AnyPlugin: TypeAlias = "Plugin[Any, Any, Any]"
AnyAdapter: TypeAlias = "Adapter[Any, Any]"

EventT = TypeVar("EventT", bound=AnyEvent, default=AnyEvent)
StateT = TypeVar("StateT", default=None)
ConfigT = TypeVar("ConfigT", bound="ConfigModel | None", default=None)
PluginT = TypeVar("PluginT", bound="Plugin[Any, Any, Any]")
AdapterT = TypeVar("AdapterT", bound="Adapter[Any, Any]")
PluginT = TypeVar("PluginT", bound=AnyPlugin)
AdapterT = TypeVar("AdapterT", bound=AnyAdapter)

BotHook = Callable[["Bot"], Awaitable[None]]
AdapterHook = Callable[["Adapter[Any, Any]"], Awaitable[None]]
EventHook = Callable[["Event[Any]"], Awaitable[None]]
BotHook: TypeAlias = Callable[["Bot"], Awaitable[None]]
AdapterHook: TypeAlias = Callable[[AnyAdapter], Awaitable[None]]
EventHook: TypeAlias = Callable[[AnyEvent], Awaitable[None]]
Loading
Loading