Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 134 additions & 7 deletions docling_jobkit/orchestrators/rq/orchestrator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import base64
import datetime
import logging
import uuid
import warnings
Expand All @@ -12,6 +13,7 @@
from pydantic import BaseModel
from rq import Queue
from rq.job import Job, JobStatus
from rq.registry import StartedJobRegistry

from docling.datamodel.base_models import DocumentStream

Expand All @@ -34,6 +36,7 @@ class RQOrchestratorConfig(BaseModel):
results_ttl: int = 3_600 * 4
failure_ttl: int = 3_600 * 4
results_prefix: str = "docling:results"
heartbeat_key_prefix: str = "docling:job:alive"
sub_channel: str = "docling:updates"
scratch_dir: Optional[Path] = None
redis_max_connections: int = 50
Expand All @@ -48,6 +51,14 @@ class _TaskUpdate(BaseModel):
error_message: Optional[str] = None


_HEARTBEAT_TTL = 60 # seconds before an unrefreshed key expires
_HEARTBEAT_INTERVAL = 20 # seconds between heartbeat writes
_WATCHDOG_INTERVAL = 30.0 # seconds between watchdog scans
_WATCHDOG_GRACE_PERIOD = (
90.0 # don't flag tasks started less than this many seconds ago
)


class RQOrchestrator(BaseOrchestrator):
@staticmethod
def make_rq_queue(config: RQOrchestratorConfig) -> tuple[redis.Redis, Queue]:
Expand Down Expand Up @@ -201,15 +212,26 @@ async def task_result(
self,
task_id: str,
) -> Optional[DoclingTaskResult]:
if task_id not in self._task_result_keys:
return None
result_key = self._task_result_keys[task_id]
result_key = self._task_result_keys.get(
task_id, f"{self.config.results_prefix}:{task_id}"
)
packed = await self._async_redis_conn.get(result_key)
if packed is None:
return None
self._task_result_keys[task_id] = result_key
result = DoclingTaskResult.model_validate(
msgpack.unpackb(packed, raw=False, strict_map_key=False)
)
return result

async def _on_task_status_changed(self, task: Task) -> None:
"""Called after every in-memory status update from pub/sub.

No-op by default. Subclasses should override to persist the updated
task to durable storage so that terminal states survive pod restarts
and are visible to pods that missed the pub/sub event.
"""

async def _listen_for_updates(self):
pubsub = self._async_redis_conn.pubsub()

Expand Down Expand Up @@ -243,6 +265,8 @@ async def _listen_for_updates(self):
):
self._task_result_keys[data.task_id] = data.result_key

await self._on_task_status_changed(task)

if self.notifier:
try:
await self.notifier.notify_task_subscribers(task.task_id)
Expand All @@ -253,14 +277,117 @@ async def _listen_for_updates(self):
except TaskNotFoundError:
_log.warning(f"Task {data.task_id} not found.")

async def _watchdog_task(self) -> None:
"""Detect orphaned STARTED tasks whose worker heartbeat key has expired.

Runs every _WATCHDOG_INTERVAL seconds. For each task in STARTED state
that is older than _WATCHDOG_GRACE_PERIOD, checks whether the worker's
liveness key ({heartbeat_key_prefix}:{task_id}) still exists in Redis. If
the key is absent the worker process has died — publishes a FAILURE update to
the pub/sub channel so that polling clients and WebSocket subscribers are
notified within ~90 seconds of the kill instead of waiting 4 hours.

Note: the grace period is tracked by the watchdog itself (first_seen_started)
rather than task.started_at. Orchestrator subclasses (e.g. RedisTaskStatusMixin)
may recreate Task objects on every poll, resetting started_at to the current
time and making a task.started_at-based check unreliable.
"""
_log.info("Watchdog starting")
# Maps task_id -> time the watchdog first observed the task in STARTED state.
# Independent of task.started_at which may be reset by polling machinery.
first_seen_started: dict[str, datetime.datetime] = {}
while True:
await asyncio.sleep(_WATCHDOG_INTERVAL)
try:
now = datetime.datetime.now(datetime.timezone.utc)

all_statuses = {
tid: t.task_status for tid, t in list(self.tasks.items())
}
_log.debug(
f"Watchdog scan: {len(self.tasks)} tasks in memory, "
f"statuses={all_statuses}"
)

# Determine which tasks are currently in STARTED state by
# querying StartedJobRegistry directly — the authoritative,
# cross-pod, durable source that is independent of request
# routing and pod lifecycle.
registry = StartedJobRegistry(
queue=self._rq_queue, connection=self._redis_conn
)
# cleanup=True so RQ can remove abandoned STARTED entries while
# the watchdog still provides fast heartbeat-based failure
# signaling for SimpleWorker pods.
rq_started_ids = await asyncio.to_thread(registry.get_job_ids)
currently_started = set(rq_started_ids)

# Remove tasks that are no longer STARTED (completed, failed, gone).
for task_id in list(first_seen_started.keys()):
if task_id not in currently_started:
_log.debug(
f"Watchdog: task {task_id} left STARTED, removing from tracking"
)
del first_seen_started[task_id]

# Record first observation time for newly STARTED tasks.
for task_id in currently_started:
if task_id not in first_seen_started:
_log.debug(
f"Watchdog: first observation of STARTED task {task_id}"
)
first_seen_started[task_id] = now

# Check tasks that have been STARTED long enough to be past grace period.
candidates = [
task_id
for task_id, first_seen in list(first_seen_started.items())
if (now - first_seen).total_seconds() > _WATCHDOG_GRACE_PERIOD
]

_log.debug(
f"Watchdog: {len(currently_started)} started, "
f"{len(first_seen_started)} tracked, "
f"{len(candidates)} past grace period"
)

for task_id in candidates:
key = f"{self.config.heartbeat_key_prefix}:{task_id}"
alive = await self._async_redis_conn.exists(key)
age = (now - first_seen_started[task_id]).total_seconds()
_log.debug(
f"Watchdog: checking task {task_id} "
f"(age={age:.0f}s), heartbeat key alive={bool(alive)}"
)
if not alive:
_log.warning(
f"Task {task_id} heartbeat key missing — "
f"worker likely dead, publishing FAILURE"
)
await self._async_redis_conn.publish(
self.config.sub_channel,
_TaskUpdate(
task_id=task_id,
task_status=TaskStatus.FAILURE,
error_message=(
"Worker heartbeat expired: worker pod "
"likely killed mid-job"
),
).model_dump_json(),
)
# Remove from tracking so we don't re-publish if pub/sub is slow.
del first_seen_started[task_id]
except Exception:
_log.exception("Watchdog error")

async def process_queue(self):
# Create a pool of workers
pubsub_worker = []
_log.debug("PubSub worker starting.")
pubsub_worker = asyncio.create_task(self._listen_for_updates())
pubsub_task = asyncio.create_task(self._listen_for_updates())
watchdog_task = asyncio.create_task(self._watchdog_task())

# Wait for all worker to complete
await asyncio.gather(pubsub_worker)
# Wait for all workers to complete
await asyncio.gather(pubsub_task, watchdog_task)
_log.debug("PubSub worker completed.")

async def delete_task(self, task_id: str):
Expand Down
46 changes: 46 additions & 0 deletions docling_jobkit/orchestrators/rq/worker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
import shutil
import tempfile
import threading
from pathlib import Path
from typing import Any, Optional, Union

import msgpack
import redis as sync_redis
from rq import SimpleWorker, get_current_job

from docling.datamodel.base_models import DocumentStream
Expand All @@ -20,6 +22,8 @@
from docling_jobkit.datamodel.task import Task
from docling_jobkit.datamodel.task_meta import TaskStatus, TaskType
from docling_jobkit.orchestrators.rq.orchestrator import (
_HEARTBEAT_INTERVAL,
_HEARTBEAT_TTL,
RQOrchestrator,
RQOrchestratorConfig,
_TaskUpdate,
Expand Down Expand Up @@ -75,7 +79,44 @@ def __init__(
# Call parent class constructor
super().__init__(*args, **kwargs)

def _heartbeat_loop(self, job_id: str, stop_event: threading.Event) -> None:
"""Write a liveness key to Redis every _HEARTBEAT_INTERVAL seconds.

Runs in a daemon thread for the duration of a single job. SimpleWorker
blocks its main thread during job execution, so the RQ-level heartbeat
is not maintained. This thread provides the equivalent liveness signal
that the standard forking Worker gets for free from its parent process.

The key expires after _HEARTBEAT_TTL seconds without a refresh. If the
worker process is killed the thread dies with it and the key expires
naturally, allowing the orchestrator watchdog to detect the dead job.
"""
key = f"{self.orchestrator_config.heartbeat_key_prefix}:{job_id}"
conn = None
try:
conn = sync_redis.Redis.from_url(self.orchestrator_config.redis_url)
# Write immediately so the key exists before the first watchdog scan.
conn.set(key, "1", ex=_HEARTBEAT_TTL)
while not stop_event.wait(timeout=_HEARTBEAT_INTERVAL):
conn.set(key, "1", ex=_HEARTBEAT_TTL)
except Exception as e:
_log.error(f"Heartbeat thread error for {job_id}: {e}")
finally:
if conn is not None:
try:
conn.close()
except Exception:
pass

def perform_job(self, job, queue):
stop_event = threading.Event()
heartbeat_thread = threading.Thread(
target=self._heartbeat_loop,
args=(job.id, stop_event),
daemon=True,
name=f"heartbeat-{job.id}",
)
heartbeat_thread.start()
try:
# Add to job's kwargs conversion manager
if hasattr(job, "kwargs"):
Expand All @@ -88,6 +129,9 @@ def perform_job(self, job, queue):
# Custom error handling for individual jobs
self.log.error(f"Job {job.id} failed: {e}")
raise
finally:
stop_event.set()
heartbeat_thread.join(timeout=5)


def docling_task(
Expand Down Expand Up @@ -133,6 +177,8 @@ def docling_task(
raise RuntimeError("No converter")
if not task.convert_options:
raise RuntimeError("No conversion options")

# TODO: potentially move the below code into thread wrapper and wait for it.
conv_results = conversion_manager.convert_documents(
sources=convert_sources,
options=task.convert_options,
Expand Down