Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@


class PluginInterface:
"""
Interface for a plugin. This class runs in the main bot process and
communicates with the plugin process over IPC.
"""

def __init__(self, name, bot, pid):

logging.info("PluginInterface.__init__ %s", "ipc://ipc_plugin_" + name)
Expand Down Expand Up @@ -109,7 +114,7 @@ async def run(self):
self._recieve(await self._socket_plugin.recv_json())

def __getattr__(self, name):
if name in ["started", "update"]:
if name in ["started", "update", "shutdown"]:

def call(*args, **kwarg):
self._call(name, *args)
Expand All @@ -118,6 +123,23 @@ def call(*args, **kwarg):
else:
raise AttributeError(self, name)

async def shutdown_plugin(self):
self.shutdown()
for _ in range(5):
await asyncio.sleep(0.1)
try:
pid, status = os.waitpid(self.pid, os.WNOHANG)
if pid != 0:
return
except ChildProcessError:
return

logging.error("Plugin %s did not shut down cleanly, sending SIGTERM", self.name)
print(
f"Error: Plugin {self.name} did not shut down cleanly, sending SIGTERM", file=sys.stderr
)
os.kill(self.pid, signal.SIGTERM)


class Bot:
def __init__(self, temp_folder):
Expand All @@ -135,6 +157,14 @@ def __init__(self, temp_folder):
self.servers = dict()
self.temp_folder = temp_folder

async def unload_plugin(self, name):
plugin = next((p for p in self.plugins if p.name == name), None)
if plugin:
await plugin.shutdown_plugin()
self.plugins.remove(plugin)
return True
return False

async def reconnect(self, connection):
while not connection.is_connected():
logging.error("Waiting 30 seconds to reconnect")
Expand Down Expand Up @@ -261,8 +291,11 @@ async def load_plugins():
logging.exception("Bot.run aborted")

# Request termination of plugins
for plugin in self.plugins:
os.kill(plugin.pid, signal.SIGTERM)
async def unload_all_plugins():
tasks = [self.unload_plugin(plugin.name) for plugin in list(self.plugins)]
await asyncio.gather(*tasks, return_exceptions=True)

self.loop.run_until_complete(unload_all_plugins())

self.loop.close()
command_line.wait_until_done()
Expand Down
17 changes: 10 additions & 7 deletions command_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,17 @@ def do_reload_plugin(self, arg):
def do_unload_plugin(self, arg):
"""Unload a plugin. Usage: unload_plugin <plugin_name>"""
name = arg.strip()
plugin = next((p for p in self.bot.plugins if p.name == name), None)
if not plugin:
print(f"Plugin '{name}' not found.")
return

async def _unload():
if await self.bot.unload_plugin(name):
print(f"Plugin '{name}' unloaded.")
else:
print(f"Plugin '{name}' not found.")

# Schedule _unload on the bot's main asyncio loop
future = asyncio.run_coroutine_threadsafe(_unload(), self.bot.loop)
try:
os.kill(plugin.pid, signal.SIGTERM)
self.bot.plugins.remove(plugin)
print(f"Plugin '{name}' unloaded.")
future.result()
except Exception as e:
print(f"Error unloading plugin '{name}': {e}")

Expand Down
11 changes: 10 additions & 1 deletion plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@


class Plugin:
"""
Base class for plugins. This class runs in its own separate process,
spawned by the main bot process.
"""

def __init__(self, name):

locale.setlocale(locale.LC_ALL, "")
Expand Down Expand Up @@ -59,7 +64,7 @@ def __init__(self, name):

def _recieve(self, data):
func_name = data["function"]
if func_name.startswith("on_") or func_name in ["started", "update"]:
if func_name.startswith("on_") or func_name in ["started", "update", "shutdown"]:
try:
func = getattr(self, func_name)
except AttributeError as e:
Expand All @@ -70,6 +75,10 @@ def _recieve(self, data):
else:
logging.warning("Unsupported call to plugin function with name " + func_name)

def shutdown(self):
logging.info("Plugin.shutdown")
asyncio.get_event_loop().stop()

def _call(self, function, *args):
logging.info("Plugin.call %s", self.threading_data.__dict__)
socket = self.threading_data.call_socket
Expand Down
56 changes: 53 additions & 3 deletions test/test_bot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import irc.client_aio
import signal
import os

import unittest
import tempfile
Expand Down Expand Up @@ -94,14 +95,12 @@ def test_reconnect_loop(self, mock_settings, mock_sleep):
@patch("bot.PluginInterface")
@patch("irc.client_aio.AioReactor")
@patch("os.spawnvpe", return_value=1234)
@patch("os.kill") # temporary until we have graceful shutdowns of plugins
@patch("bot.irc.connection.AioFactory")
@patch("bot.CommandLine")
def test_run(
self,
mock_command_line,
mock_factory,
mock_kill,
mock_spawnvpe,
mock_reactor,
mock_plugin_interface,
Expand All @@ -116,6 +115,8 @@ def test_run(

plugin_interface_instance = MagicMock()
plugin_interface_instance.pid = 666
plugin_interface_instance.name = "test_plugin"
plugin_interface_instance.shutdown_plugin = AsyncMock()
mock_plugin_interface.return_value = plugin_interface_instance

bot = self.create_bot(mock_settings)
Expand All @@ -124,7 +125,7 @@ def test_run(
bot.run()

mock_plugin_interface.assert_called_once() # Make sure we loaded one pluign
mock_kill.assert_called_with(666, signal.SIGTERM) # Make sure the plugin was destroyed
plugin_interface_instance.shutdown_plugin.assert_called_once() # Make sure the plugin was gracefully shut down
reactor_instance.server.assert_called_once() # Make sure we create the server
server_mock.connect.assert_called_once() # Make sure we try to connect to the server
reactor_instance.process_forever.assert_called_once() # Make sure the reactor runs forever
Expand All @@ -144,3 +145,52 @@ def test_plugin_init(self, mock_settings, mock_context):
plugin = PluginInterface("testplugin", bot, 123)
mock_socket.bind.assert_called()
bot.plugin_started.assert_called_with(plugin)

@patch("bot.zmq.asyncio.Context")
@patch("bot.settings")
@patch("bot.os.waitpid")
@patch("bot.os.kill")
@patch("bot.asyncio.sleep", new_callable=AsyncMock)
async def test_shutdown_plugin_clean(
self, mock_sleep, mock_kill, mock_waitpid, mock_settings, mock_context
):
bot = MagicMock()
mock_socket = MagicMock()
mock_context.return_value.socket.return_value = mock_socket

plugin = PluginInterface("testplugin", bot, 123)
# Mocking the _call method to avoid sending real zmq messages
plugin._call = MagicMock()

# Simulate process exiting cleanly
mock_waitpid.return_value = (123, 0)

await plugin.shutdown_plugin()

plugin._call.assert_called_with("shutdown")
mock_waitpid.assert_called_with(123, os.WNOHANG)
mock_kill.assert_not_called()

@patch("bot.zmq.asyncio.Context")
@patch("bot.settings")
@patch("bot.os.waitpid")
@patch("bot.os.kill")
@patch("bot.asyncio.sleep", new_callable=AsyncMock)
async def test_shutdown_plugin_timeout(
self, mock_sleep, mock_kill, mock_waitpid, mock_settings, mock_context
):
bot = MagicMock()
mock_socket = MagicMock()
mock_context.return_value.socket.return_value = mock_socket

plugin = PluginInterface("testplugin", bot, 123)
plugin._call = MagicMock()

# Simulate process not exiting
mock_waitpid.return_value = (0, 0)

await plugin.shutdown_plugin()

plugin._call.assert_called_with("shutdown")
self.assertEqual(mock_waitpid.call_count, 5)
mock_kill.assert_called_with(123, signal.SIGTERM)
8 changes: 8 additions & 0 deletions test/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,11 @@ async def run_once():

asyncio.run(run_once())
self.plugin.on_message.assert_called_once_with("test_param")

def test_shutdown_stops_loop(self):
with patch("asyncio.get_event_loop") as mock_get_loop:
mock_loop = MagicMock()
mock_get_loop.return_value = mock_loop
self.plugin.shutdown = Plugin.shutdown.__get__(self.plugin, DummyPlugin)
self.plugin.shutdown()
mock_loop.stop.assert_called_once()
Loading