Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
// Pytest
"python.testing.pytestArgs": ["tests"],
"python.testing.pytestArgs": ["-s"],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
5 changes: 5 additions & 0 deletions alicebot/adapter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ async def get(
event_type: None = None,
max_try_times: Optional[int] = None,
timeout: Optional[Union[int, float]] = None,
to_thread: bool = False,
) -> EventT: ...

@overload
Expand All @@ -123,6 +124,7 @@ async def get(
event_type: type[_EventT],
max_try_times: Optional[int] = None,
timeout: Optional[Union[int, float]] = None,
to_thread: bool = False,
) -> _EventT: ...

@final
Expand All @@ -133,6 +135,7 @@ async def get(
event_type: Any = None,
max_try_times: Optional[int] = None,
timeout: Optional[Union[int, float]] = None,
to_thread: bool = False,
) -> Event[Any]:
"""获取满足指定条件的的事件,协程会等待直到适配器接收到满足条件的事件、超过最大事件数或超时。

Expand All @@ -147,6 +150,7 @@ async def get(
event_type: 当指定时,只接受指定类型的事件,先于 func 条件生效。默认为 `None`。
max_try_times: 最大事件数。
timeout: 超时时间。
to_thread: 是否在独立的线程中运行同步函数。仅当 func 为同步函数时生效。默认为 `False`。

Returns:
返回满足 func 条件的事件。
Expand All @@ -160,4 +164,5 @@ async def get(
adapter_type=type(self),
max_try_times=max_try_times,
timeout=timeout,
to_thread=to_thread,
)
201 changes: 94 additions & 107 deletions alicebot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import signal
import sys
import threading
import time
from collections import defaultdict
from collections.abc import Awaitable
from contextlib import AsyncExitStack
Expand All @@ -18,28 +17,23 @@

import anyio
import structlog
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import ValidationError, create_model

from alicebot.adapter import Adapter
from alicebot.config import AdapterConfig, ConfigModel, MainConfig, PluginConfig
from alicebot.dependencies import solve_dependencies
from alicebot.event import Event, EventHandleOption
from alicebot.exceptions import (
GetEventTimeout,
LoadModuleError,
SkipException,
StopException,
)
from alicebot.exceptions import LoadModuleError, SkipException, StopException
from alicebot.matcher import EventMatcher
from alicebot.plugin import Plugin, PluginLoadType
from alicebot.typing import AdapterHook, AdapterT, BotHook, EventHook, EventT
from alicebot.utils import (
ModulePathFinder,
async_map,
get_classes_from_module_name,
is_config_class,
samefile,
wrap_get_func,
)

if sys.version_info >= (3, 11): # pragma: no cover
Expand Down Expand Up @@ -79,8 +73,7 @@ class Bot:

_event_send_stream: MemoryObjectSendStream[EventHandleOption] # pyright: ignore[reportUninitializedInstanceVariable]
_event_receive_stream: MemoryObjectReceiveStream[EventHandleOption] # pyright: ignore[reportUninitializedInstanceVariable]
_condition: anyio.Condition # 用于处理 get 的 Condition # pyright: ignore[reportUninitializedInstanceVariable]
_current_event: Optional[Event[Any]] # 当前待处理的 Event
_event_matchers: list[EventMatcher] # pyright: ignore[reportUninitializedInstanceVariable]

_should_exit: anyio.Event # 机器人是否应该进入准备退出状态 # pyright: ignore[reportUninitializedInstanceVariable]
_restart_flag: bool # 重新启动标志
Expand Down Expand Up @@ -135,7 +128,6 @@ def __init__(
self.global_state = {}

self.adapters = []
self._current_event = None
self._restart_flag = False
self._module_path_finder = ModulePathFinder()
self._raw_config_dict = {}
Expand Down Expand Up @@ -200,7 +192,7 @@ async def run_async(self) -> None:
async def _init(self) -> None:
"""初始化 AliceBot。"""
self._should_exit = anyio.Event()
self._condition = anyio.Condition()
self._event_matchers = []
self._event_send_stream, self._event_receive_stream = (
anyio.create_memory_object_stream(
max_buffer_size=self.config.bot.event_queue_size
Expand Down Expand Up @@ -504,74 +496,84 @@ async def handle_event(
async def _handle_event_receive(self) -> None:
async with anyio.create_task_group() as tg, self._event_receive_stream:
async for current_event, handle_get in self._event_receive_stream:
if handle_get:
await tg.start(self._handle_event_wait_condition)
async with self._condition:
self._current_event = current_event
self._condition.notify_all()
else:
tg.start_soon(self._handle_event, current_event)
tg.start_soon(self._handle_event, current_event, handle_get)

async def _handle_event(self, current_event: Event[Any], handle_get: bool) -> None:
async with anyio.create_task_group() as tg:
if handle_get:
event_handled = False
new_event_matchers: list[EventMatcher] = []
async for event_matcher, result in async_map(
tg,
lambda x: x.run(current_event),
self._event_matchers.copy(),
):
if result is None:
# 当前 event_matcher 已经失效,什么都不做
pass
elif result is True:
# 当前 event_matcher 成功匹配事件,设置 event_handled 为 True
event_handled = True
elif result is False:
# 当前 event_matcher 未成功匹配事件,将其放回队列中,等待下次处理
new_event_matchers.append(event_matcher)
self._event_matchers = new_event_matchers
if event_handled:
return

for event_preprocessor_hook_func in self._event_preprocessor_hooks:
await event_preprocessor_hook_func(current_event)

for plugin_priority in sorted(self.plugins_priority_dict.keys()):
logger.debug("Checking for matching plugins", priority=plugin_priority)
stop = False
async for _plugin_class, should_stop in async_map(
tg,
lambda x: self._run_plugin(x, current_event),
self.plugins_priority_dict[plugin_priority],
):
stop = stop or should_stop
if stop:
break

async def _handle_event_wait_condition(
self, *, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORED
) -> None:
async with self._condition:
task_status.started()
await self._condition.wait()
assert self._current_event is not None
current_event = self._current_event
await self._handle_event(current_event)

async def _handle_event(self, current_event: Event[Any]) -> None:
if current_event.__handled__:
return
for event_postprocessor_hook_func in self._event_postprocessor_hooks:
await event_postprocessor_hook_func(current_event)

for _hook_func in self._event_preprocessor_hooks:
await _hook_func(current_event)
logger.info("Event Finished")

for plugin_priority in sorted(self.plugins_priority_dict.keys()):
logger.debug("Checking for matching plugins", priority=plugin_priority)
stop = False
for plugin in self.plugins_priority_dict[plugin_priority]:
async def _run_plugin(
self, plugin_class: type[Plugin[Any, Any, Any]], event: Event[Any]
) -> bool:
try:
async with AsyncExitStack() as stack:
plugin_instance = await solve_dependencies(
plugin_class,
use_cache=True,
stack=stack,
dependency_cache={Bot: self, Event: event},
)
if plugin_instance.name not in self.plugin_state:
plugin_state = plugin_instance.__init_state__()
if plugin_state is not None:
self.plugin_state[plugin_instance.name] = plugin_state
try:
async with AsyncExitStack() as stack:
_plugin = await solve_dependencies(
plugin,
use_cache=True,
stack=stack,
dependency_cache={
Bot: self,
Event: current_event,
},
if await plugin_instance.rule():
logger.info(
"Event will be handled by plugin", plugin=plugin_instance
)
if _plugin.name not in self.plugin_state:
plugin_state = _plugin.__init_state__()
if plugin_state is not None:
self.plugin_state[_plugin.name] = plugin_state
if await _plugin.rule():
logger.info(
"Event will be handled by plugin", plugin=_plugin
)
try:
await _plugin.handle()
finally:
if _plugin.block:
stop = True
except SkipException:
# 插件要求跳过自身继续当前事件传播
continue
except StopException:
# 插件要求停止当前事件传播
stop = True
except Exception:
logger.exception("Exception in plugin", plugin=plugin)
if stop:
break

for _hook_func in self._event_postprocessor_hooks:
await _hook_func(current_event)

logger.info("Event Finished")
await plugin_instance.handle()
finally:
if plugin_instance.block:
raise StopException
except SkipException:
# 插件要求跳过自身继续当前事件传播
pass
except StopException:
# 插件要求停止当前事件传播
return True
except Exception:
logger.exception("Exception in plugin", plugin=plugin_class)
return False

@overload
async def get(
Expand All @@ -582,6 +584,7 @@ async def get(
adapter_type: None = None,
max_try_times: Optional[int] = None,
timeout: Optional[Union[int, float]] = None,
to_thread: bool = False,
) -> Event[Any]: ...

@overload
Expand All @@ -593,6 +596,7 @@ async def get(
adapter_type: type[Adapter[EventT, Any]],
max_try_times: Optional[int] = None,
timeout: Optional[Union[int, float]] = None,
to_thread: bool = False,
) -> EventT: ...

@overload
Expand All @@ -604,6 +608,7 @@ async def get(
adapter_type: Optional[type[Adapter[Any, Any]]] = None,
max_try_times: Optional[int] = None,
timeout: Optional[Union[int, float]] = None,
to_thread: bool = False,
) -> EventT: ...

async def get(
Expand All @@ -614,6 +619,7 @@ async def get(
adapter_type: Optional[type[Adapter[Any, Any]]] = None,
max_try_times: Optional[int] = None,
timeout: Optional[Union[int, float]] = None,
to_thread: bool = False,
) -> Event[Any]:
"""获取满足指定条件的的事件,协程会等待直到适配器接收到满足条件的事件、超过最大事件数或超时。

Expand All @@ -625,44 +631,25 @@ async def get(
adapter_type: 当指定时,只接受指定适配器产生的事件,先于 func 条件生效。默认为 `None`。
max_try_times: 最大事件数。
timeout: 超时时间。
to_thread: 是否在独立的线程中运行同步函数。仅当 func 为同步函数时生效。默认为 `False`。

Returns:
返回满足 `func` 条件的事件。

Raises:
GetEventTimeout: 超过最大事件数或超时。
"""
_func = wrap_get_func(func, event_type=event_type, adapter_type=adapter_type)

try_times = 0
start_time = time.time()
while not self._should_exit.is_set():
if max_try_times is not None and try_times > max_try_times:
break
if timeout is not None and time.time() - start_time > timeout:
break

async with self._condition:
if timeout is None:
await self._condition.wait()
else:
try:
with anyio.fail_after(start_time + timeout - time.time()):
await self._condition.wait()
except TimeoutError:
break

if (
self._current_event is not None
and not self._current_event.__handled__
and await _func(self._current_event)
):
self._current_event.__handled__ = True
return self._current_event

try_times += 1

raise GetEventTimeout
event_matcher = EventMatcher(
func,
bot=self,
event_type=event_type,
adapter_type=adapter_type,
max_try_times=max_try_times,
timeout=timeout,
to_thread=to_thread,
)
self._event_matchers.append(event_matcher)
return await event_matcher.wait()

def _load_plugin_class(
self,
Expand Down
2 changes: 0 additions & 2 deletions alicebot/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ class Event(ABC, BaseModel, Generic[AdapterT]):
Attributes:
adapter: 产生当前事件的适配器对象。
type: 事件类型。
__handled__: 表示事件是否被处理过了,用于适配器处理。警告:请勿手动更改此属性的值。
"""

model_config = ConfigDict(extra="allow")
Expand All @@ -30,7 +29,6 @@ class Event(ABC, BaseModel, Generic[AdapterT]):
else:
adapter: Any
type: Optional[str]
__handled__: bool = False

@override
def __str__(self) -> str:
Expand Down
Loading
Loading