Skip to content

Commit 0d7535a

Browse files
authored
Add task timeout (#52)
* Add timeout task param * Add nice log for when timeouts happen * Add tests for task timeout * Implement review comments, better events data structure
1 parent 84f7a97 commit 0d7535a

File tree

11 files changed

+200
-36
lines changed

11 files changed

+200
-36
lines changed

dispatcher/main.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@ async def alive(self, dispatcher, **data):
7171
return
7272

7373

74+
class DispatcherEvents:
75+
"Benchmark tests have to re-create this because they use same object in different event loops"
76+
77+
def __init__(self) -> None:
78+
self.exit_event: asyncio.Event = asyncio.Event()
79+
80+
7481
class DispatcherMain:
7582
def __init__(self, config: dict):
7683
self.delayed_messages: list[SimpleNamespace] = []
@@ -97,11 +104,7 @@ def __init__(self, config: dict):
97104
if 'scheduled' in producer_config:
98105
self.producers.append(ScheduledProducer(producer_config['scheduled']))
99106

100-
self.events = self._create_events()
101-
102-
def _create_events(self):
103-
"Benchmark tests have to re-create this because they use same object in different event loops"
104-
return SimpleNamespace(exit_event=asyncio.Event())
107+
self.events: DispatcherEvents = DispatcherEvents()
105108

106109
def fatal_error_callback(self, *args) -> None:
107110
"""Method to connect to error callbacks of other tasks, will kick out of main loop"""

dispatcher/pool.py

Lines changed: 74 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import multiprocessing
44
import os
55
import signal
6+
import time
67
from asyncio import Task
7-
from types import SimpleNamespace
88
from typing import Iterator, Optional
99

1010
from dispatcher.utils import DuplicateBehavior, MessageAction
@@ -19,18 +19,28 @@ def __init__(self, worker_id: int, finished_queue: multiprocessing.Queue):
1919
# TODO: rename message_queue to call_queue, because this is what cpython ProcessPoolExecutor calls them
2020
self.message_queue: multiprocessing.Queue = multiprocessing.Queue()
2121
self.process = multiprocessing.Process(target=work_loop, args=(self.worker_id, self.message_queue, finished_queue))
22+
23+
# Info specific to the current task being ran
2224
self.current_task: Optional[dict] = None
25+
self.started_at: Optional[int] = None
26+
self.is_active_cancel: bool = False
27+
28+
# Tracking information for worker
2329
self.finished_count = 0
2430
self.status = 'initialized'
2531
self.exit_msg_event = asyncio.Event()
26-
self.active_cancel = False
2732

2833
async def start(self) -> None:
2934
self.status = 'spawned'
3035
self.process.start()
3136
logger.debug(f'Worker {self.worker_id} pid={self.process.pid} subprocess has spawned')
3237
self.status = 'starting' # Not ready until it sends callback message
3338

39+
async def start_task(self, message: dict) -> None:
40+
self.current_task = message # NOTE: this marks this worker as busy
41+
self.message_queue.put(message)
42+
self.started_at = time.monotonic_ns()
43+
3444
async def join(self) -> None:
3545
logger.debug(f'Joining worker {self.worker_id} pid={self.process.pid} subprocess')
3646
self.process.join()
@@ -65,12 +75,13 @@ async def stop(self) -> None:
6575
return
6676

6777
def cancel(self) -> None:
68-
self.active_cancel = True # signal for result callback
78+
self.is_active_cancel = True # signal for result callback
6979
self.process.terminate() # SIGTERM
7080

7181
def mark_finished_task(self) -> None:
72-
self.active_cancel = False
82+
self.is_active_cancel = False
7383
self.current_task = None
84+
self.started_at = None
7485
self.finished_count += 1
7586

7687
@property
@@ -79,6 +90,16 @@ def inactive(self) -> bool:
7990
return self.status in ['exited', 'error', 'initialized']
8091

8192

93+
class PoolEvents:
94+
"Benchmark tests have to re-create this because they use same object in different event loops"
95+
96+
def __init__(self) -> None:
97+
self.queue_cleared: asyncio.Event = asyncio.Event() # queue is now 0 length
98+
self.work_cleared: asyncio.Event = asyncio.Event() # Totally quiet, no blocked or queued messages, no busy workers
99+
self.management_event: asyncio.Event = asyncio.Event() # Process spawning is backgrounded, so this is the kicker
100+
self.timeout_event: asyncio.Event = asyncio.Event() # Anything that might affect the timeout watcher task
101+
102+
82103
class WorkerPool:
83104
def __init__(self, num_workers: int, fd_lock: Optional[asyncio.Lock] = None):
84105
self.num_workers = num_workers
@@ -97,7 +118,7 @@ def __init__(self, num_workers: int, fd_lock: Optional[asyncio.Lock] = None):
97118
self.management_lock = asyncio.Lock()
98119
self.fd_lock = fd_lock or asyncio.Lock()
99120

100-
self.events = self._create_events()
121+
self.events: PoolEvents = PoolEvents()
101122

102123
@property
103124
def processed_count(self):
@@ -107,19 +128,13 @@ def processed_count(self):
107128
def received_count(self):
108129
return self.processed_count + len(self.queued_messages) + sum(1 for w in self.workers.values() if w.current_task)
109130

110-
def _create_events(self):
111-
"Benchmark tests have to re-create this because they use same object in different event loops"
112-
return SimpleNamespace(
113-
queue_cleared=asyncio.Event(), # queue is now 0 length
114-
work_cleared=asyncio.Event(), # Totally quiet, no blocked or queued messages, no busy workers
115-
management_event=asyncio.Event(), # Process spawning is backgrounded, so this is the kicker
116-
)
117-
118131
async def start_working(self, dispatcher) -> None:
119132
self.read_results_task = asyncio.create_task(self.read_results_forever(), name='results_task')
120133
self.read_results_task.add_done_callback(dispatcher.fatal_error_callback)
121134
self.management_task = asyncio.create_task(self.manage_workers(), name='management_task')
122135
self.management_task.add_done_callback(dispatcher.fatal_error_callback)
136+
self.timeout_task = asyncio.create_task(self.manage_timeout(), name='timeout_task')
137+
self.timeout_task.add_done_callback(dispatcher.fatal_error_callback)
123138

124139
async def manage_workers(self) -> None:
125140
"""Enforces worker policy like min and max workers, and later, auto scale-down"""
@@ -140,6 +155,43 @@ async def manage_workers(self) -> None:
140155
self.events.management_event.clear()
141156
logger.debug('Pool worker management task exiting')
142157

158+
async def process_worker_timeouts(self, current_time: float) -> Optional[int]:
159+
"""
160+
Cancels tasks that have exceeded their timeout.
161+
Returns the system clock time of the next task timeout, for rescheduling.
162+
"""
163+
next_deadline = None
164+
for worker in self.workers.values():
165+
if (not worker.is_active_cancel) and worker.current_task and worker.started_at and (worker.current_task.get('timeout')):
166+
timeout: float = worker.current_task['timeout']
167+
worker_deadline = worker.started_at + int(timeout * 1.0e9)
168+
169+
# Established that worker is running a task that has a timeout
170+
if worker_deadline < current_time:
171+
uuid: str = worker.current_task.get('uuid', '<unknown>')
172+
delta: float = (current_time - worker.started_at) * 1.0e9
173+
logger.info(f'Worker {worker.worker_id} runtime {delta:.5f}(s) for task uuid={uuid} exceeded timeout {timeout}(s), canceling')
174+
worker.cancel()
175+
elif next_deadline is None or worker_deadline < next_deadline:
176+
# worker timeout is closer than any yet seen
177+
next_deadline = worker_deadline
178+
179+
return next_deadline
180+
181+
async def manage_timeout(self) -> None:
182+
while not self.shutting_down:
183+
current_time = time.monotonic_ns()
184+
pool_deadline = await self.process_worker_timeouts(current_time)
185+
if pool_deadline:
186+
time_until_deadline = (pool_deadline - current_time) * 1.0e-9
187+
try:
188+
await asyncio.wait_for(self.events.timeout_event.wait(), timeout=time_until_deadline)
189+
except asyncio.TimeoutError:
190+
pass # will handle in next loop run
191+
else:
192+
await self.events.timeout_event.wait()
193+
self.events.timeout_event.clear()
194+
143195
async def up(self) -> None:
144196
worker = PoolWorker(worker_id=self.next_worker_id, finished_queue=self.finished_queue)
145197
self.workers[self.next_worker_id] = worker
@@ -166,6 +218,7 @@ async def force_shutdown(self) -> None:
166218
async def shutdown(self) -> None:
167219
self.shutting_down = True
168220
self.events.management_event.set()
221+
self.events.timeout_event.set()
169222
await self.stop_workers()
170223
self.finished_queue.put('stop')
171224

@@ -277,8 +330,9 @@ async def dispatch_task(self, message: dict) -> None:
277330

278331
if worker := self.get_free_worker():
279332
logger.debug(f"Dispatching task (uuid={uuid}) to worker (id={worker.worker_id})")
280-
worker.current_task = message # NOTE: this marks the worker as busy
281-
worker.message_queue.put(message)
333+
await worker.start_task(message)
334+
if 'timeout' in message:
335+
self.events.timeout_event.set() # kick timeout task to set wakeup
282336
else:
283337
logger.warning(f'Queueing task (uuid={uuid}), ran out of workers, queued_ct={len(self.queued_messages)}')
284338
self.queued_messages.append(message)
@@ -302,7 +356,7 @@ async def process_finished(self, worker, message) -> None:
302356
result = None
303357
if message.get("result"):
304358
result = message["result"]
305-
if worker.active_cancel:
359+
if worker.is_active_cancel:
306360
msg += ', expected cancel'
307361
if result == '<cancel>':
308362
msg += ', canceled'
@@ -312,7 +366,7 @@ async def process_finished(self, worker, message) -> None:
312366

313367
# Mark the worker as no longer busy
314368
async with self.management_lock:
315-
if worker.active_cancel and result == '<cancel>':
369+
if worker.is_active_cancel and result == '<cancel>':
316370
self.canceled_count += 1
317371
elif 'control' in worker.current_task:
318372
self.control_count += 1
@@ -323,6 +377,9 @@ async def process_finished(self, worker, message) -> None:
323377
if not self.queued_messages and all(worker.current_task is None for worker in self.workers.values()):
324378
self.events.work_cleared.set()
325379

380+
if 'timeout' in message:
381+
self.events.timeout_event.set()
382+
326383
async def read_results_forever(self) -> None:
327384
"""Perpetual task that continuously waits for task completions."""
328385
loop = asyncio.get_event_loop()

dispatcher/producers/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import asyncio
2-
from types import SimpleNamespace
2+
3+
4+
class ProducerEvents:
5+
def __init__(self):
6+
self.ready_event = asyncio.Event()
37

48

59
class BaseProducer:
6-
def _create_events(self):
7-
return SimpleNamespace(ready_event=asyncio.Event())
10+
pass

dispatcher/producers/brokered.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
from typing import Optional
44

55
from dispatcher.brokers.pg_notify import aget_connection, aprocess_notify, apublish_message
6-
from dispatcher.producers.base import BaseProducer
6+
from dispatcher.producers.base import BaseProducer, ProducerEvents
77

88
logger = logging.getLogger(__name__)
99

1010

1111
class BrokeredProducer(BaseProducer):
1212
def __init__(self, broker: str = 'pg_notify', config: Optional[dict] = None, channels: tuple = (), connection=None) -> None:
13-
self.events = self._create_events()
13+
self.events = ProducerEvents()
1414
self.production_task: Optional[asyncio.Task] = None
1515
self.broker = broker
1616
self.config = config

dispatcher/producers/scheduled.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import asyncio
22
import logging
33

4-
from dispatcher.producers.base import BaseProducer
4+
from dispatcher.producers.base import BaseProducer, ProducerEvents
55

66
logger = logging.getLogger(__name__)
77

88

99
class ScheduledProducer(BaseProducer):
1010
def __init__(self, task_schedule: dict):
11-
self.events = self._create_events()
11+
self.events = ProducerEvents()
1212
self.task_schedule = task_schedule
1313
self.scheduled_tasks: list[asyncio.Task] = []
1414
self.produced_count = 0

dispatcher/publish.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,18 @@
99

1010

1111
class DispatcherDecorator:
12-
def __init__(self, registry: DispatcherMethodRegistry, *, queue: Optional[str] = None, on_duplicate: Optional[str] = None) -> None:
12+
def __init__(
13+
self, registry: DispatcherMethodRegistry, *, queue: Optional[str] = None, on_duplicate: Optional[str] = None, timeout: Optional[float] = None
14+
) -> None:
1315
self.registry = registry
1416
self.queue = queue
1517
self.on_duplicate = on_duplicate
18+
self.timeout = timeout
1619

1720
def __call__(self, fn: DispatcherCallable, /) -> DispatcherCallable:
1821
"Concrete task decorator, registers method and glues on some methods from the registry"
1922

20-
dmethod = self.registry.register(fn, queue=self.queue, on_duplicate=self.on_duplicate)
23+
dmethod = self.registry.register(fn, queue=self.queue, on_duplicate=self.on_duplicate, timeout=self.timeout)
2124

2225
setattr(fn, 'apply_async', dmethod.apply_async)
2326
setattr(fn, 'delay', dmethod.delay)
@@ -29,6 +32,7 @@ def task(
2932
*,
3033
queue: Optional[str] = None,
3134
on_duplicate: Optional[str] = None,
35+
timeout: Optional[float] = None,
3236
registry: DispatcherMethodRegistry = default_registry,
3337
) -> DispatcherDecorator:
3438
"""
@@ -68,4 +72,4 @@ def announce():
6872
# The on_duplicate kwarg controls behavior when multiple instances of the task running
6973
# options are documented in dispatcher.utils.DuplicateBehavior
7074
"""
71-
return DispatcherDecorator(registry, queue=queue, on_duplicate=on_duplicate)
75+
return DispatcherDecorator(registry, queue=queue, on_duplicate=on_duplicate, timeout=timeout)

dispatcher/registry.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,18 @@ def get_callable(self) -> Callable:
4343
return self.fn
4444

4545
def publication_defaults(self) -> dict:
46-
defaults = self.submission_defaults.copy()
46+
defaults = {}
47+
for k, v in self.submission_defaults.items():
48+
if v: # all None or falsy values have no effect
49+
defaults[k] = v
4750
defaults['task'] = self.serialize_task()
4851
defaults['time_pub'] = time.time()
4952
return defaults
5053

5154
def delay(self, *args, **kwargs) -> Tuple[dict, str]:
5255
return self.apply_async(args, kwargs)
5356

54-
def get_async_body(self, args=None, kwargs=None, uuid=None, on_duplicate: Optional[str] = None, delay: float = 0.0) -> dict:
57+
def get_async_body(self, args=None, kwargs=None, uuid=None, on_duplicate: Optional[str] = None, timeout: Optional[float] = 0.0, delay: float = 0.0) -> dict:
5558
"""
5659
Get the python dict to become JSON data in the pg_notify message
5760
This same message gets passed over the dispatcher IPC queue to workers
@@ -67,6 +70,8 @@ def get_async_body(self, args=None, kwargs=None, uuid=None, on_duplicate: Option
6770
body['on_duplicate'] = on_duplicate
6871
if delay:
6972
body['delay'] = delay
73+
if timeout:
74+
body['timeout'] = timeout
7075

7176
return body
7277

0 commit comments

Comments
 (0)