Skip to content

Commit f1151e5

Browse files
committed
allow queueing of messages from any thread
1 parent c962e50 commit f1151e5

File tree

2 files changed

+100
-15
lines changed

2 files changed

+100
-15
lines changed

services/dask-sidecar/src/simcore_service_dask_sidecar/rabbitmq_plugin.py

Lines changed: 94 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import asyncio
22
import logging
3+
import threading
34
from asyncio import AbstractEventLoop
45
from collections.abc import Awaitable
56
from typing import Final
67

78
import distributed
9+
from servicelib.async_utils import cancel_wait_task
810
from servicelib.logging_utils import log_catch, log_context
911
from servicelib.rabbitmq import RabbitMQClient, wait_till_rabbitmq_responsive
12+
from servicelib.rabbitmq._models import RabbitMessage
1013
from settings_library.rabbit import RabbitSettings
1114

1215
from .errors import ConfigurationError
@@ -22,13 +25,40 @@ class RabbitMQPlugin(distributed.WorkerPlugin):
2225
"""Dask Worker Plugin for RabbitMQ integration"""
2326

2427
name = "rabbitmq_plugin"
25-
_loop: AbstractEventLoop | None = None
28+
_main_thread_loop: AbstractEventLoop | None = None
2629
_client: RabbitMQClient | None = None
2730
_settings: RabbitSettings | None = None
31+
_message_queue: asyncio.Queue | None = None
32+
_message_processor: asyncio.Task | None = None
2833

2934
def __init__(self, settings: RabbitSettings):
3035
self._settings = settings
3136

37+
async def _process_messages(self) -> None:
38+
"""Process messages from worker threads in the main thread"""
39+
assert self._message_queue is not None # nosec
40+
assert self._client is not None # nosec
41+
42+
_logger.info("Starting message processor for RabbitMQ")
43+
try:
44+
while True:
45+
# Get message from queue
46+
exchange_name, message_data = await self._message_queue.get()
47+
48+
try:
49+
# Publish to RabbitMQ
50+
await self._client.publish(exchange_name, message_data)
51+
except Exception as e:
52+
_logger.exception("Failed to publish message: %s", str(e))
53+
finally:
54+
# Mark task as done
55+
self._message_queue.task_done()
56+
except asyncio.CancelledError:
57+
_logger.info("RabbitMQ message processor shutting down")
58+
raise
59+
except Exception:
60+
_logger.exception("Unexpected error in RabbitMQ message processor")
61+
3262
def setup(self, worker: distributed.Worker) -> Awaitable[None]:
3363
"""Called when the plugin is attached to a worker"""
3464

@@ -39,17 +69,30 @@ async def _() -> None:
3969
)
4070
return
4171

72+
if threading.current_thread() is threading.main_thread():
73+
_logger.info(
74+
"RabbitMQ client plugin setup is in the main thread! That is good."
75+
)
76+
else:
77+
msg = "RabbitMQ client plugin setup is not the main thread!"
78+
raise ConfigurationError(msg=msg)
79+
4280
with log_context(
4381
_logger,
4482
logging.INFO,
4583
f"RabbitMQ client initialization for worker {worker.address}",
4684
):
47-
self._loop = asyncio.get_event_loop()
85+
self._main_thread_loop = asyncio.get_event_loop()
4886
await wait_till_rabbitmq_responsive(self._settings.dsn)
4987
self._client = RabbitMQClient(
5088
client_name="dask-sidecar", settings=self._settings
5189
)
5290

91+
self._message_queue = asyncio.Queue()
92+
self._message_processor = asyncio.create_task(
93+
self._process_messages(), name="rabbit_message_processor"
94+
)
95+
5396
return _()
5497

5598
def teardown(self, worker: distributed.Worker) -> Awaitable[None]:
@@ -61,17 +104,32 @@ async def _() -> None:
61104
logging.INFO,
62105
f"RabbitMQ client teardown for worker {worker.address}",
63106
):
64-
if self._client:
65-
current_loop = asyncio.get_event_loop()
66-
if self._loop != current_loop:
67-
_logger.warning(
68-
"RabbitMQ client is de-activated (loop mismatch)"
69-
)
70-
assert self._loop # nosec
107+
if not self._client:
108+
return
109+
if threading.current_thread() is threading.main_thread():
110+
_logger.info(
111+
"RabbitMQ client plugin setup is in the main thread! That is good."
112+
)
113+
else:
114+
_logger.warning(
115+
"RabbitMQ client plugin setup is not the main thread!"
116+
)
117+
118+
# Cancel the message processor task
119+
if self._message_processor:
71120
with log_catch(_logger, reraise=False):
72-
await asyncio.wait_for(self._client.close(), timeout=5.0)
121+
await cancel_wait_task(self._message_processor, max_delay=5)
122+
self._message_processor = None
123+
124+
# close client
125+
current_loop = asyncio.get_event_loop()
126+
if self._main_thread_loop != current_loop:
127+
_logger.warning("RabbitMQ client is de-activated (loop mismatch)")
128+
assert self._main_thread_loop # nosec
129+
with log_catch(_logger, reraise=False):
130+
await asyncio.wait_for(self._client.close(), timeout=5.0)
73131

74-
self._client = None
132+
self._client = None
75133

76134
return _()
77135

@@ -81,12 +139,35 @@ def get_client(self) -> RabbitMQClient:
81139
raise ConfigurationError(msg=_RABBITMQ_CONFIGURATION_ERROR)
82140
return self._client
83141

142+
async def publish_message_from_any_thread(
143+
self, exchange_name: str, message_data: RabbitMessage
144+
) -> None:
145+
"""Enqueue a message to be published to RabbitMQ from any thread"""
146+
assert self._message_queue # nosec
147+
148+
if threading.current_thread() is threading.main_thread():
149+
# If we're in the main thread, add directly to the queue
150+
await self._message_queue.put((exchange_name, message_data))
151+
return
152+
153+
# If we're in a worker thread, we need to use a different approach
154+
assert self._main_thread_loop # nosec
155+
156+
# Create a Future in the main thread's event loop
157+
future = asyncio.run_coroutine_threadsafe(
158+
self._message_queue.put((exchange_name, message_data)),
159+
self._main_thread_loop,
160+
)
161+
162+
# waiting here is quick, just queueing
163+
future.result()
164+
84165

85-
def get_rabbitmq_client(worker: distributed.Worker) -> RabbitMQClient:
166+
def get_rabbitmq_client(worker: distributed.Worker) -> RabbitMQPlugin:
86167
"""Returns the RabbitMQ client or raises an error if not available"""
87168
if not worker.plugins:
88169
raise ConfigurationError(msg=_RABBITMQ_CONFIGURATION_ERROR)
89170
rabbitmq_plugin = worker.plugins.get(RabbitMQPlugin.name)
90171
if not isinstance(rabbitmq_plugin, RabbitMQPlugin):
91172
raise ConfigurationError(msg=_RABBITMQ_CONFIGURATION_ERROR)
92-
return rabbitmq_plugin.get_client()
173+
return rabbitmq_plugin

services/dask-sidecar/src/simcore_service_dask_sidecar/utils/dask.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ async def publish_logs(
9797
messages=[message],
9898
log_level=log_level,
9999
)
100-
await rabbitmq_client.publish(base_message.channel_name, base_message)
100+
await rabbitmq_client.publish_message_from_any_thread(
101+
base_message.channel_name, base_message
102+
)
101103
if self.task_owner.has_parent:
102104
assert self.task_owner.parent_project_id # nosec
103105
assert self.task_owner.parent_node_id # nosec
@@ -108,7 +110,9 @@ async def publish_logs(
108110
messages=[message],
109111
log_level=log_level,
110112
)
111-
await rabbitmq_client.publish(parent_message.channel_name, base_message)
113+
await rabbitmq_client.publish_message_from_any_thread(
114+
parent_message.channel_name, base_message
115+
)
112116

113117
_logger.log(log_level, message)
114118

0 commit comments

Comments
 (0)