@@ -172,18 +172,9 @@ def __init__(self, *,
172172
173173 self .cert_dir = cert_dir
174174 self .zmq_context = curvezmq .ClientContext (self .cert_dir )
175- self .task_incoming = self .zmq_context .socket (zmq .DEALER )
176- self .task_incoming .setsockopt (zmq .IDENTITY , uid .encode ('utf-8' ))
177- # Linger is set to 0, so that the manager can exit even when there might be
178- # messages in the pipe
179- self .task_incoming .setsockopt (zmq .LINGER , 0 )
180- self .task_incoming .connect (task_q_url )
181175
182- self .result_outgoing = self .zmq_context .socket (zmq .DEALER )
183- self .result_outgoing .setsockopt (zmq .IDENTITY , uid .encode ('utf-8' ))
184- self .result_outgoing .setsockopt (zmq .LINGER , 0 )
185- self .result_outgoing .connect (result_q_url )
186- logger .info ("Manager connected to interchange" )
176+ self ._task_q_url = task_q_url
177+ self ._result_q_url = result_q_url
187178
188179 self .uid = uid
189180 self .block_id = block_id
@@ -282,21 +273,23 @@ def create_reg_message(self):
282273 b_msg = json .dumps (msg ).encode ('utf-8' )
283274 return b_msg
284275
285- def heartbeat_to_incoming (self ):
276+ @staticmethod
277+ def heartbeat_to_incoming (task_incoming : zmq .Socket ) -> None :
286278 """ Send heartbeat to the incoming task queue
287279 """
288280 msg = {'type' : 'heartbeat' }
289281 # don't need to dumps and encode this every time - could do as a global on import?
290282 b_msg = json .dumps (msg ).encode ('utf-8' )
291- self . task_incoming .send (b_msg )
283+ task_incoming .send (b_msg )
292284 logger .debug ("Sent heartbeat" )
293285
294- def drain_to_incoming (self ):
286+ @staticmethod
287+ def drain_to_incoming (task_incoming : zmq .Socket ) -> None :
295288 """ Send heartbeat to the incoming task queue
296289 """
297290 msg = {'type' : 'drain' }
298291 b_msg = json .dumps (msg ).encode ('utf-8' )
299- self . task_incoming .send (b_msg )
292+ task_incoming .send (b_msg )
300293 logger .debug ("Sent drain" )
301294
302295 @wrap_with_logs
@@ -305,13 +298,22 @@ def pull_tasks(self):
305298 pending task queue
306299 """
307300 logger .info ("starting" )
301+
302+ # Linger is set to 0, so that the manager can exit even when there might be
303+ # messages in the pipe
304+ task_incoming = self .zmq_context .socket (zmq .DEALER )
305+ task_incoming .setsockopt (zmq .IDENTITY , self .uid .encode ('utf-8' ))
306+ task_incoming .setsockopt (zmq .LINGER , 0 )
307+ task_incoming .connect (self ._task_q_url )
308+ logger .info ("Manager task pipe connected to interchange" )
309+
308310 poller = zmq .Poller ()
309- poller .register (self . task_incoming , zmq .POLLIN )
311+ poller .register (task_incoming , zmq .POLLIN )
310312
311313 # Send a registration message
312314 msg = self .create_reg_message ()
313315 logger .debug ("Sending registration message: {}" .format (msg ))
314- self . task_incoming .send (msg )
316+ task_incoming .send (msg )
315317 last_beat = time .time ()
316318 last_interchange_contact = time .time ()
317319 task_recv_counter = 0
@@ -338,12 +340,12 @@ def pull_tasks(self):
338340 pending_task_count ))
339341
340342 if time .time () >= last_beat + self .heartbeat_period :
341- self .heartbeat_to_incoming ()
343+ self .heartbeat_to_incoming (task_incoming )
342344 last_beat = time .time ()
343345
344346 if time .time () > self .drain_time :
345347 logger .info ("Requesting drain" )
346- self .drain_to_incoming ()
348+ self .drain_to_incoming (task_incoming )
347349 # This will start the pool draining...
348350 # Drained exit behaviour does not happen here. It will be
349351 # driven by the interchange sending a DRAINED_CODE message.
@@ -355,8 +357,8 @@ def pull_tasks(self):
355357 poll_duration_s = max (0 , next_interesting_event_time - time .time ())
356358 socks = dict (poller .poll (timeout = poll_duration_s * 1000 ))
357359
358- if self . task_incoming in socks and socks [ self . task_incoming ] == zmq .POLLIN :
359- _ , pkl_msg = self . task_incoming .recv_multipart ()
360+ if socks . get ( task_incoming ) == zmq .POLLIN :
361+ _ , pkl_msg = task_incoming .recv_multipart ()
360362 tasks = pickle .loads (pkl_msg )
361363 last_interchange_contact = time .time ()
362364
@@ -384,12 +386,23 @@ def pull_tasks(self):
384386 logger .critical ("Exiting" )
385387 break
386388
389+ task_incoming .close ()
390+ logger .info ("Exiting" )
391+
387392 @wrap_with_logs
388393 def push_results (self ):
389394 """ Listens on the pending_result_queue and sends out results via zmq
390395 """
391396 logger .debug ("Starting result push thread" )
392397
398+ # Linger is set to 0, so that the manager can exit even when there might be
399+ # messages in the pipe
400+ result_outgoing = self .zmq_context .socket (zmq .DEALER )
401+ result_outgoing .setsockopt (zmq .IDENTITY , self .uid .encode ('utf-8' ))
402+ result_outgoing .setsockopt (zmq .LINGER , 0 )
403+ result_outgoing .connect (self ._result_q_url )
404+ logger .info ("Manager result pipe connected to interchange" )
405+
393406 push_poll_period = max (10 , self .poll_period ) / 1000 # push_poll_period must be atleast 10 ms
394407 logger .debug ("push poll period: {}" .format (push_poll_period ))
395408
@@ -418,15 +431,16 @@ def push_results(self):
418431 last_beat = time .time ()
419432 if items :
420433 logger .debug (f"Result send: Pushing { len (items )} items" )
421- self . result_outgoing .send_multipart (items )
434+ result_outgoing .send_multipart (items )
422435 logger .debug ("Result send: Pushed" )
423436 items = []
424437 else :
425438 logger .debug ("Result send: No items to push" )
426439 else :
427440 logger .debug (f"Result send: check condition not met - deferring { len (items )} result items" )
428441
429- logger .critical ("Exiting" )
442+ result_outgoing .close ()
443+ logger .info ("Exiting" )
430444
431445 @wrap_with_logs
432446 def worker_watchdog (self ):
@@ -533,8 +547,6 @@ def start(self):
533547 self .procs [proc_id ].join ()
534548 logger .debug ("Worker {} joined successfully" .format (self .procs [proc_id ]))
535549
536- self .task_incoming .close ()
537- self .result_outgoing .close ()
538550 self .zmq_context .term ()
539551 delta = time .time () - self ._start_time
540552 logger .info ("process_worker_pool ran for {} seconds" .format (delta ))
0 commit comments