33import multiprocessing
44import os
55import signal
6+ import time
67from asyncio import Task
7- from types import SimpleNamespace
88from typing import Iterator , Optional
99
1010from dispatcher .utils import DuplicateBehavior , MessageAction
@@ -19,18 +19,28 @@ def __init__(self, worker_id: int, finished_queue: multiprocessing.Queue):
1919 # TODO: rename message_queue to call_queue, because this is what cpython ProcessPoolExecutor calls them
2020 self .message_queue : multiprocessing .Queue = multiprocessing .Queue ()
2121 self .process = multiprocessing .Process (target = work_loop , args = (self .worker_id , self .message_queue , finished_queue ))
22+
23+ # Info specific to the current task being ran
2224 self .current_task : Optional [dict ] = None
25+ self .started_at : Optional [int ] = None
26+ self .is_active_cancel : bool = False
27+
28+ # Tracking information for worker
2329 self .finished_count = 0
2430 self .status = 'initialized'
2531 self .exit_msg_event = asyncio .Event ()
26- self .active_cancel = False
2732
2833 async def start (self ) -> None :
2934 self .status = 'spawned'
3035 self .process .start ()
3136 logger .debug (f'Worker { self .worker_id } pid={ self .process .pid } subprocess has spawned' )
3237 self .status = 'starting' # Not ready until it sends callback message
3338
39+ async def start_task (self , message : dict ) -> None :
40+ self .current_task = message # NOTE: this marks this worker as busy
41+ self .message_queue .put (message )
42+ self .started_at = time .monotonic_ns ()
43+
3444 async def join (self ) -> None :
3545 logger .debug (f'Joining worker { self .worker_id } pid={ self .process .pid } subprocess' )
3646 self .process .join ()
@@ -65,12 +75,13 @@ async def stop(self) -> None:
6575 return
6676
6777 def cancel (self ) -> None :
68- self .active_cancel = True # signal for result callback
78+ self .is_active_cancel = True # signal for result callback
6979 self .process .terminate () # SIGTERM
7080
7181 def mark_finished_task (self ) -> None :
72- self .active_cancel = False
82+ self .is_active_cancel = False
7383 self .current_task = None
84+ self .started_at = None
7485 self .finished_count += 1
7586
7687 @property
@@ -79,6 +90,16 @@ def inactive(self) -> bool:
7990 return self .status in ['exited' , 'error' , 'initialized' ]
8091
8192
93+ class PoolEvents :
94+ "Benchmark tests have to re-create this because they use same object in different event loops"
95+
96+ def __init__ (self ) -> None :
97+ self .queue_cleared : asyncio .Event = asyncio .Event () # queue is now 0 length
98+ self .work_cleared : asyncio .Event = asyncio .Event () # Totally quiet, no blocked or queued messages, no busy workers
99+ self .management_event : asyncio .Event = asyncio .Event () # Process spawning is backgrounded, so this is the kicker
100+ self .timeout_event : asyncio .Event = asyncio .Event () # Anything that might affect the timeout watcher task
101+
102+
82103class WorkerPool :
83104 def __init__ (self , num_workers : int , fd_lock : Optional [asyncio .Lock ] = None ):
84105 self .num_workers = num_workers
@@ -97,7 +118,7 @@ def __init__(self, num_workers: int, fd_lock: Optional[asyncio.Lock] = None):
97118 self .management_lock = asyncio .Lock ()
98119 self .fd_lock = fd_lock or asyncio .Lock ()
99120
100- self .events = self . _create_events ()
121+ self .events : PoolEvents = PoolEvents ()
101122
102123 @property
103124 def processed_count (self ):
@@ -107,19 +128,13 @@ def processed_count(self):
107128 def received_count (self ):
108129 return self .processed_count + len (self .queued_messages ) + sum (1 for w in self .workers .values () if w .current_task )
109130
110- def _create_events (self ):
111- "Benchmark tests have to re-create this because they use same object in different event loops"
112- return SimpleNamespace (
113- queue_cleared = asyncio .Event (), # queue is now 0 length
114- work_cleared = asyncio .Event (), # Totally quiet, no blocked or queued messages, no busy workers
115- management_event = asyncio .Event (), # Process spawning is backgrounded, so this is the kicker
116- )
117-
118131 async def start_working (self , dispatcher ) -> None :
119132 self .read_results_task = asyncio .create_task (self .read_results_forever (), name = 'results_task' )
120133 self .read_results_task .add_done_callback (dispatcher .fatal_error_callback )
121134 self .management_task = asyncio .create_task (self .manage_workers (), name = 'management_task' )
122135 self .management_task .add_done_callback (dispatcher .fatal_error_callback )
136+ self .timeout_task = asyncio .create_task (self .manage_timeout (), name = 'timeout_task' )
137+ self .timeout_task .add_done_callback (dispatcher .fatal_error_callback )
123138
124139 async def manage_workers (self ) -> None :
125140 """Enforces worker policy like min and max workers, and later, auto scale-down"""
@@ -140,6 +155,43 @@ async def manage_workers(self) -> None:
140155 self .events .management_event .clear ()
141156 logger .debug ('Pool worker management task exiting' )
142157
158+ async def process_worker_timeouts (self , current_time : float ) -> Optional [int ]:
159+ """
160+ Cancels tasks that have exceeded their timeout.
161+ Returns the system clock time of the next task timeout, for rescheduling.
162+ """
163+ next_deadline = None
164+ for worker in self .workers .values ():
165+ if (not worker .is_active_cancel ) and worker .current_task and worker .started_at and (worker .current_task .get ('timeout' )):
166+ timeout : float = worker .current_task ['timeout' ]
167+ worker_deadline = worker .started_at + int (timeout * 1.0e9 )
168+
169+ # Established that worker is running a task that has a timeout
170+ if worker_deadline < current_time :
171+ uuid : str = worker .current_task .get ('uuid' , '<unknown>' )
172+ delta : float = (current_time - worker .started_at ) * 1.0e9
173+ logger .info (f'Worker { worker .worker_id } runtime { delta :.5f} (s) for task uuid={ uuid } exceeded timeout { timeout } (s), canceling' )
174+ worker .cancel ()
175+ elif next_deadline is None or worker_deadline < next_deadline :
176+ # worker timeout is closer than any yet seen
177+ next_deadline = worker_deadline
178+
179+ return next_deadline
180+
181+ async def manage_timeout (self ) -> None :
182+ while not self .shutting_down :
183+ current_time = time .monotonic_ns ()
184+ pool_deadline = await self .process_worker_timeouts (current_time )
185+ if pool_deadline :
186+ time_until_deadline = (pool_deadline - current_time ) * 1.0e-9
187+ try :
188+ await asyncio .wait_for (self .events .timeout_event .wait (), timeout = time_until_deadline )
189+ except asyncio .TimeoutError :
190+ pass # will handle in next loop run
191+ else :
192+ await self .events .timeout_event .wait ()
193+ self .events .timeout_event .clear ()
194+
143195 async def up (self ) -> None :
144196 worker = PoolWorker (worker_id = self .next_worker_id , finished_queue = self .finished_queue )
145197 self .workers [self .next_worker_id ] = worker
@@ -166,6 +218,7 @@ async def force_shutdown(self) -> None:
166218 async def shutdown (self ) -> None :
167219 self .shutting_down = True
168220 self .events .management_event .set ()
221+ self .events .timeout_event .set ()
169222 await self .stop_workers ()
170223 self .finished_queue .put ('stop' )
171224
@@ -277,8 +330,9 @@ async def dispatch_task(self, message: dict) -> None:
277330
278331 if worker := self .get_free_worker ():
279332 logger .debug (f"Dispatching task (uuid={ uuid } ) to worker (id={ worker .worker_id } )" )
280- worker .current_task = message # NOTE: this marks the worker as busy
281- worker .message_queue .put (message )
333+ await worker .start_task (message )
334+ if 'timeout' in message :
335+ self .events .timeout_event .set () # kick timeout task to set wakeup
282336 else :
283337 logger .warning (f'Queueing task (uuid={ uuid } ), ran out of workers, queued_ct={ len (self .queued_messages )} ' )
284338 self .queued_messages .append (message )
@@ -302,7 +356,7 @@ async def process_finished(self, worker, message) -> None:
302356 result = None
303357 if message .get ("result" ):
304358 result = message ["result" ]
305- if worker .active_cancel :
359+ if worker .is_active_cancel :
306360 msg += ', expected cancel'
307361 if result == '<cancel>' :
308362 msg += ', canceled'
@@ -312,7 +366,7 @@ async def process_finished(self, worker, message) -> None:
312366
313367 # Mark the worker as no longer busy
314368 async with self .management_lock :
315- if worker .active_cancel and result == '<cancel>' :
369+ if worker .is_active_cancel and result == '<cancel>' :
316370 self .canceled_count += 1
317371 elif 'control' in worker .current_task :
318372 self .control_count += 1
@@ -323,6 +377,9 @@ async def process_finished(self, worker, message) -> None:
323377 if not self .queued_messages and all (worker .current_task is None for worker in self .workers .values ()):
324378 self .events .work_cleared .set ()
325379
380+ if 'timeout' in message :
381+ self .events .timeout_event .set ()
382+
326383 async def read_results_forever (self ) -> None :
327384 """Perpetual task that continuously waits for task completions."""
328385 loop = asyncio .get_event_loop ()
0 commit comments