diff --git a/alicebot/bot.py b/alicebot/bot.py index 8385e89..b5fbad3 100644 --- a/alicebot/bot.py +++ b/alicebot/bot.py @@ -236,17 +236,18 @@ async def _run(self) -> None: await self._should_exit.wait() finally: - for _adapter in self.adapters: - for adapter_shutdown_hook_func in self._adapter_shutdown_hooks: - await adapter_shutdown_hook_func(_adapter) - await _adapter.shutdown() + with anyio.CancelScope(shield=True): + for _adapter in self.adapters: + for adapter_shutdown_hook_func in self._adapter_shutdown_hooks: + await adapter_shutdown_hook_func(_adapter) + await _adapter.shutdown() - for bot_exit_hook_func in self._bot_exit_hooks: - await bot_exit_hook_func(self) + for bot_exit_hook_func in self._bot_exit_hooks: + await bot_exit_hook_func(self) - self.adapters.clear() - self.plugins_priority_dict.clear() - self._module_path_finder.path.clear() + self.adapters.clear() + self.plugins_priority_dict.clear() + self._module_path_finder.path.clear() def _remove_plugin_by_path( self, file: Path diff --git a/tests/test_hook.py b/tests/test_hook.py index 6cb177b..1448e1b 100644 --- a/tests/test_hook.py +++ b/tests/test_hook.py @@ -2,7 +2,9 @@ from typing import Any +from anyio.lowlevel import checkpoint from fake_adapter import fake_adapter_class_factory, fake_message_event_factor +from pytest_mock import MockerFixture from alicebot import Adapter, Bot, Event @@ -51,3 +53,23 @@ async def event_postprocessor_hook(_event: Event[Any]) -> None: "adapter_shutdown_hook", "bot_exit_hook", ] + + +def test_multiple_bot_exit_hook(mocker: MockerFixture) -> None: + bot = Bot() + mock = mocker.AsyncMock() + + @bot.bot_exit_hook + async def bot_exit_hook1(_bot: Bot) -> None: + await checkpoint() + await mock() + + @bot.bot_exit_hook + async def bot_exit_hook2(_bot: Bot) -> None: + await checkpoint() + await mock() + + bot.load_adapters(fake_adapter_class_factory(fake_message_event_factor)) + bot.run() + + assert mock.call_count == 2