diff --git a/docling_jobkit/orchestrators/rq/orchestrator.py b/docling_jobkit/orchestrators/rq/orchestrator.py index a5d0ba1..bb30314 100644 --- a/docling_jobkit/orchestrators/rq/orchestrator.py +++ b/docling_jobkit/orchestrators/rq/orchestrator.py @@ -1,5 +1,6 @@ import asyncio import base64 +import datetime import logging import uuid import warnings @@ -12,6 +13,7 @@ from pydantic import BaseModel from rq import Queue from rq.job import Job, JobStatus +from rq.registry import StartedJobRegistry from docling.datamodel.base_models import DocumentStream @@ -34,6 +36,7 @@ class RQOrchestratorConfig(BaseModel): results_ttl: int = 3_600 * 4 failure_ttl: int = 3_600 * 4 results_prefix: str = "docling:results" + heartbeat_key_prefix: str = "docling:job:alive" sub_channel: str = "docling:updates" scratch_dir: Optional[Path] = None redis_max_connections: int = 50 @@ -48,6 +51,14 @@ class _TaskUpdate(BaseModel): error_message: Optional[str] = None +_HEARTBEAT_TTL = 60 # seconds before an unrefreshed key expires +_HEARTBEAT_INTERVAL = 20 # seconds between heartbeat writes +_WATCHDOG_INTERVAL = 30.0 # seconds between watchdog scans +_WATCHDOG_GRACE_PERIOD = ( + 90.0 # don't flag tasks started less than this many seconds ago +) + + class RQOrchestrator(BaseOrchestrator): @staticmethod def make_rq_queue(config: RQOrchestratorConfig) -> tuple[redis.Redis, Queue]: @@ -201,15 +212,26 @@ async def task_result( self, task_id: str, ) -> Optional[DoclingTaskResult]: - if task_id not in self._task_result_keys: - return None - result_key = self._task_result_keys[task_id] + result_key = self._task_result_keys.get( + task_id, f"{self.config.results_prefix}:{task_id}" + ) packed = await self._async_redis_conn.get(result_key) + if packed is None: + return None + self._task_result_keys[task_id] = result_key result = DoclingTaskResult.model_validate( msgpack.unpackb(packed, raw=False, strict_map_key=False) ) return result + async def _on_task_status_changed(self, task: Task) -> None: + """Called after every in-memory status update from pub/sub. + + No-op by default. Subclasses should override to persist the updated + task to durable storage so that terminal states survive pod restarts + and are visible to pods that missed the pub/sub event. + """ + async def _listen_for_updates(self): pubsub = self._async_redis_conn.pubsub() @@ -243,6 +265,8 @@ async def _listen_for_updates(self): ): self._task_result_keys[data.task_id] = data.result_key + await self._on_task_status_changed(task) + if self.notifier: try: await self.notifier.notify_task_subscribers(task.task_id) @@ -253,14 +277,117 @@ async def _listen_for_updates(self): except TaskNotFoundError: _log.warning(f"Task {data.task_id} not found.") + async def _watchdog_task(self) -> None: + """Detect orphaned STARTED tasks whose worker heartbeat key has expired. + + Runs every _WATCHDOG_INTERVAL seconds. For each task in STARTED state + that is older than _WATCHDOG_GRACE_PERIOD, checks whether the worker's + liveness key ({heartbeat_key_prefix}:{task_id}) still exists in Redis. If + the key is absent the worker process has died — publishes a FAILURE update to + the pub/sub channel so that polling clients and WebSocket subscribers are + notified within ~90 seconds of the kill instead of waiting 4 hours. + + Note: the grace period is tracked by the watchdog itself (first_seen_started) + rather than task.started_at. Orchestrator subclasses (e.g. RedisTaskStatusMixin) + may recreate Task objects on every poll, resetting started_at to the current + time and making a task.started_at-based check unreliable. + """ + _log.info("Watchdog starting") + # Maps task_id -> time the watchdog first observed the task in STARTED state. + # Independent of task.started_at which may be reset by polling machinery. + first_seen_started: dict[str, datetime.datetime] = {} + while True: + await asyncio.sleep(_WATCHDOG_INTERVAL) + try: + now = datetime.datetime.now(datetime.timezone.utc) + + all_statuses = { + tid: t.task_status for tid, t in list(self.tasks.items()) + } + _log.debug( + f"Watchdog scan: {len(self.tasks)} tasks in memory, " + f"statuses={all_statuses}" + ) + + # Determine which tasks are currently in STARTED state by + # querying StartedJobRegistry directly — the authoritative, + # cross-pod, durable source that is independent of request + # routing and pod lifecycle. + registry = StartedJobRegistry( + queue=self._rq_queue, connection=self._redis_conn + ) + # cleanup=True so RQ can remove abandoned STARTED entries while + # the watchdog still provides fast heartbeat-based failure + # signaling for SimpleWorker pods. + rq_started_ids = await asyncio.to_thread(registry.get_job_ids) + currently_started = set(rq_started_ids) + + # Remove tasks that are no longer STARTED (completed, failed, gone). + for task_id in list(first_seen_started.keys()): + if task_id not in currently_started: + _log.debug( + f"Watchdog: task {task_id} left STARTED, removing from tracking" + ) + del first_seen_started[task_id] + + # Record first observation time for newly STARTED tasks. + for task_id in currently_started: + if task_id not in first_seen_started: + _log.debug( + f"Watchdog: first observation of STARTED task {task_id}" + ) + first_seen_started[task_id] = now + + # Check tasks that have been STARTED long enough to be past grace period. + candidates = [ + task_id + for task_id, first_seen in list(first_seen_started.items()) + if (now - first_seen).total_seconds() > _WATCHDOG_GRACE_PERIOD + ] + + _log.debug( + f"Watchdog: {len(currently_started)} started, " + f"{len(first_seen_started)} tracked, " + f"{len(candidates)} past grace period" + ) + + for task_id in candidates: + key = f"{self.config.heartbeat_key_prefix}:{task_id}" + alive = await self._async_redis_conn.exists(key) + age = (now - first_seen_started[task_id]).total_seconds() + _log.debug( + f"Watchdog: checking task {task_id} " + f"(age={age:.0f}s), heartbeat key alive={bool(alive)}" + ) + if not alive: + _log.warning( + f"Task {task_id} heartbeat key missing — " + f"worker likely dead, publishing FAILURE" + ) + await self._async_redis_conn.publish( + self.config.sub_channel, + _TaskUpdate( + task_id=task_id, + task_status=TaskStatus.FAILURE, + error_message=( + "Worker heartbeat expired: worker pod " + "likely killed mid-job" + ), + ).model_dump_json(), + ) + # Remove from tracking so we don't re-publish if pub/sub is slow. + del first_seen_started[task_id] + except Exception: + _log.exception("Watchdog error") + async def process_queue(self): # Create a pool of workers - pubsub_worker = [] _log.debug("PubSub worker starting.") - pubsub_worker = asyncio.create_task(self._listen_for_updates()) + pubsub_task = asyncio.create_task(self._listen_for_updates()) + watchdog_task = asyncio.create_task(self._watchdog_task()) - # Wait for all worker to complete - await asyncio.gather(pubsub_worker) + # Wait for all workers to complete + await asyncio.gather(pubsub_task, watchdog_task) _log.debug("PubSub worker completed.") async def delete_task(self, task_id: str): diff --git a/docling_jobkit/orchestrators/rq/worker.py b/docling_jobkit/orchestrators/rq/worker.py index 4011dae..dad4563 100644 --- a/docling_jobkit/orchestrators/rq/worker.py +++ b/docling_jobkit/orchestrators/rq/worker.py @@ -1,10 +1,12 @@ import logging import shutil import tempfile +import threading from pathlib import Path from typing import Any, Optional, Union import msgpack +import redis as sync_redis from rq import SimpleWorker, get_current_job from docling.datamodel.base_models import DocumentStream @@ -20,6 +22,8 @@ from docling_jobkit.datamodel.task import Task from docling_jobkit.datamodel.task_meta import TaskStatus, TaskType from docling_jobkit.orchestrators.rq.orchestrator import ( + _HEARTBEAT_INTERVAL, + _HEARTBEAT_TTL, RQOrchestrator, RQOrchestratorConfig, _TaskUpdate, @@ -75,7 +79,44 @@ def __init__( # Call parent class constructor super().__init__(*args, **kwargs) + def _heartbeat_loop(self, job_id: str, stop_event: threading.Event) -> None: + """Write a liveness key to Redis every _HEARTBEAT_INTERVAL seconds. + + Runs in a daemon thread for the duration of a single job. SimpleWorker + blocks its main thread during job execution, so the RQ-level heartbeat + is not maintained. This thread provides the equivalent liveness signal + that the standard forking Worker gets for free from its parent process. + + The key expires after _HEARTBEAT_TTL seconds without a refresh. If the + worker process is killed the thread dies with it and the key expires + naturally, allowing the orchestrator watchdog to detect the dead job. + """ + key = f"{self.orchestrator_config.heartbeat_key_prefix}:{job_id}" + conn = None + try: + conn = sync_redis.Redis.from_url(self.orchestrator_config.redis_url) + # Write immediately so the key exists before the first watchdog scan. + conn.set(key, "1", ex=_HEARTBEAT_TTL) + while not stop_event.wait(timeout=_HEARTBEAT_INTERVAL): + conn.set(key, "1", ex=_HEARTBEAT_TTL) + except Exception as e: + _log.error(f"Heartbeat thread error for {job_id}: {e}") + finally: + if conn is not None: + try: + conn.close() + except Exception: + pass + def perform_job(self, job, queue): + stop_event = threading.Event() + heartbeat_thread = threading.Thread( + target=self._heartbeat_loop, + args=(job.id, stop_event), + daemon=True, + name=f"heartbeat-{job.id}", + ) + heartbeat_thread.start() try: # Add to job's kwargs conversion manager if hasattr(job, "kwargs"): @@ -88,6 +129,9 @@ def perform_job(self, job, queue): # Custom error handling for individual jobs self.log.error(f"Job {job.id} failed: {e}") raise + finally: + stop_event.set() + heartbeat_thread.join(timeout=5) def docling_task( @@ -133,6 +177,8 @@ def docling_task( raise RuntimeError("No converter") if not task.convert_options: raise RuntimeError("No conversion options") + + # TODO: potentially move the below code into thread wrapper and wait for it. conv_results = conversion_manager.convert_documents( sources=convert_sources, options=task.convert_options,