Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions alicebot/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,17 +236,18 @@ async def _run(self) -> None:

await self._should_exit.wait()
finally:
for _adapter in self.adapters:
for adapter_shutdown_hook_func in self._adapter_shutdown_hooks:
await adapter_shutdown_hook_func(_adapter)
await _adapter.shutdown()
with anyio.CancelScope(shield=True):
for _adapter in self.adapters:
for adapter_shutdown_hook_func in self._adapter_shutdown_hooks:
await adapter_shutdown_hook_func(_adapter)
await _adapter.shutdown()

for bot_exit_hook_func in self._bot_exit_hooks:
await bot_exit_hook_func(self)
for bot_exit_hook_func in self._bot_exit_hooks:
await bot_exit_hook_func(self)

self.adapters.clear()
self.plugins_priority_dict.clear()
self._module_path_finder.path.clear()
self.adapters.clear()
self.plugins_priority_dict.clear()
self._module_path_finder.path.clear()

def _remove_plugin_by_path(
self, file: Path
Expand Down
22 changes: 22 additions & 0 deletions tests/test_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from typing import Any

from anyio.lowlevel import checkpoint
from fake_adapter import fake_adapter_class_factory, fake_message_event_factor
from pytest_mock import MockerFixture

from alicebot import Adapter, Bot, Event

Expand Down Expand Up @@ -51,3 +53,23 @@ async def event_postprocessor_hook(_event: Event[Any]) -> None:
"adapter_shutdown_hook",
"bot_exit_hook",
]


def test_multiple_bot_exit_hook(mocker: MockerFixture) -> None:
bot = Bot()
mock = mocker.AsyncMock()

@bot.bot_exit_hook
async def bot_exit_hook1(_bot: Bot) -> None:
await checkpoint()
await mock()

@bot.bot_exit_hook
async def bot_exit_hook2(_bot: Bot) -> None:
await checkpoint()
await mock()

bot.load_adapters(fake_adapter_class_factory(fake_message_event_factor))
bot.run()

assert mock.call_count == 2
Loading