11import asyncio
22import logging
3- import multiprocessing
4- import os
5- import signal
63import time
74from asyncio import Task
85from typing import Iterator , Optional
96
7+ from dispatcher .process import ProcessManager , ProcessProxy
108from dispatcher .utils import DuplicateBehavior , MessageAction
11- from dispatcher .worker .task import work_loop
129
1310logger = logging .getLogger (__name__ )
1411
1512
1613class PoolWorker :
17- def __init__ (self , worker_id : int , finished_queue : multiprocessing . Queue ) :
14+ def __init__ (self , worker_id : int , process : ProcessProxy ) -> None :
1815 self .worker_id = worker_id
19- # TODO: rename message_queue to call_queue, because this is what cpython ProcessPoolExecutor calls them
20- self .message_queue : multiprocessing .Queue = multiprocessing .Queue ()
21- self .process = multiprocessing .Process (target = work_loop , args = (self .worker_id , self .message_queue , finished_queue ))
22-
23- # Info specific to the current task being ran
16+ self .process = process
2417 self .current_task : Optional [dict ] = None
2518 self .started_at : Optional [int ] = None
2619 self .is_active_cancel : bool = False
@@ -38,15 +31,15 @@ async def start(self) -> None:
3831
3932 async def start_task (self , message : dict ) -> None :
4033 self .current_task = message # NOTE: this marks this worker as busy
41- self .message_queue .put (message )
34+ self .process . message_queue .put (message )
4235 self .started_at = time .monotonic_ns ()
4336
4437 async def join (self ) -> None :
4538 logger .debug (f'Joining worker { self .worker_id } pid={ self .process .pid } subprocess' )
4639 self .process .join ()
4740
4841 async def stop (self ) -> None :
49- self .message_queue .put ("stop" )
42+ self .process . message_queue .put ("stop" )
5043 if self .current_task :
5144 uuid = self .current_task .get ('uuid' , '<unknown>' )
5245 logger .warning (f'Worker { self .worker_id } is currently running task (uuid={ uuid } ), canceling for shutdown' )
@@ -105,7 +98,7 @@ def __init__(self, num_workers: int, fd_lock: Optional[asyncio.Lock] = None):
10598 self .num_workers = num_workers
10699 self .workers : dict [int , PoolWorker ] = {}
107100 self .next_worker_id = 0
108- self .finished_queue : multiprocessing . Queue = multiprocessing . Queue ()
101+ self .process_manager = ProcessManager ()
109102 self .queued_messages : list [dict ] = [] # TODO: use deque, invent new kinds of logging anxiety
110103 self .read_results_task : Optional [Task ] = None
111104 self .start_worker_task : Optional [Task ] = None
@@ -193,7 +186,8 @@ async def manage_timeout(self) -> None:
193186 self .events .timeout_event .clear ()
194187
195188 async def up (self ) -> None :
196- worker = PoolWorker (worker_id = self .next_worker_id , finished_queue = self .finished_queue )
189+ process = self .process_manager .create_process ((self .next_worker_id ,))
190+ worker = PoolWorker (self .next_worker_id , process )
197191 self .workers [self .next_worker_id ] = worker
198192 self .next_worker_id += 1
199193
@@ -205,7 +199,7 @@ async def force_shutdown(self) -> None:
205199 for worker in self .workers .values ():
206200 if worker .process .pid and worker .process .is_alive ():
207201 logger .warning (f'Force killing worker { worker .worker_id } pid={ worker .process .pid } ' )
208- os . kill ( worker .process .pid , signal . SIGKILL )
202+ worker .process .kill ( )
209203
210204 if self .read_results_task :
211205 self .read_results_task .cancel ()
@@ -220,7 +214,7 @@ async def shutdown(self) -> None:
220214 self .events .management_event .set ()
221215 self .events .timeout_event .set ()
222216 await self .stop_workers ()
223- self .finished_queue .put ('stop' )
217+ self .process_manager . finished_queue .put ('stop' )
224218
225219 if self .read_results_task :
226220 logger .info ('Waiting for the finished watcher to return' )
@@ -382,10 +376,9 @@ async def process_finished(self, worker, message) -> None:
382376
383377 async def read_results_forever (self ) -> None :
384378 """Perpetual task that continuously waits for task completions."""
385- loop = asyncio .get_event_loop ()
386379 while True :
387380 # Wait for a result from the finished queue
388- message = await loop . run_in_executor ( None , self .finished_queue . get )
381+ message = await self .process_manager . read_finished ( )
389382
390383 if message == 'stop' :
391384 if self .shutting_down :
@@ -396,7 +389,7 @@ async def read_results_forever(self) -> None:
396389 logger .error ('Results queue got stop message even through not shutting down' )
397390 continue
398391
399- worker_id = message ["worker" ]
392+ worker_id = int ( message ["worker" ])
400393 event = message ["event" ]
401394 worker = self .workers [worker_id ]
402395
0 commit comments