11import asyncio
22import logging
3+ import threading
34from asyncio import AbstractEventLoop
45from collections .abc import Awaitable
56from typing import Final
67
78import distributed
9+ from servicelib .async_utils import cancel_wait_task
810from servicelib .logging_utils import log_catch , log_context
911from servicelib .rabbitmq import RabbitMQClient , wait_till_rabbitmq_responsive
12+ from servicelib .rabbitmq ._models import RabbitMessage
1013from settings_library .rabbit import RabbitSettings
1114
1215from .errors import ConfigurationError
@@ -22,13 +25,40 @@ class RabbitMQPlugin(distributed.WorkerPlugin):
2225 """Dask Worker Plugin for RabbitMQ integration"""
2326
2427 name = "rabbitmq_plugin"
25- _loop : AbstractEventLoop | None = None
28+ _main_thread_loop : AbstractEventLoop | None = None
2629 _client : RabbitMQClient | None = None
2730 _settings : RabbitSettings | None = None
31+ _message_queue : asyncio .Queue | None = None
32+ _message_processor : asyncio .Task | None = None
2833
2934 def __init__ (self , settings : RabbitSettings ):
3035 self ._settings = settings
3136
37+ async def _process_messages (self ) -> None :
38+ """Process messages from worker threads in the main thread"""
39+ assert self ._message_queue is not None # nosec
40+ assert self ._client is not None # nosec
41+
42+ _logger .info ("Starting message processor for RabbitMQ" )
43+ try :
44+ while True :
45+ # Get message from queue
46+ exchange_name , message_data = await self ._message_queue .get ()
47+
48+ try :
49+ # Publish to RabbitMQ
50+ await self ._client .publish (exchange_name , message_data )
51+ except Exception as e :
52+ _logger .exception ("Failed to publish message: %s" , str (e ))
53+ finally :
54+ # Mark task as done
55+ self ._message_queue .task_done ()
56+ except asyncio .CancelledError :
57+ _logger .info ("RabbitMQ message processor shutting down" )
58+ raise
59+ except Exception :
60+ _logger .exception ("Unexpected error in RabbitMQ message processor" )
61+
3262 def setup (self , worker : distributed .Worker ) -> Awaitable [None ]:
3363 """Called when the plugin is attached to a worker"""
3464
@@ -39,17 +69,30 @@ async def _() -> None:
3969 )
4070 return
4171
72+ if threading .current_thread () is threading .main_thread ():
73+ _logger .info (
74+ "RabbitMQ client plugin setup is in the main thread! That is good."
75+ )
76+ else :
77+ msg = "RabbitMQ client plugin setup is not the main thread!"
78+ raise ConfigurationError (msg = msg )
79+
4280 with log_context (
4381 _logger ,
4482 logging .INFO ,
4583 f"RabbitMQ client initialization for worker { worker .address } " ,
4684 ):
47- self ._loop = asyncio .get_event_loop ()
85+ self ._main_thread_loop = asyncio .get_event_loop ()
4886 await wait_till_rabbitmq_responsive (self ._settings .dsn )
4987 self ._client = RabbitMQClient (
5088 client_name = "dask-sidecar" , settings = self ._settings
5189 )
5290
91+ self ._message_queue = asyncio .Queue ()
92+ self ._message_processor = asyncio .create_task (
93+ self ._process_messages (), name = "rabbit_message_processor"
94+ )
95+
5396 return _ ()
5497
5598 def teardown (self , worker : distributed .Worker ) -> Awaitable [None ]:
@@ -61,17 +104,32 @@ async def _() -> None:
61104 logging .INFO ,
62105 f"RabbitMQ client teardown for worker { worker .address } " ,
63106 ):
64- if self ._client :
65- current_loop = asyncio .get_event_loop ()
66- if self ._loop != current_loop :
67- _logger .warning (
68- "RabbitMQ client is de-activated (loop mismatch)"
69- )
70- assert self ._loop # nosec
107+ if not self ._client :
108+ return
109+ if threading .current_thread () is threading .main_thread ():
110+ _logger .info (
111+ "RabbitMQ client plugin setup is in the main thread! That is good."
112+ )
113+ else :
114+ _logger .warning (
115+ "RabbitMQ client plugin setup is not the main thread!"
116+ )
117+
118+ # Cancel the message processor task
119+ if self ._message_processor :
71120 with log_catch (_logger , reraise = False ):
72- await asyncio .wait_for (self ._client .close (), timeout = 5.0 )
121+ await cancel_wait_task (self ._message_processor , max_delay = 5 )
122+ self ._message_processor = None
123+
124+ # close client
125+ current_loop = asyncio .get_event_loop ()
126+ if self ._main_thread_loop != current_loop :
127+ _logger .warning ("RabbitMQ client is de-activated (loop mismatch)" )
128+ assert self ._main_thread_loop # nosec
129+ with log_catch (_logger , reraise = False ):
130+ await asyncio .wait_for (self ._client .close (), timeout = 5.0 )
73131
74- self ._client = None
132+ self ._client = None
75133
76134 return _ ()
77135
@@ -81,12 +139,35 @@ def get_client(self) -> RabbitMQClient:
81139 raise ConfigurationError (msg = _RABBITMQ_CONFIGURATION_ERROR )
82140 return self ._client
83141
142+ async def publish_message_from_any_thread (
143+ self , exchange_name : str , message_data : RabbitMessage
144+ ) -> None :
145+ """Enqueue a message to be published to RabbitMQ from any thread"""
146+ assert self ._message_queue # nosec
147+
148+ if threading .current_thread () is threading .main_thread ():
149+ # If we're in the main thread, add directly to the queue
150+ await self ._message_queue .put ((exchange_name , message_data ))
151+ return
152+
153+ # If we're in a worker thread, we need to use a different approach
154+ assert self ._main_thread_loop # nosec
155+
156+ # Create a Future in the main thread's event loop
157+ future = asyncio .run_coroutine_threadsafe (
158+ self ._message_queue .put ((exchange_name , message_data )),
159+ self ._main_thread_loop ,
160+ )
161+
162+ # waiting here is quick, just queueing
163+ future .result ()
164+
84165
85- def get_rabbitmq_client (worker : distributed .Worker ) -> RabbitMQClient :
166+ def get_rabbitmq_client (worker : distributed .Worker ) -> RabbitMQPlugin :
86167 """Returns the RabbitMQ client or raises an error if not available"""
87168 if not worker .plugins :
88169 raise ConfigurationError (msg = _RABBITMQ_CONFIGURATION_ERROR )
89170 rabbitmq_plugin = worker .plugins .get (RabbitMQPlugin .name )
90171 if not isinstance (rabbitmq_plugin , RabbitMQPlugin ):
91172 raise ConfigurationError (msg = _RABBITMQ_CONFIGURATION_ERROR )
92- return rabbitmq_plugin . get_client ()
173+ return rabbitmq_plugin
0 commit comments